diff --git a/tofu/tofu.go b/tofu/tofu.go index ac08d63..7dd5069 100644 --- a/tofu/tofu.go +++ b/tofu/tofu.go @@ -9,64 +9,23 @@ import ( "errors" "fmt" "io" - "os" "strconv" "strings" "sync" "time" ) -// HostsFile represents a list of known hosts optionally loaded from a file. -// The zero value for HostsFile represents an empty list ready to use. +// KnownHosts represents a list of known hosts. +// The zero value for KnownHosts represents an empty list ready to use. // -// HostsFile is safe for concurrent use by multiple goroutines. -type HostsFile struct { - hosts map[string]Host - writer *bufio.Writer - closer io.Closer - mu sync.RWMutex +// KnownHosts is safe for concurrent use by multiple goroutines. +type KnownHosts struct { + hosts map[string]Host + mu sync.RWMutex } -// SetOutput sets the output to which new known hosts will be written to. -func (k *HostsFile) SetOutput(w io.WriteCloser) error { - k.mu.Lock() - defer k.mu.Unlock() - - if k.writer != nil { - err := k.closer.Close() - if err != nil { - return fmt.Errorf("failed to close previous output: %w", err) - } - } - - k.writer = bufio.NewWriter(w) - k.closer = w - - return nil -} - -// Close closes the output. -func (k *HostsFile) Close() error { - k.mu.Lock() - defer k.mu.Unlock() - - if k.writer == nil { - return nil - } - - err := k.closer.Close() - if err != nil { - return err - } - - k.writer = nil - k.closer = nil - - return nil -} - -// Add adds a known host to the list of known hosts. -func (k *HostsFile) Add(h Host) error { +// Add adds a host to the list of known hosts. +func (k *KnownHosts) Add(h Host) error { k.mu.Lock() defer k.mu.Unlock() if k.hosts == nil { @@ -74,77 +33,45 @@ func (k *HostsFile) Add(h Host) error { } k.hosts[h.Hostname] = h - - if k.writer != nil { - h.WriteTo(k.writer) - k.writer.WriteByte('\n') - - if err := k.writer.Flush(); err != nil { - return fmt.Errorf("failed to write to known host file: %w", err) - } - } - return nil } -// Lookup returns the fingerprint of the certificate corresponding to -// the given hostname. -func (k *HostsFile) Lookup(hostname string) (Host, bool) { +// Lookup returns the known host entry corresponding to the given hostname. +func (k *KnownHosts) Lookup(hostname string) (Host, bool) { k.mu.RLock() defer k.mu.RUnlock() c, ok := k.hosts[hostname] return c, ok } -// WriteAll writes all of the known hosts to the provided io.Writer. -func (k *HostsFile) WriteTo(w io.Writer) (int64, error) { +// WriteTo writes the list of known hosts to the provided io.Writer. +func (k *KnownHosts) WriteTo(w io.Writer) (int64, error) { k.mu.RLock() defer k.mu.RUnlock() + var written int + bw := bufio.NewWriter(w) - - var written int64 - for _, h := range k.hosts { - n, err := h.WriteTo(bw) + n, err := bw.WriteString(h.String()) + written += n if err != nil { - return written, err + return int64(written), err } bw.WriteByte('\n') - written += n + 1 + written += 1 } - return written, bw.Flush() + return int64(written), bw.Flush() } -// Open loads the known hosts from the provided path. -// It creates the file if it does not exist. -// New known hosts will be appended to the file. -func (k *HostsFile) Open(path string) error { - f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644) - if err != nil { - return err - } - - err = k.Parse(f) - if err != nil { - _ = f.Close() - return err - } - - err = k.SetOutput(f) - if err != nil { - _ = f.Close() - return err - } - - return nil -} - -// Parse parses the provided reader and adds the parsed known hosts to the list. +// Parse parses the provided io.Reader and adds the parsed hosts to the list. // Invalid entries are ignored. -func (k *HostsFile) Parse(r io.Reader) error { +// +// For more control over errors encountered by parsing, scan the reader with a bufio.Scanner +// and call ParseHost with scanner.Bytes(). +func (k *KnownHosts) Parse(r io.Reader) error { k.mu.Lock() defer k.mu.Unlock() @@ -162,26 +89,19 @@ func (k *HostsFile) Parse(r io.Reader) error { continue } - var h Host - - err := h.UnmarshalText(text) + h, err := ParseHost(text) if err != nil { - return fmt.Errorf("error when parsing line %d: %w", - line, err) + continue } k.hosts[h.Hostname] = h } - err := scanner.Err() - if err != nil { - return fmt.Errorf("failed to read lines: %w", err) - } - - return nil + return scanner.Err() } -func (k *HostsFile) TOFU(hostname string, cert *x509.Certificate) error { +// TOFU implements a basic Trust On First Use flow. +func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error { host := NewHost(hostname, cert.Raw, cert.NotAfter) knownHost, ok := k.Lookup(hostname) @@ -198,23 +118,30 @@ func (k *HostsFile) TOFU(hostname string, cert *x509.Certificate) error { return nil } -type Fingerprint []byte - -func (f Fingerprint) String() string { - var sb strings.Builder - - for i, b := range f { - if i > 0 { - sb.WriteByte(':') - } - - fmt.Fprintf(&sb, "%02X", b) - } - - return sb.String() +// HostWriter writes host entries to an io.Writer. +type HostWriter struct { + bw *bufio.Writer } -// Host represents a known host entry for a fingerprint using a certain algorithm. +// NewHostsWriter returns a new host writer that writes to the provided writer. +func NewHostsWriter(w io.Writer) *HostWriter { + return &HostWriter{ + bw: bufio.NewWriter(w), + } +} + +// WriteHost writes the host to the underlying io.Writer. +func (f *HostWriter) WriteHost(host Host) error { + f.bw.WriteString(host.String()) + f.bw.WriteByte('\n') + + if err := f.bw.Flush(); err != nil { + return fmt.Errorf("failed to write to hosts file: %w", err) + } + return nil +} + +// Host represents a host entry with a fingerprint using a certain algorithm. type Host struct { Hostname string // hostname Algorithm string // fingerprint algorithm e.g. SHA-512 @@ -222,8 +149,8 @@ type Host struct { Expires time.Time // unix time of the fingerprint expiration date } -// NewHost returns the known host entry with a SHA-512 -// fingerprint of the provided raw data. +// NewHost returns a new host with a SHA-512 fingerprint of +// the provided raw data. func NewHost(hostname string, raw []byte, expires time.Time) Host { sum := sha512.Sum512(raw) @@ -235,6 +162,27 @@ func NewHost(hostname string, raw []byte, expires time.Time) Host { } } +// ParseHost parses a host from the provided text. +func ParseHost(text []byte) (Host, error) { + var h Host + err := h.UnmarshalText(text) + return h, err +} + +// String returns a string representation of the host. +func (h Host) String() string { + var b strings.Builder + b.WriteString(h.Hostname) + b.WriteByte(' ') + b.WriteString(h.Algorithm) + b.WriteByte(' ') + b.WriteString(h.Fingerprint.String()) + b.WriteByte(' ') + b.WriteString(strconv.FormatInt(h.Expires.Unix(), 10)) + return b.String() +} + +// UnmarshalText unmarshals the host from the provided text. func (h *Host) UnmarshalText(text []byte) error { const format = "hostname algorithm hex-fingerprint expiry-unix-ts" @@ -259,11 +207,11 @@ func (h *Host) UnmarshalText(text []byte) error { h.Algorithm = algorithm fingerprint := make([]byte, 0, sha512.Size) - scan := bufio.NewScanner(bytes.NewReader(parts[2])) - scan.Split(scanFingerprint) + scanner := bufio.NewScanner(bytes.NewReader(parts[2])) + scanner.Split(scanFingerprint) - for scan.Scan() { - b, err := strconv.ParseUint(scan.Text(), 16, 8) + for scanner.Scan() { + b, err := strconv.ParseUint(scanner.Text(), 16, 8) if err != nil { return fmt.Errorf("failed to parse fingerprint hash: %w", err) } @@ -288,29 +236,6 @@ func (h *Host) UnmarshalText(text []byte) error { return nil } -func (h *Host) WriteTo(w io.Writer) (int64, error) { - bw := bufio.NewWriter(w) - - var written, n int - - n, _ = bw.WriteString(h.Hostname) - bw.WriteByte(' ') - written += n + 1 - - n, _ = bw.WriteString(h.Algorithm) - bw.WriteByte(' ') - written += n + 1 - - n, _ = bw.WriteString(h.Fingerprint.String()) - bw.WriteByte(' ') - written += n + 1 - - n, _ = bw.WriteString(strconv.FormatInt(h.Expires.Unix(), 10)) - written += n - - return int64(written), bw.Flush() -} - func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil @@ -328,3 +253,21 @@ func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err er // Request more data. return 0, nil, nil } + +// Fingerprint represents a fingerprint. +type Fingerprint []byte + +// String returns a string representation of the fingerprint. +func (f Fingerprint) String() string { + var sb strings.Builder + + for i, b := range f { + if i > 0 { + sb.WriteByte(':') + } + + fmt.Fprintf(&sb, "%02X", b) + } + + return sb.String() +}