From 9fe837ffaca805d9781a44f7760632083793a20a Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Thu, 14 Jan 2021 16:14:40 -0500 Subject: [PATCH] tofu: Refactor known hosts This commit introduces the KnownHosts struct, whose purpose is simply to store known hosts entries. The HostWriter struct is now in charge of appending hosts to files, and the two are not dependent on each other. Users are now responsible for opening the known hosts file and closing it when they are finished with it. --- tofu/tofu.go | 249 ++++++++++++++++++++------------------------------- 1 file changed, 96 insertions(+), 153 deletions(-) diff --git a/tofu/tofu.go b/tofu/tofu.go index ac08d63..7dd5069 100644 --- a/tofu/tofu.go +++ b/tofu/tofu.go @@ -9,64 +9,23 @@ import ( "errors" "fmt" "io" - "os" "strconv" "strings" "sync" "time" ) -// HostsFile represents a list of known hosts optionally loaded from a file. -// The zero value for HostsFile represents an empty list ready to use. +// KnownHosts represents a list of known hosts. +// The zero value for KnownHosts represents an empty list ready to use. // -// HostsFile is safe for concurrent use by multiple goroutines. -type HostsFile struct { - hosts map[string]Host - writer *bufio.Writer - closer io.Closer - mu sync.RWMutex +// KnownHosts is safe for concurrent use by multiple goroutines. +type KnownHosts struct { + hosts map[string]Host + mu sync.RWMutex } -// SetOutput sets the output to which new known hosts will be written to. -func (k *HostsFile) SetOutput(w io.WriteCloser) error { - k.mu.Lock() - defer k.mu.Unlock() - - if k.writer != nil { - err := k.closer.Close() - if err != nil { - return fmt.Errorf("failed to close previous output: %w", err) - } - } - - k.writer = bufio.NewWriter(w) - k.closer = w - - return nil -} - -// Close closes the output. -func (k *HostsFile) Close() error { - k.mu.Lock() - defer k.mu.Unlock() - - if k.writer == nil { - return nil - } - - err := k.closer.Close() - if err != nil { - return err - } - - k.writer = nil - k.closer = nil - - return nil -} - -// Add adds a known host to the list of known hosts. -func (k *HostsFile) Add(h Host) error { +// Add adds a host to the list of known hosts. +func (k *KnownHosts) Add(h Host) error { k.mu.Lock() defer k.mu.Unlock() if k.hosts == nil { @@ -74,77 +33,45 @@ func (k *HostsFile) Add(h Host) error { } k.hosts[h.Hostname] = h - - if k.writer != nil { - h.WriteTo(k.writer) - k.writer.WriteByte('\n') - - if err := k.writer.Flush(); err != nil { - return fmt.Errorf("failed to write to known host file: %w", err) - } - } - return nil } -// Lookup returns the fingerprint of the certificate corresponding to -// the given hostname. -func (k *HostsFile) Lookup(hostname string) (Host, bool) { +// Lookup returns the known host entry corresponding to the given hostname. +func (k *KnownHosts) Lookup(hostname string) (Host, bool) { k.mu.RLock() defer k.mu.RUnlock() c, ok := k.hosts[hostname] return c, ok } -// WriteAll writes all of the known hosts to the provided io.Writer. -func (k *HostsFile) WriteTo(w io.Writer) (int64, error) { +// WriteTo writes the list of known hosts to the provided io.Writer. +func (k *KnownHosts) WriteTo(w io.Writer) (int64, error) { k.mu.RLock() defer k.mu.RUnlock() + var written int + bw := bufio.NewWriter(w) - - var written int64 - for _, h := range k.hosts { - n, err := h.WriteTo(bw) + n, err := bw.WriteString(h.String()) + written += n if err != nil { - return written, err + return int64(written), err } bw.WriteByte('\n') - written += n + 1 + written += 1 } - return written, bw.Flush() + return int64(written), bw.Flush() } -// Open loads the known hosts from the provided path. -// It creates the file if it does not exist. -// New known hosts will be appended to the file. -func (k *HostsFile) Open(path string) error { - f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644) - if err != nil { - return err - } - - err = k.Parse(f) - if err != nil { - _ = f.Close() - return err - } - - err = k.SetOutput(f) - if err != nil { - _ = f.Close() - return err - } - - return nil -} - -// Parse parses the provided reader and adds the parsed known hosts to the list. +// Parse parses the provided io.Reader and adds the parsed hosts to the list. // Invalid entries are ignored. -func (k *HostsFile) Parse(r io.Reader) error { +// +// For more control over errors encountered by parsing, scan the reader with a bufio.Scanner +// and call ParseHost with scanner.Bytes(). +func (k *KnownHosts) Parse(r io.Reader) error { k.mu.Lock() defer k.mu.Unlock() @@ -162,26 +89,19 @@ func (k *HostsFile) Parse(r io.Reader) error { continue } - var h Host - - err := h.UnmarshalText(text) + h, err := ParseHost(text) if err != nil { - return fmt.Errorf("error when parsing line %d: %w", - line, err) + continue } k.hosts[h.Hostname] = h } - err := scanner.Err() - if err != nil { - return fmt.Errorf("failed to read lines: %w", err) - } - - return nil + return scanner.Err() } -func (k *HostsFile) TOFU(hostname string, cert *x509.Certificate) error { +// TOFU implements a basic Trust On First Use flow. +func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error { host := NewHost(hostname, cert.Raw, cert.NotAfter) knownHost, ok := k.Lookup(hostname) @@ -198,23 +118,30 @@ func (k *HostsFile) TOFU(hostname string, cert *x509.Certificate) error { return nil } -type Fingerprint []byte - -func (f Fingerprint) String() string { - var sb strings.Builder - - for i, b := range f { - if i > 0 { - sb.WriteByte(':') - } - - fmt.Fprintf(&sb, "%02X", b) - } - - return sb.String() +// HostWriter writes host entries to an io.Writer. +type HostWriter struct { + bw *bufio.Writer } -// Host represents a known host entry for a fingerprint using a certain algorithm. +// NewHostsWriter returns a new host writer that writes to the provided writer. +func NewHostsWriter(w io.Writer) *HostWriter { + return &HostWriter{ + bw: bufio.NewWriter(w), + } +} + +// WriteHost writes the host to the underlying io.Writer. +func (f *HostWriter) WriteHost(host Host) error { + f.bw.WriteString(host.String()) + f.bw.WriteByte('\n') + + if err := f.bw.Flush(); err != nil { + return fmt.Errorf("failed to write to hosts file: %w", err) + } + return nil +} + +// Host represents a host entry with a fingerprint using a certain algorithm. type Host struct { Hostname string // hostname Algorithm string // fingerprint algorithm e.g. SHA-512 @@ -222,8 +149,8 @@ type Host struct { Expires time.Time // unix time of the fingerprint expiration date } -// NewHost returns the known host entry with a SHA-512 -// fingerprint of the provided raw data. +// NewHost returns a new host with a SHA-512 fingerprint of +// the provided raw data. func NewHost(hostname string, raw []byte, expires time.Time) Host { sum := sha512.Sum512(raw) @@ -235,6 +162,27 @@ func NewHost(hostname string, raw []byte, expires time.Time) Host { } } +// ParseHost parses a host from the provided text. +func ParseHost(text []byte) (Host, error) { + var h Host + err := h.UnmarshalText(text) + return h, err +} + +// String returns a string representation of the host. +func (h Host) String() string { + var b strings.Builder + b.WriteString(h.Hostname) + b.WriteByte(' ') + b.WriteString(h.Algorithm) + b.WriteByte(' ') + b.WriteString(h.Fingerprint.String()) + b.WriteByte(' ') + b.WriteString(strconv.FormatInt(h.Expires.Unix(), 10)) + return b.String() +} + +// UnmarshalText unmarshals the host from the provided text. func (h *Host) UnmarshalText(text []byte) error { const format = "hostname algorithm hex-fingerprint expiry-unix-ts" @@ -259,11 +207,11 @@ func (h *Host) UnmarshalText(text []byte) error { h.Algorithm = algorithm fingerprint := make([]byte, 0, sha512.Size) - scan := bufio.NewScanner(bytes.NewReader(parts[2])) - scan.Split(scanFingerprint) + scanner := bufio.NewScanner(bytes.NewReader(parts[2])) + scanner.Split(scanFingerprint) - for scan.Scan() { - b, err := strconv.ParseUint(scan.Text(), 16, 8) + for scanner.Scan() { + b, err := strconv.ParseUint(scanner.Text(), 16, 8) if err != nil { return fmt.Errorf("failed to parse fingerprint hash: %w", err) } @@ -288,29 +236,6 @@ func (h *Host) UnmarshalText(text []byte) error { return nil } -func (h *Host) WriteTo(w io.Writer) (int64, error) { - bw := bufio.NewWriter(w) - - var written, n int - - n, _ = bw.WriteString(h.Hostname) - bw.WriteByte(' ') - written += n + 1 - - n, _ = bw.WriteString(h.Algorithm) - bw.WriteByte(' ') - written += n + 1 - - n, _ = bw.WriteString(h.Fingerprint.String()) - bw.WriteByte(' ') - written += n + 1 - - n, _ = bw.WriteString(strconv.FormatInt(h.Expires.Unix(), 10)) - written += n - - return int64(written), bw.Flush() -} - func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil @@ -328,3 +253,21 @@ func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err er // Request more data. return 0, nil, nil } + +// Fingerprint represents a fingerprint. +type Fingerprint []byte + +// String returns a string representation of the fingerprint. +func (f Fingerprint) String() string { + var sb strings.Builder + + for i, b := range f { + if i > 0 { + sb.WriteByte(':') + } + + fmt.Fprintf(&sb, "%02X", b) + } + + return sb.String() +}