diff --git a/client.go b/client.go index 57c144a..d1bd993 100644 --- a/client.go +++ b/client.go @@ -168,7 +168,7 @@ func (resp *Response) read(r *bufio.Reader) error { // Client represents a Gemini client. type Client struct { // KnownHosts is a list of known hosts that the client trusts. - KnownHosts *KnownHosts + KnownHosts KnownHosts // CertificateStore contains all the certificates that the client has stored. CertificateStore *CertificateStore @@ -210,12 +210,10 @@ func (c *Client) Send(req *Request) (*Response, error) { } // Check that the client trusts the certificate if c.TrustCertificate == nil { - if c.KnownHosts == nil { - return ErrCertificateNotTrusted - } else if err := c.KnownHosts.Lookup(cert); err != nil { + if err := c.KnownHosts.Lookup(cert); err != nil { return err } - } else if err := c.TrustCertificate(cert, c.KnownHosts); err != nil { + } else if err := c.TrustCertificate(cert, &c.KnownHosts); err != nil { return err } return nil diff --git a/examples/client/client.go b/examples/client/client.go index 69e68e6..f61e3b3 100644 --- a/examples/client/client.go +++ b/examples/client/client.go @@ -19,15 +19,8 @@ var ( ) func init() { - // Load the list of known hosts - knownHosts, err := gemini.LoadKnownHosts() - if err != nil { - log.Fatal(err) - } - - client = &gemini.Client{ - KnownHosts: knownHosts, - } + client = &gemini.Client{} + client.KnownHosts.Load() client.TrustCertificate = func(cert *x509.Certificate, knownHosts *gemini.KnownHosts) error { err := knownHosts.Lookup(cert) @@ -53,12 +46,10 @@ func init() { } // Configure a client side certificate. - // To generate a certificate, run: - // - // openssl genrsa -out client.key 2048 - // openssl ecparam -genkey -name secp384r1 -out client.key - // openssl req -new -x509 -sha512 -key client.key -out client.crt -days 365 + // To generate a TLS key pair, run: // + // go run -tags=example ../cert + var cert tls.Certificate cert, err = tls.LoadX509KeyPair("examples/client/localhost.crt", "examples/client/localhost.key") if err != nil { log.Fatal(err) diff --git a/examples/server/server.go b/examples/server/server.go index 62608c7..bda87e6 100644 --- a/examples/server/server.go +++ b/examples/server/server.go @@ -13,10 +13,7 @@ func main() { // Load a TLS key pair. // To generate a TLS key pair, run: // - // openssl genrsa -out server.key 2048 - // openssl ecparam -genkey -name secp384r1 -out server.key - // openssl req -new -x509 -sha512 -key server.key -out server.crt -days 365 - // + // go run -tags=example ../cert cert, err := tls.LoadX509KeyPair("examples/server/localhost.crt", "examples/server/localhost.key") if err != nil { log.Fatal(err) diff --git a/gemini.go b/gemini.go index 409069c..99ae901 100644 --- a/gemini.go +++ b/gemini.go @@ -60,11 +60,7 @@ func init() { var setupDefaultClientOnce sync.Once func setupDefaultClient() { - knownHosts, err := LoadKnownHosts() - if err != nil { - knownHosts = &KnownHosts{} - } - DefaultClient.KnownHosts = knownHosts + DefaultClient.KnownHosts.Load() } // Send sends a Gemini request and returns a Gemini response. diff --git a/tofu.go b/tofu.go index aecde48..8052a35 100644 --- a/tofu.go +++ b/tofu.go @@ -15,50 +15,50 @@ import ( ) // KnownHosts represents a list of known hosts. +// The zero value for KnownHosts is an empty list ready to use. type KnownHosts struct { hosts []KnownHost file *os.File } -// LoadKnownHosts loads the known hosts from the default known hosts path. -// The default path is $XDG_DATA_HOME/gemini/known_hosts +// Load loads the known hosts from the default known hosts path, which is +// `$XDG_DATA_HOME/gemini/known_hosts`. // It creates the path and any of its parent directories if they do not exist. -// The returned KnownHosts appends to the file whenever a certificate is added. -func LoadKnownHosts() (*KnownHosts, error) { +// KnownHosts will append to the file whenever a certificate is added. +func (k *KnownHosts) Load() error { path, err := defaultKnownHostsPath() if err != nil { - return nil, err + return err } - return LoadKnownHostsFrom(path) + return k.LoadFrom(path) } -// LoadKnownHosts loads the known hosts from the provided path. +// LoadFrom loads the known hosts from the provided path. // It creates the path and any of its parent directories if they do not exist. -// The returned KnownHosts appends to the file whenever a certificate is added. -func LoadKnownHostsFrom(path string) (*KnownHosts, error) { +// KnownHosts will append to the file whenever a certificate is added. +func (k *KnownHosts) LoadFrom(path string) error { if dir := filepath.Dir(path); dir != "." { err := os.MkdirAll(dir, 0755) if err != nil { - return nil, err + return err } } f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0644) if err != nil { - return nil, err + return err } - k := &KnownHosts{} k.Parse(f) f.Close() // Open the file for append-only use f, err = os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { - return nil, err + return err } k.file = f - return k, nil + return nil } -// Add adds a certificate to the KnownHosts. +// Add adds a certificate to the list of known hosts. // If KnownHosts was loaded from a file, Add will append to the file. func (k *KnownHosts) Add(cert *x509.Certificate) { host := NewKnownHost(cert) @@ -125,6 +125,13 @@ func (k *KnownHosts) Parse(r io.Reader) { } } +// Write writes the known hosts to the provided io.Writer. +func (k *KnownHosts) Write(w io.Writer) { + for _, h := range k.hosts { + h.Write(w) + } +} + // KnownHost represents a known host. type KnownHost struct { Hostname string // e.g. gemini.circumlunar.space