tofu: Refactor known hosts

This commit introduces the KnownHosts struct, whose purpose is simply to
store known hosts entries. The HostWriter struct is now in charge of
appending hosts to files, and the two are not dependent on each other.
Users are now responsible for opening the known hosts file and closing
it when they are finished with it.
This commit is contained in:
Adnan Maolood 2021-01-14 16:14:40 -05:00
parent 4b8bb16a3d
commit 9fe837ffac

View File

@ -9,64 +9,23 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
// HostsFile represents a list of known hosts optionally loaded from a file. // KnownHosts represents a list of known hosts.
// The zero value for HostsFile represents an empty list ready to use. // The zero value for KnownHosts represents an empty list ready to use.
// //
// HostsFile is safe for concurrent use by multiple goroutines. // KnownHosts is safe for concurrent use by multiple goroutines.
type HostsFile struct { type KnownHosts struct {
hosts map[string]Host hosts map[string]Host
writer *bufio.Writer mu sync.RWMutex
closer io.Closer
mu sync.RWMutex
} }
// SetOutput sets the output to which new known hosts will be written to. // Add adds a host to the list of known hosts.
func (k *HostsFile) SetOutput(w io.WriteCloser) error { func (k *KnownHosts) Add(h Host) error {
k.mu.Lock()
defer k.mu.Unlock()
if k.writer != nil {
err := k.closer.Close()
if err != nil {
return fmt.Errorf("failed to close previous output: %w", err)
}
}
k.writer = bufio.NewWriter(w)
k.closer = w
return nil
}
// Close closes the output.
func (k *HostsFile) Close() error {
k.mu.Lock()
defer k.mu.Unlock()
if k.writer == nil {
return nil
}
err := k.closer.Close()
if err != nil {
return err
}
k.writer = nil
k.closer = nil
return nil
}
// Add adds a known host to the list of known hosts.
func (k *HostsFile) Add(h Host) error {
k.mu.Lock() k.mu.Lock()
defer k.mu.Unlock() defer k.mu.Unlock()
if k.hosts == nil { if k.hosts == nil {
@ -74,77 +33,45 @@ func (k *HostsFile) Add(h Host) error {
} }
k.hosts[h.Hostname] = h k.hosts[h.Hostname] = h
if k.writer != nil {
h.WriteTo(k.writer)
k.writer.WriteByte('\n')
if err := k.writer.Flush(); err != nil {
return fmt.Errorf("failed to write to known host file: %w", err)
}
}
return nil return nil
} }
// Lookup returns the fingerprint of the certificate corresponding to // Lookup returns the known host entry corresponding to the given hostname.
// the given hostname. func (k *KnownHosts) Lookup(hostname string) (Host, bool) {
func (k *HostsFile) Lookup(hostname string) (Host, bool) {
k.mu.RLock() k.mu.RLock()
defer k.mu.RUnlock() defer k.mu.RUnlock()
c, ok := k.hosts[hostname] c, ok := k.hosts[hostname]
return c, ok return c, ok
} }
// WriteAll writes all of the known hosts to the provided io.Writer. // WriteTo writes the list of known hosts to the provided io.Writer.
func (k *HostsFile) WriteTo(w io.Writer) (int64, error) { func (k *KnownHosts) WriteTo(w io.Writer) (int64, error) {
k.mu.RLock() k.mu.RLock()
defer k.mu.RUnlock() defer k.mu.RUnlock()
var written int
bw := bufio.NewWriter(w) bw := bufio.NewWriter(w)
var written int64
for _, h := range k.hosts { for _, h := range k.hosts {
n, err := h.WriteTo(bw) n, err := bw.WriteString(h.String())
written += n
if err != nil { if err != nil {
return written, err return int64(written), err
} }
bw.WriteByte('\n') bw.WriteByte('\n')
written += n + 1 written += 1
} }
return written, bw.Flush() return int64(written), bw.Flush()
} }
// Open loads the known hosts from the provided path. // Parse parses the provided io.Reader and adds the parsed hosts to the list.
// It creates the file if it does not exist.
// New known hosts will be appended to the file.
func (k *HostsFile) Open(path string) error {
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return err
}
err = k.Parse(f)
if err != nil {
_ = f.Close()
return err
}
err = k.SetOutput(f)
if err != nil {
_ = f.Close()
return err
}
return nil
}
// Parse parses the provided reader and adds the parsed known hosts to the list.
// Invalid entries are ignored. // Invalid entries are ignored.
func (k *HostsFile) Parse(r io.Reader) error { //
// For more control over errors encountered by parsing, scan the reader with a bufio.Scanner
// and call ParseHost with scanner.Bytes().
func (k *KnownHosts) Parse(r io.Reader) error {
k.mu.Lock() k.mu.Lock()
defer k.mu.Unlock() defer k.mu.Unlock()
@ -162,26 +89,19 @@ func (k *HostsFile) Parse(r io.Reader) error {
continue continue
} }
var h Host h, err := ParseHost(text)
err := h.UnmarshalText(text)
if err != nil { if err != nil {
return fmt.Errorf("error when parsing line %d: %w", continue
line, err)
} }
k.hosts[h.Hostname] = h k.hosts[h.Hostname] = h
} }
err := scanner.Err() return scanner.Err()
if err != nil {
return fmt.Errorf("failed to read lines: %w", err)
}
return nil
} }
func (k *HostsFile) TOFU(hostname string, cert *x509.Certificate) error { // TOFU implements a basic Trust On First Use flow.
func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error {
host := NewHost(hostname, cert.Raw, cert.NotAfter) host := NewHost(hostname, cert.Raw, cert.NotAfter)
knownHost, ok := k.Lookup(hostname) knownHost, ok := k.Lookup(hostname)
@ -198,23 +118,30 @@ func (k *HostsFile) TOFU(hostname string, cert *x509.Certificate) error {
return nil return nil
} }
type Fingerprint []byte // HostWriter writes host entries to an io.Writer.
type HostWriter struct {
func (f Fingerprint) String() string { bw *bufio.Writer
var sb strings.Builder
for i, b := range f {
if i > 0 {
sb.WriteByte(':')
}
fmt.Fprintf(&sb, "%02X", b)
}
return sb.String()
} }
// Host represents a known host entry for a fingerprint using a certain algorithm. // NewHostsWriter returns a new host writer that writes to the provided writer.
func NewHostsWriter(w io.Writer) *HostWriter {
return &HostWriter{
bw: bufio.NewWriter(w),
}
}
// WriteHost writes the host to the underlying io.Writer.
func (f *HostWriter) WriteHost(host Host) error {
f.bw.WriteString(host.String())
f.bw.WriteByte('\n')
if err := f.bw.Flush(); err != nil {
return fmt.Errorf("failed to write to hosts file: %w", err)
}
return nil
}
// Host represents a host entry with a fingerprint using a certain algorithm.
type Host struct { type Host struct {
Hostname string // hostname Hostname string // hostname
Algorithm string // fingerprint algorithm e.g. SHA-512 Algorithm string // fingerprint algorithm e.g. SHA-512
@ -222,8 +149,8 @@ type Host struct {
Expires time.Time // unix time of the fingerprint expiration date Expires time.Time // unix time of the fingerprint expiration date
} }
// NewHost returns the known host entry with a SHA-512 // NewHost returns a new host with a SHA-512 fingerprint of
// fingerprint of the provided raw data. // the provided raw data.
func NewHost(hostname string, raw []byte, expires time.Time) Host { func NewHost(hostname string, raw []byte, expires time.Time) Host {
sum := sha512.Sum512(raw) sum := sha512.Sum512(raw)
@ -235,6 +162,27 @@ func NewHost(hostname string, raw []byte, expires time.Time) Host {
} }
} }
// ParseHost parses a host from the provided text.
func ParseHost(text []byte) (Host, error) {
var h Host
err := h.UnmarshalText(text)
return h, err
}
// String returns a string representation of the host.
func (h Host) String() string {
var b strings.Builder
b.WriteString(h.Hostname)
b.WriteByte(' ')
b.WriteString(h.Algorithm)
b.WriteByte(' ')
b.WriteString(h.Fingerprint.String())
b.WriteByte(' ')
b.WriteString(strconv.FormatInt(h.Expires.Unix(), 10))
return b.String()
}
// UnmarshalText unmarshals the host from the provided text.
func (h *Host) UnmarshalText(text []byte) error { func (h *Host) UnmarshalText(text []byte) error {
const format = "hostname algorithm hex-fingerprint expiry-unix-ts" const format = "hostname algorithm hex-fingerprint expiry-unix-ts"
@ -259,11 +207,11 @@ func (h *Host) UnmarshalText(text []byte) error {
h.Algorithm = algorithm h.Algorithm = algorithm
fingerprint := make([]byte, 0, sha512.Size) fingerprint := make([]byte, 0, sha512.Size)
scan := bufio.NewScanner(bytes.NewReader(parts[2])) scanner := bufio.NewScanner(bytes.NewReader(parts[2]))
scan.Split(scanFingerprint) scanner.Split(scanFingerprint)
for scan.Scan() { for scanner.Scan() {
b, err := strconv.ParseUint(scan.Text(), 16, 8) b, err := strconv.ParseUint(scanner.Text(), 16, 8)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse fingerprint hash: %w", err) return fmt.Errorf("failed to parse fingerprint hash: %w", err)
} }
@ -288,29 +236,6 @@ func (h *Host) UnmarshalText(text []byte) error {
return nil return nil
} }
func (h *Host) WriteTo(w io.Writer) (int64, error) {
bw := bufio.NewWriter(w)
var written, n int
n, _ = bw.WriteString(h.Hostname)
bw.WriteByte(' ')
written += n + 1
n, _ = bw.WriteString(h.Algorithm)
bw.WriteByte(' ')
written += n + 1
n, _ = bw.WriteString(h.Fingerprint.String())
bw.WriteByte(' ')
written += n + 1
n, _ = bw.WriteString(strconv.FormatInt(h.Expires.Unix(), 10))
written += n
return int64(written), bw.Flush()
}
func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err error) { func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
@ -328,3 +253,21 @@ func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err er
// Request more data. // Request more data.
return 0, nil, nil return 0, nil, nil
} }
// Fingerprint represents a fingerprint.
type Fingerprint []byte
// String returns a string representation of the fingerprint.
func (f Fingerprint) String() string {
var sb strings.Builder
for i, b := range f {
if i > 0 {
sb.WriteByte(':')
}
fmt.Fprintf(&sb, "%02X", b)
}
return sb.String()
}