From 62960266acb31027f15881d0818430ae293cd0db Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Mon, 25 Jan 2021 12:02:09 -0500 Subject: [PATCH] tofu: Implement PersistentHosts --- client.go | 9 ++--- tofu/tofu.go | 95 ++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 87 insertions(+), 17 deletions(-) diff --git a/client.go b/client.go index ba84245..a45c659 100644 --- a/client.go +++ b/client.go @@ -20,8 +20,7 @@ type Client struct { // If the returned error is not nil, the certificate will not be trusted // and the request will be aborted. // - // For a basic trust on first use implementation, see (*KnownHosts).TOFU - // in the tofu submodule. + // See the tofu submodule for an implementation of trust on first use. TrustCertificate func(hostname string, cert *x509.Certificate) error // Timeout specifies a time limit for requests made by this @@ -88,8 +87,7 @@ func (c *Client) Do(req *Request) (*Response, error) { if c.Timeout != 0 { err := conn.SetDeadline(start.Add(c.Timeout)) if err != nil { - return nil, fmt.Errorf( - "failed to set connection deadline: %w", err) + return nil, fmt.Errorf("failed to set connection deadline: %w", err) } } @@ -114,8 +112,7 @@ func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) { err := req.Write(w) if err != nil { - return nil, fmt.Errorf( - "failed to write request data: %w", err) + return nil, fmt.Errorf("failed to write request: %w", err) } if err := w.Flush(); err != nil { diff --git a/tofu/tofu.go b/tofu/tofu.go index 2ea8ac8..a928be6 100644 --- a/tofu/tofu.go +++ b/tofu/tofu.go @@ -27,7 +27,7 @@ type KnownHosts struct { } // Add adds a host to the list of known hosts. -func (k *KnownHosts) Add(h Host) error { +func (k *KnownHosts) Add(h Host) { k.mu.Lock() defer k.mu.Unlock() if k.hosts == nil { @@ -35,7 +35,6 @@ func (k *KnownHosts) Add(h Host) error { } k.hosts[h.Hostname] = h - return nil } // Lookup returns the known host entry corresponding to the given hostname. @@ -144,7 +143,7 @@ func (k *KnownHosts) Parse(r io.Reader) error { // TOFU implements basic trust on first use. // // If the host is not on file, it is added to the list. -// If the host on file is expired, it is replaced with the provided host. +// If the host on file is expired, a new entry is added to the list. // If the fingerprint does not match the one on file, an error is returned. func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error { host := NewHost(hostname, cert.Raw, cert.NotAfter) @@ -181,9 +180,9 @@ func NewHostWriter(w io.WriteCloser) *HostWriter { } } -// NewHostsFile returns a new host writer that appends to the file at the given path. +// OpenHostsFile returns a new host writer that appends to the file at the given path. // The file is created if it does not exist. -func NewHostsFile(path string) (*HostWriter, error) { +func OpenHostsFile(path string) (*HostWriter, error) { f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { return nil, err @@ -212,6 +211,83 @@ func (h *HostWriter) Close() error { return h.cl.Close() } +// PersistentHosts represents a persistent set of known hosts. +type PersistentHosts struct { + hosts *KnownHosts + writer *HostWriter +} + +// NewPersistentHosts returns a new persistent set of known hosts. +func NewPersistentHosts(hosts *KnownHosts, writer *HostWriter) *PersistentHosts { + return &PersistentHosts{ + hosts, + writer, + } +} + +// LoadPersistentHosts loads persistent hosts from the file at the given path. +func LoadPersistentHosts(path string) (*PersistentHosts, error) { + hosts := &KnownHosts{} + if err := hosts.Load(path); err != nil { + return nil, err + } + writer, err := OpenHostsFile(path) + if err != nil { + return nil, err + } + return &PersistentHosts{ + hosts, + writer, + }, nil +} + +// Add adds a host to the list of known hosts. +// It returns an error if the host could not be persisted. +func (p *PersistentHosts) Add(h Host) error { + err := p.writer.WriteHost(h) + if err != nil { + return fmt.Errorf("failed to persist host: %w", err) + } + p.hosts.Add(h) + return nil +} + +// Lookup returns the known host entry corresponding to the given hostname. +func (p *PersistentHosts) Lookup(hostname string) (Host, bool) { + return p.hosts.Lookup(hostname) +} + +// Entries returns the known host entries sorted by hostname. +func (p *PersistentHosts) Entries() []Host { + return p.hosts.Entries() +} + +// TOFU implements trust on first use with a persistent set of known hosts. +// +// If the host is not on file, it is added to the list. +// If the host on file is expired, a new entry is added to the list. +// If the fingerprint does not match the one on file, an error is returned. +func (p *PersistentHosts) TOFU(hostname string, cert *x509.Certificate) error { + host := NewHost(hostname, cert.Raw, cert.NotAfter) + + knownHost, ok := p.Lookup(hostname) + if !ok || time.Now().After(knownHost.Expires) { + return p.Add(host) + } + + // Check fingerprint + if !bytes.Equal(knownHost.Fingerprint, host.Fingerprint) { + return fmt.Errorf("fingerprint for %q does not match", hostname) + } + + return nil +} + +// Close closes the underlying HostWriter. +func (p *PersistentHosts) Close() error { + return p.writer.Close() +} + // Host represents a host entry with a fingerprint using a certain algorithm. type Host struct { Hostname string // hostname @@ -259,8 +335,7 @@ func (h *Host) UnmarshalText(text []byte) error { parts := bytes.Split(text, []byte(" ")) if len(parts) != 4 { - return fmt.Errorf( - "expected the format %q", format) + return fmt.Errorf("expected the format %q", format) } if len(parts[0]) == 0 { @@ -271,8 +346,7 @@ func (h *Host) UnmarshalText(text []byte) error { algorithm := string(parts[1]) if algorithm != "SHA-512" { - return fmt.Errorf( - "unsupported algorithm %q", algorithm) + return fmt.Errorf("unsupported algorithm %q", algorithm) } h.Algorithm = algorithm @@ -298,8 +372,7 @@ func (h *Host) UnmarshalText(text []byte) error { unix, err := strconv.ParseInt(string(parts[3]), 10, 0) if err != nil { - return fmt.Errorf( - "invalid unix timestamp: %w", err) + return fmt.Errorf("invalid unix timestamp: %w", err) } h.Expires = time.Unix(unix, 0)