diff --git a/tofu.go b/tofu.go index 8932b2f..eedffb5 100644 --- a/tofu.go +++ b/tofu.go @@ -17,7 +17,7 @@ import ( // KnownHosts represents a list of known hosts. // The zero value for KnownHosts is an empty list ready to use. type KnownHosts struct { - hosts []KnownHost + hosts map[string]certInfo file *os.File } @@ -61,43 +61,50 @@ func (k *KnownHosts) Load(path string) error { // Add adds a certificate to the list of known hosts. // If KnownHosts was loaded from a file, Add will append to the file. func (k *KnownHosts) Add(hostname string, cert *x509.Certificate) { - host := NewKnownHost(hostname, cert) - k.hosts = append(k.hosts, host) - // Append to the file - if k.file != nil { - host.Write(k.file) - } + k.add(hostname, cert, true) } -// AddTemporary adds a certificate to the list of known hosts, -// without writing it to a file. +// AddTemporary adds a certificate to the list of known hosts +// without writing it to the known hosts file. func (k *KnownHosts) AddTemporary(hostname string, cert *x509.Certificate) { - host := NewKnownHost(hostname, cert) - k.hosts = append(k.hosts, host) + k.add(hostname, cert, false) +} + +func (k *KnownHosts) add(hostname string, cert *x509.Certificate, write bool) { + if k.hosts == nil { + k.hosts = map[string]certInfo{} + } + info := certInfo{ + Algorithm: "SHA-512", + Fingerprint: Fingerprint(cert), + Expires: cert.NotAfter.Unix(), + } + k.hosts[hostname] = info + // Append to the file + if write && k.file != nil { + appendKnownHost(k.file, hostname, info) + } } // Lookup looks for the provided certificate in the list of known hosts. // If the hostname is in the list, but the fingerprint differs, // Lookup returns ErrCertificateNotTrusted. -// If the hostname is not in the list, Lookup returns ErrUnknownCertificate. +// If the hostname is not in the list, Lookup returns ErrCertificateUnknown. // If the certificate is found and the fingerprint matches, error will be nil. func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error { now := time.Now().Unix() fingerprint := Fingerprint(cert) - for i := range k.hosts { - if k.hosts[i].Hostname != hostname { - continue - } - if k.hosts[i].Expires <= now { + if c, ok := k.hosts[hostname]; ok { + if c.Expires <= now { // Certificate is expired - continue + return ErrCertificateUnknown } - if k.hosts[i].Fingerprint == fingerprint { - // Fingerprint matches - return nil + if c.Fingerprint != fingerprint { + // Fingerprint does not match + return ErrCertificateNotTrusted } - // Fingerprint does not match - return ErrCertificateNotTrusted + // Certificate is trusted + return nil } return ErrCertificateUnknown } @@ -105,10 +112,12 @@ func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error { // Parse parses the provided reader and adds the parsed known hosts to the list. // Invalid lines are ignored. func (k *KnownHosts) Parse(r io.Reader) { + if k.hosts == nil { + k.hosts = map[string]certInfo{} + } scanner := bufio.NewScanner(r) for scanner.Scan() { text := scanner.Text() - parts := strings.Split(text, " ") if len(parts) < 4 { continue @@ -116,50 +125,38 @@ func (k *KnownHosts) Parse(r io.Reader) { hostname := parts[0] algorithm := parts[1] + if algorithm != "SHA-512" { + continue + } fingerprint := parts[2] expires, err := strconv.ParseInt(parts[3], 10, 0) if err != nil { continue } - k.hosts = append(k.hosts, KnownHost{ - Hostname: hostname, + k.hosts[hostname] = certInfo{ Algorithm: algorithm, Fingerprint: fingerprint, Expires: expires, - }) + } } } // Write writes the known hosts to the provided io.Writer. func (k *KnownHosts) Write(w io.Writer) { - for _, h := range k.hosts { - h.Write(w) + for h, c := range k.hosts { + appendKnownHost(w, h, c) } } -// KnownHost represents a known host. -type KnownHost struct { - Hostname string // e.g. gemini.circumlunar.space +type certInfo struct { Algorithm string // fingerprint algorithm e.g. SHA-512 Fingerprint string // fingerprint in hexadecimal, with ':' between each octet Expires int64 // unix time of certificate notAfter date } -// NewKnownHost creates a new known host from a hostname and a certificate. -func NewKnownHost(hostname string, cert *x509.Certificate) KnownHost { - return KnownHost{ - Hostname: hostname, - Algorithm: "SHA-512", - Fingerprint: Fingerprint(cert), - Expires: cert.NotAfter.Unix(), - } -} - -// Write writes the known host to the provided io.Writer. -func (k KnownHost) Write(w io.Writer) (int, error) { - s := fmt.Sprintf("%s %s %s %d\n", k.Hostname, k.Algorithm, k.Fingerprint, k.Expires) - return w.Write([]byte(s)) +func appendKnownHost(w io.Writer, hostname string, c certInfo) (int, error) { + return fmt.Fprintf(w, "%s %s %s %d\n", hostname, c.Algorithm, c.Fingerprint, c.Expires) } // Fingerprint returns the SHA-512 fingerprint of the provided certificate.