diff --git a/cert.go b/cert.go index 2650082..3cfcb41 100644 --- a/cert.go +++ b/cert.go @@ -18,19 +18,22 @@ import ( "time" ) -// CertificateStore maps certificate scopes to certificates. -// The zero value of CertificateStore is an empty store ready to use. -type CertificateStore struct { - store map[string]tls.Certificate - dir bool - path string +// CertificateDir maps certificate scopes to certificates. +type CertificateStore map[string]tls.Certificate + +// CertificateDir represents a certificate store optionally loaded from a directory. +// The zero value of CertificateDir is an empty store ready to use. +type CertificateDir struct { + CertificateStore + dir bool + path string } // Add adds a certificate for the given scope to the store. // It tries to parse the certificate if it is not already parsed. -func (c *CertificateStore) Add(scope string, cert tls.Certificate) { - if c.store == nil { - c.store = map[string]tls.Certificate{} +func (c *CertificateDir) Add(scope string, cert tls.Certificate) { + if c.CertificateStore == nil { + c.CertificateStore = CertificateStore{} } // Parse certificate if not already parsed if cert.Leaf == nil { @@ -39,11 +42,11 @@ func (c *CertificateStore) Add(scope string, cert tls.Certificate) { cert.Leaf = parsed } } - c.store[scope] = cert + c.CertificateStore[scope] = cert } // Write writes the provided certificate to the certificate directory. -func (c *CertificateStore) Write(scope string, cert tls.Certificate) error { +func (c *CertificateDir) Write(scope string, cert tls.Certificate) error { if c.dir { certPath := filepath.Join(c.path, scope+".crt") keyPath := filepath.Join(c.path, scope+".key") @@ -55,8 +58,8 @@ func (c *CertificateStore) Write(scope string, cert tls.Certificate) error { } // Lookup returns the certificate for the given scope. -func (c *CertificateStore) Lookup(scope string) (tls.Certificate, bool) { - cert, ok := c.store[scope] +func (c *CertificateDir) Lookup(scope string) (tls.Certificate, bool) { + cert, ok := c.CertificateStore[scope] return cert, ok } @@ -66,7 +69,7 @@ func (c *CertificateStore) Lookup(scope string) (tls.Certificate, bool) { // For example, the hostname "localhost" would have the corresponding files // localhost.crt (certificate) and localhost.key (private key). // New certificates will be written to this directory. -func (c *CertificateStore) Load(path string) error { +func (c *CertificateDir) Load(path string) error { matches, err := filepath.Glob(filepath.Join(path, "*.crt")) if err != nil { return err @@ -85,8 +88,8 @@ func (c *CertificateStore) Load(path string) error { return nil } -// SetOutput sets the directory that new certificates will be written to. -func (c *CertificateStore) SetOutput(path string) { +// SetDir sets the directory that new certificates will be written to. +func (c *CertificateDir) SetDir(path string) { c.dir = true c.path = path } diff --git a/client.go b/client.go index 29eb01f..5f8eac1 100644 --- a/client.go +++ b/client.go @@ -15,10 +15,10 @@ import ( // Client is a Gemini client. type Client struct { // KnownHosts is a list of known hosts. - KnownHosts KnownHosts + KnownHosts KnownHostsFile // Certificates stores client-side certificates. - Certificates CertificateStore + Certificates CertificateDir // Timeout specifies a time limit for requests made by this // Client. The timeout includes connection time and reading diff --git a/server.go b/server.go index 52de33a..e687aa8 100644 --- a/server.go +++ b/server.go @@ -26,7 +26,7 @@ type Server struct { WriteTimeout time.Duration // Certificates contains the certificates used by the server. - Certificates CertificateStore + Certificates CertificateDir // CreateCertificate, if not nil, will be called to create a new certificate // if the current one is expired or missing. diff --git a/tofu.go b/tofu.go index b6e3be2..de971b6 100644 --- a/tofu.go +++ b/tofu.go @@ -21,42 +21,45 @@ const ( ) // KnownHosts represents a list of known hosts. -// The zero value for KnownHosts is an empty list ready to use. -type KnownHosts struct { - hosts map[string]Fingerprint - out io.Writer +type KnownHosts map[string]Fingerprint + +// KnownHostsFile represents a list of known hosts optionally loaded from a file. +// The zero value for KnownHostsFile represents an empty list ready to use. +type KnownHostsFile struct { + KnownHosts + out io.Writer } // SetOutput sets the output to which new known hosts will be written to. -func (k *KnownHosts) SetOutput(w io.Writer) { +func (k *KnownHostsFile) SetOutput(w io.Writer) { k.out = w } // Add adds a known host to the list of known hosts. -func (k *KnownHosts) Add(hostname string, fingerprint Fingerprint) { - if k.hosts == nil { - k.hosts = map[string]Fingerprint{} +func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) { + if k.KnownHosts == nil { + k.KnownHosts = KnownHosts{} } - k.hosts[hostname] = fingerprint + k.KnownHosts[hostname] = fingerprint } // Lookup returns the fingerprint of the certificate corresponding to // the given hostname. -func (k *KnownHosts) Lookup(hostname string) (Fingerprint, bool) { - c, ok := k.hosts[hostname] +func (k *KnownHostsFile) Lookup(hostname string) (Fingerprint, bool) { + c, ok := k.KnownHosts[hostname] return c, ok } // Write writes a known hosts entry to the configured output. -func (k *KnownHosts) Write(hostname string, fingerprint Fingerprint) { +func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) { if k.out != nil { k.writeKnownHost(k.out, hostname, fingerprint) } } // WriteAll writes all of the known hosts to the provided io.Writer. -func (k *KnownHosts) WriteAll(w io.Writer) error { - for h, c := range k.hosts { +func (k *KnownHostsFile) WriteAll(w io.Writer) error { + for h, c := range k.KnownHosts { if _, err := k.writeKnownHost(w, h, c); err != nil { return err } @@ -65,14 +68,14 @@ func (k *KnownHosts) WriteAll(w io.Writer) error { } // writeKnownHost writes a known host to the provided io.Writer. -func (k *KnownHosts) writeKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) { +func (k *KnownHostsFile) writeKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) { return fmt.Fprintf(w, "%s %s %s %d\n", hostname, f.Algorithm, f.Hex, f.Expires) } // Load loads the known hosts from the provided path. // It creates the file if it does not exist. // New known hosts will be appended to the file. -func (k *KnownHosts) Load(path string) error { +func (k *KnownHostsFile) Load(path string) error { f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0644) if err != nil { return err @@ -90,9 +93,9 @@ func (k *KnownHosts) Load(path string) error { // Parse parses the provided reader and adds the parsed known hosts to the list. // Invalid entries are ignored. -func (k *KnownHosts) Parse(r io.Reader) { - if k.hosts == nil { - k.hosts = map[string]Fingerprint{} +func (k *KnownHostsFile) Parse(r io.Reader) { + if k.KnownHosts == nil { + k.KnownHosts = map[string]Fingerprint{} } scanner := bufio.NewScanner(r) for scanner.Scan() { @@ -114,7 +117,7 @@ func (k *KnownHosts) Parse(r io.Reader) { continue } - k.hosts[hostname] = Fingerprint{ + k.KnownHosts[hostname] = Fingerprint{ Algorithm: algorithm, Hex: fingerprint, Expires: expires,