KnownHosts: use map instead of slice for underlying storage

This commit is contained in:
adnano 2020-10-13 19:54:48 -04:00
parent 92789dff19
commit 2c1081966d

87
tofu.go
View File

@ -17,7 +17,7 @@ import (
// KnownHosts represents a list of known hosts. // KnownHosts represents a list of known hosts.
// The zero value for KnownHosts is an empty list ready to use. // The zero value for KnownHosts is an empty list ready to use.
type KnownHosts struct { type KnownHosts struct {
hosts []KnownHost hosts map[string]certInfo
file *os.File file *os.File
} }
@ -61,54 +61,63 @@ func (k *KnownHosts) Load(path string) error {
// Add adds a certificate to the list of known hosts. // Add adds a certificate to the list of known hosts.
// If KnownHosts was loaded from a file, Add will append to the file. // If KnownHosts was loaded from a file, Add will append to the file.
func (k *KnownHosts) Add(hostname string, cert *x509.Certificate) { func (k *KnownHosts) Add(hostname string, cert *x509.Certificate) {
host := NewKnownHost(hostname, cert) k.add(hostname, cert, true)
k.hosts = append(k.hosts, host)
// Append to the file
if k.file != nil {
host.Write(k.file)
}
} }
// AddTemporary adds a certificate to the list of known hosts, // AddTemporary adds a certificate to the list of known hosts
// without writing it to a file. // without writing it to the known hosts file.
func (k *KnownHosts) AddTemporary(hostname string, cert *x509.Certificate) { func (k *KnownHosts) AddTemporary(hostname string, cert *x509.Certificate) {
host := NewKnownHost(hostname, cert) k.add(hostname, cert, false)
k.hosts = append(k.hosts, host) }
func (k *KnownHosts) add(hostname string, cert *x509.Certificate, write bool) {
if k.hosts == nil {
k.hosts = map[string]certInfo{}
}
info := certInfo{
Algorithm: "SHA-512",
Fingerprint: Fingerprint(cert),
Expires: cert.NotAfter.Unix(),
}
k.hosts[hostname] = info
// Append to the file
if write && k.file != nil {
appendKnownHost(k.file, hostname, info)
}
} }
// Lookup looks for the provided certificate in the list of known hosts. // Lookup looks for the provided certificate in the list of known hosts.
// If the hostname is in the list, but the fingerprint differs, // If the hostname is in the list, but the fingerprint differs,
// Lookup returns ErrCertificateNotTrusted. // Lookup returns ErrCertificateNotTrusted.
// If the hostname is not in the list, Lookup returns ErrUnknownCertificate. // If the hostname is not in the list, Lookup returns ErrCertificateUnknown.
// If the certificate is found and the fingerprint matches, error will be nil. // If the certificate is found and the fingerprint matches, error will be nil.
func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error { func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
now := time.Now().Unix() now := time.Now().Unix()
fingerprint := Fingerprint(cert) fingerprint := Fingerprint(cert)
for i := range k.hosts { if c, ok := k.hosts[hostname]; ok {
if k.hosts[i].Hostname != hostname { if c.Expires <= now {
continue
}
if k.hosts[i].Expires <= now {
// Certificate is expired // Certificate is expired
continue return ErrCertificateUnknown
}
if k.hosts[i].Fingerprint == fingerprint {
// Fingerprint matches
return nil
} }
if c.Fingerprint != fingerprint {
// Fingerprint does not match // Fingerprint does not match
return ErrCertificateNotTrusted return ErrCertificateNotTrusted
} }
// Certificate is trusted
return nil
}
return ErrCertificateUnknown return ErrCertificateUnknown
} }
// Parse parses the provided reader and adds the parsed known hosts to the list. // Parse parses the provided reader and adds the parsed known hosts to the list.
// Invalid lines are ignored. // Invalid lines are ignored.
func (k *KnownHosts) Parse(r io.Reader) { func (k *KnownHosts) Parse(r io.Reader) {
if k.hosts == nil {
k.hosts = map[string]certInfo{}
}
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
for scanner.Scan() { for scanner.Scan() {
text := scanner.Text() text := scanner.Text()
parts := strings.Split(text, " ") parts := strings.Split(text, " ")
if len(parts) < 4 { if len(parts) < 4 {
continue continue
@ -116,50 +125,38 @@ func (k *KnownHosts) Parse(r io.Reader) {
hostname := parts[0] hostname := parts[0]
algorithm := parts[1] algorithm := parts[1]
if algorithm != "SHA-512" {
continue
}
fingerprint := parts[2] fingerprint := parts[2]
expires, err := strconv.ParseInt(parts[3], 10, 0) expires, err := strconv.ParseInt(parts[3], 10, 0)
if err != nil { if err != nil {
continue continue
} }
k.hosts = append(k.hosts, KnownHost{ k.hosts[hostname] = certInfo{
Hostname: hostname,
Algorithm: algorithm, Algorithm: algorithm,
Fingerprint: fingerprint, Fingerprint: fingerprint,
Expires: expires, Expires: expires,
}) }
} }
} }
// Write writes the known hosts to the provided io.Writer. // Write writes the known hosts to the provided io.Writer.
func (k *KnownHosts) Write(w io.Writer) { func (k *KnownHosts) Write(w io.Writer) {
for _, h := range k.hosts { for h, c := range k.hosts {
h.Write(w) appendKnownHost(w, h, c)
} }
} }
// KnownHost represents a known host. type certInfo struct {
type KnownHost struct {
Hostname string // e.g. gemini.circumlunar.space
Algorithm string // fingerprint algorithm e.g. SHA-512 Algorithm string // fingerprint algorithm e.g. SHA-512
Fingerprint string // fingerprint in hexadecimal, with ':' between each octet Fingerprint string // fingerprint in hexadecimal, with ':' between each octet
Expires int64 // unix time of certificate notAfter date Expires int64 // unix time of certificate notAfter date
} }
// NewKnownHost creates a new known host from a hostname and a certificate. func appendKnownHost(w io.Writer, hostname string, c certInfo) (int, error) {
func NewKnownHost(hostname string, cert *x509.Certificate) KnownHost { return fmt.Fprintf(w, "%s %s %s %d\n", hostname, c.Algorithm, c.Fingerprint, c.Expires)
return KnownHost{
Hostname: hostname,
Algorithm: "SHA-512",
Fingerprint: Fingerprint(cert),
Expires: cert.NotAfter.Unix(),
}
}
// Write writes the known host to the provided io.Writer.
func (k KnownHost) Write(w io.Writer) (int, error) {
s := fmt.Sprintf("%s %s %s %d\n", k.Hostname, k.Algorithm, k.Fingerprint, k.Expires)
return w.Write([]byte(s))
} }
// Fingerprint returns the SHA-512 fingerprint of the provided certificate. // Fingerprint returns the SHA-512 fingerprint of the provided certificate.