diff --git a/README.md b/README.md index bec4258..6b5988f 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ clients. Here is a simple client using TOFU to authenticate certificates: ```go client := &gemini.Client{ - KnownHosts: gemini.LoadKnownHosts(".local/share/gemini/known_hosts"), + KnownHosts: gemini.LoadKnownHosts(), TrustCertificate: func(cert *x509.Certificate, knownHosts *gemini.KnownHosts) error { // If the certificate is in the known hosts list, allow the connection if err := knownHosts.Lookup(cert); { diff --git a/tofu.go b/tofu.go index 5c92ea2..aecde48 100644 --- a/tofu.go +++ b/tofu.go @@ -20,10 +20,22 @@ type KnownHosts struct { file *os.File } +// LoadKnownHosts loads the known hosts from the default known hosts path. +// The default path is $XDG_DATA_HOME/gemini/known_hosts +// It creates the path and any of its parent directories if they do not exist. +// The returned KnownHosts appends to the file whenever a certificate is added. +func LoadKnownHosts() (*KnownHosts, error) { + path, err := defaultKnownHostsPath() + if err != nil { + return nil, err + } + return LoadKnownHostsFrom(path) +} + // LoadKnownHosts loads the known hosts from the provided path. // It creates the path and any of its parent directories if they do not exist. // The returned KnownHosts appends to the file whenever a certificate is added. -func LoadKnownHosts(path string) (*KnownHosts, error) { +func LoadKnownHostsFrom(path string) (*KnownHosts, error) { if dir := filepath.Dir(path); dir != "." { err := os.MkdirAll(dir, 0755) if err != nil { @@ -149,3 +161,27 @@ func Fingerprint(cert *x509.Certificate) string { } return buf.String() } + +// defaultKnownHostsPath returns the default known_hosts path. +// The default path is $XDG_DATA_HOME/gemini/known_hosts +func defaultKnownHostsPath() (string, error) { + dataDir, err := userDataDir() + if err != nil { + return "", err + } + return filepath.Join(dataDir, "gemini", "known_hosts"), nil +} + +// userDataDir returns the user data directory. +func userDataDir() (string, error) { + dataDir, ok := os.LookupEnv("XDG_DATA_HOME") + if ok { + return dataDir, nil + } + + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, ".local", "share"), nil +}