diff --git a/tofu/tofu.go b/tofu/tofu.go index bc4fb9d..ac08d63 100644 --- a/tofu/tofu.go +++ b/tofu/tofu.go @@ -16,41 +16,41 @@ import ( "time" ) -// KnownHostsFile represents a list of known hosts optionally loaded from a file. -// The zero value for KnownHostsFile represents an empty list ready to use. +// HostsFile represents a list of known hosts optionally loaded from a file. +// The zero value for HostsFile represents an empty list ready to use. // -// KnownHostsFile is safe for concurrent use by multiple goroutines. -type KnownHostsFile struct { - hosts map[string]KnownHost - out *bufio.Writer +// HostsFile is safe for concurrent use by multiple goroutines. +type HostsFile struct { + hosts map[string]Host + writer *bufio.Writer closer io.Closer mu sync.RWMutex } // SetOutput sets the output to which new known hosts will be written to. -func (k *KnownHostsFile) SetOutput(w io.WriteCloser) error { +func (k *HostsFile) SetOutput(w io.WriteCloser) error { k.mu.Lock() defer k.mu.Unlock() - if k.out != nil { + if k.writer != nil { err := k.closer.Close() if err != nil { return fmt.Errorf("failed to close previous output: %w", err) } } - k.out = bufio.NewWriter(w) + k.writer = bufio.NewWriter(w) k.closer = w return nil } // Close closes the output. -func (k *KnownHostsFile) Close() error { +func (k *HostsFile) Close() error { k.mu.Lock() defer k.mu.Unlock() - if k.out == nil { + if k.writer == nil { return nil } @@ -59,27 +59,27 @@ func (k *KnownHostsFile) Close() error { return err } - k.out = nil + k.writer = nil k.closer = nil return nil } // Add adds a known host to the list of known hosts. -func (k *KnownHostsFile) Add(h KnownHost) error { +func (k *HostsFile) Add(h Host) error { k.mu.Lock() defer k.mu.Unlock() if k.hosts == nil { - k.hosts = map[string]KnownHost{} + k.hosts = map[string]Host{} } k.hosts[h.Hostname] = h - if k.out != nil { - h.WriteTo(k.out) - k.out.WriteRune('\n') + if k.writer != nil { + h.WriteTo(k.writer) + k.writer.WriteByte('\n') - if err := k.out.Flush(); err != nil { + if err := k.writer.Flush(); err != nil { return fmt.Errorf("failed to write to known host file: %w", err) } } @@ -89,7 +89,7 @@ func (k *KnownHostsFile) Add(h KnownHost) error { // Lookup returns the fingerprint of the certificate corresponding to // the given hostname. -func (k *KnownHostsFile) Lookup(hostname string) (KnownHost, bool) { +func (k *HostsFile) Lookup(hostname string) (Host, bool) { k.mu.RLock() defer k.mu.RUnlock() c, ok := k.hosts[hostname] @@ -97,7 +97,7 @@ func (k *KnownHostsFile) Lookup(hostname string) (KnownHost, bool) { } // WriteAll writes all of the known hosts to the provided io.Writer. -func (k *KnownHostsFile) WriteTo(w io.Writer) (int64, error) { +func (k *HostsFile) WriteTo(w io.Writer) (int64, error) { k.mu.RLock() defer k.mu.RUnlock() @@ -121,7 +121,7 @@ func (k *KnownHostsFile) WriteTo(w io.Writer) (int64, error) { // Open loads the known hosts from the provided path. // It creates the file if it does not exist. // New known hosts will be appended to the file. -func (k *KnownHostsFile) Open(path string) error { +func (k *HostsFile) Open(path string) error { f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644) if err != nil { return err @@ -144,12 +144,12 @@ func (k *KnownHostsFile) Open(path string) error { // Parse parses the provided reader and adds the parsed known hosts to the list. // Invalid entries are ignored. -func (k *KnownHostsFile) Parse(r io.Reader) error { +func (k *HostsFile) Parse(r io.Reader) error { k.mu.Lock() defer k.mu.Unlock() if k.hosts == nil { - k.hosts = map[string]KnownHost{} + k.hosts = map[string]Host{} } scanner := bufio.NewScanner(r) @@ -162,7 +162,7 @@ func (k *KnownHostsFile) Parse(r io.Reader) error { continue } - var h KnownHost + var h Host err := h.UnmarshalText(text) if err != nil { @@ -181,8 +181,8 @@ func (k *KnownHostsFile) Parse(r io.Reader) error { return nil } -func (k *KnownHostsFile) TOFU(hostname string, cert *x509.Certificate) error { - host := NewKnownHost(hostname, cert.Raw, cert.NotAfter) +func (k *HostsFile) TOFU(hostname string, cert *x509.Certificate) error { + host := NewHost(hostname, cert.Raw, cert.NotAfter) knownHost, ok := k.Lookup(hostname) if !ok || time.Now().After(knownHost.Expires) { @@ -214,15 +214,28 @@ func (f Fingerprint) String() string { return sb.String() } -// KnownHost represents a known host entry for a fingerprint using a certain algorithm. -type KnownHost struct { +// Host represents a known host entry for a fingerprint using a certain algorithm. +type Host struct { Hostname string // hostname Algorithm string // fingerprint algorithm e.g. SHA-512 Fingerprint Fingerprint // fingerprint Expires time.Time // unix time of the fingerprint expiration date } -func (f *KnownHost) UnmarshalText(text []byte) error { +// NewHost returns the known host entry with a SHA-512 +// fingerprint of the provided raw data. +func NewHost(hostname string, raw []byte, expires time.Time) Host { + sum := sha512.Sum512(raw) + + return Host{ + Hostname: hostname, + Algorithm: "SHA-512", + Fingerprint: sum[:], + Expires: expires, + } +} + +func (h *Host) UnmarshalText(text []byte) error { const format = "hostname algorithm hex-fingerprint expiry-unix-ts" parts := bytes.Split(text, []byte(" ")) @@ -235,7 +248,7 @@ func (f *KnownHost) UnmarshalText(text []byte) error { return errors.New("empty hostname") } - f.Hostname = string(parts[0]) + h.Hostname = string(parts[0]) algorithm := string(parts[1]) if algorithm != "SHA-512" { @@ -243,7 +256,7 @@ func (f *KnownHost) UnmarshalText(text []byte) error { "unsupported algorithm %q", algorithm) } - f.Algorithm = algorithm + h.Algorithm = algorithm fingerprint := make([]byte, 0, sha512.Size) scan := bufio.NewScanner(bytes.NewReader(parts[2])) @@ -262,7 +275,7 @@ func (f *KnownHost) UnmarshalText(text []byte) error { len(fingerprint), sha512.Size) } - f.Fingerprint = fingerprint + h.Fingerprint = fingerprint unix, err := strconv.ParseInt(string(parts[3]), 10, 0) if err != nil { @@ -270,12 +283,12 @@ func (f *KnownHost) UnmarshalText(text []byte) error { "invalid unix timestamp: %w", err) } - f.Expires = time.Unix(unix, 0) + h.Expires = time.Unix(unix, 0) return nil } -func (h *KnownHost) WriteTo(w io.Writer) (int64, error) { +func (h *Host) WriteTo(w io.Writer) (int64, error) { bw := bufio.NewWriter(w) var written, n int @@ -315,16 +328,3 @@ func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err er // Request more data. return 0, nil, nil } - -// NewKnownHost returns the known host entry with a SHA-512 -// fingerprint of the provided raw data. -func NewKnownHost(hostname string, raw []byte, expires time.Time) KnownHost { - sum := sha512.Sum512(raw) - - return KnownHost{ - Hostname: hostname, - Algorithm: "SHA-512", - Fingerprint: sum[:], - Expires: expires, - } -}