tofu: Refactor known hosts

This commit introduces the KnownHosts struct, whose purpose is simply to
store known hosts entries. The HostWriter struct is now in charge of
appending hosts to files, and the two are not dependent on each other.
Users are now responsible for opening the known hosts file and closing
it when they are finished with it.
This commit is contained in:
Adnan Maolood 2021-01-14 16:14:40 -05:00
parent 4b8bb16a3d
commit 9fe837ffac

View File

@ -9,64 +9,23 @@ import (
"errors"
"fmt"
"io"
"os"
"strconv"
"strings"
"sync"
"time"
)
// HostsFile represents a list of known hosts optionally loaded from a file.
// The zero value for HostsFile represents an empty list ready to use.
// KnownHosts represents a list of known hosts.
// The zero value for KnownHosts represents an empty list ready to use.
//
// HostsFile is safe for concurrent use by multiple goroutines.
type HostsFile struct {
hosts map[string]Host
writer *bufio.Writer
closer io.Closer
mu sync.RWMutex
// KnownHosts is safe for concurrent use by multiple goroutines.
type KnownHosts struct {
hosts map[string]Host
mu sync.RWMutex
}
// SetOutput sets the output to which new known hosts will be written to.
func (k *HostsFile) SetOutput(w io.WriteCloser) error {
k.mu.Lock()
defer k.mu.Unlock()
if k.writer != nil {
err := k.closer.Close()
if err != nil {
return fmt.Errorf("failed to close previous output: %w", err)
}
}
k.writer = bufio.NewWriter(w)
k.closer = w
return nil
}
// Close closes the output.
func (k *HostsFile) Close() error {
k.mu.Lock()
defer k.mu.Unlock()
if k.writer == nil {
return nil
}
err := k.closer.Close()
if err != nil {
return err
}
k.writer = nil
k.closer = nil
return nil
}
// Add adds a known host to the list of known hosts.
func (k *HostsFile) Add(h Host) error {
// Add adds a host to the list of known hosts.
func (k *KnownHosts) Add(h Host) error {
k.mu.Lock()
defer k.mu.Unlock()
if k.hosts == nil {
@ -74,77 +33,45 @@ func (k *HostsFile) Add(h Host) error {
}
k.hosts[h.Hostname] = h
if k.writer != nil {
h.WriteTo(k.writer)
k.writer.WriteByte('\n')
if err := k.writer.Flush(); err != nil {
return fmt.Errorf("failed to write to known host file: %w", err)
}
}
return nil
}
// Lookup returns the fingerprint of the certificate corresponding to
// the given hostname.
func (k *HostsFile) Lookup(hostname string) (Host, bool) {
// Lookup returns the known host entry corresponding to the given hostname.
func (k *KnownHosts) Lookup(hostname string) (Host, bool) {
k.mu.RLock()
defer k.mu.RUnlock()
c, ok := k.hosts[hostname]
return c, ok
}
// WriteAll writes all of the known hosts to the provided io.Writer.
func (k *HostsFile) WriteTo(w io.Writer) (int64, error) {
// WriteTo writes the list of known hosts to the provided io.Writer.
func (k *KnownHosts) WriteTo(w io.Writer) (int64, error) {
k.mu.RLock()
defer k.mu.RUnlock()
var written int
bw := bufio.NewWriter(w)
var written int64
for _, h := range k.hosts {
n, err := h.WriteTo(bw)
n, err := bw.WriteString(h.String())
written += n
if err != nil {
return written, err
return int64(written), err
}
bw.WriteByte('\n')
written += n + 1
written += 1
}
return written, bw.Flush()
return int64(written), bw.Flush()
}
// Open loads the known hosts from the provided path.
// It creates the file if it does not exist.
// New known hosts will be appended to the file.
func (k *HostsFile) Open(path string) error {
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return err
}
err = k.Parse(f)
if err != nil {
_ = f.Close()
return err
}
err = k.SetOutput(f)
if err != nil {
_ = f.Close()
return err
}
return nil
}
// Parse parses the provided reader and adds the parsed known hosts to the list.
// Parse parses the provided io.Reader and adds the parsed hosts to the list.
// Invalid entries are ignored.
func (k *HostsFile) Parse(r io.Reader) error {
//
// For more control over errors encountered by parsing, scan the reader with a bufio.Scanner
// and call ParseHost with scanner.Bytes().
func (k *KnownHosts) Parse(r io.Reader) error {
k.mu.Lock()
defer k.mu.Unlock()
@ -162,26 +89,19 @@ func (k *HostsFile) Parse(r io.Reader) error {
continue
}
var h Host
err := h.UnmarshalText(text)
h, err := ParseHost(text)
if err != nil {
return fmt.Errorf("error when parsing line %d: %w",
line, err)
continue
}
k.hosts[h.Hostname] = h
}
err := scanner.Err()
if err != nil {
return fmt.Errorf("failed to read lines: %w", err)
}
return nil
return scanner.Err()
}
func (k *HostsFile) TOFU(hostname string, cert *x509.Certificate) error {
// TOFU implements a basic Trust On First Use flow.
func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error {
host := NewHost(hostname, cert.Raw, cert.NotAfter)
knownHost, ok := k.Lookup(hostname)
@ -198,23 +118,30 @@ func (k *HostsFile) TOFU(hostname string, cert *x509.Certificate) error {
return nil
}
type Fingerprint []byte
func (f Fingerprint) String() string {
var sb strings.Builder
for i, b := range f {
if i > 0 {
sb.WriteByte(':')
}
fmt.Fprintf(&sb, "%02X", b)
}
return sb.String()
// HostWriter writes host entries to an io.Writer.
type HostWriter struct {
bw *bufio.Writer
}
// Host represents a known host entry for a fingerprint using a certain algorithm.
// NewHostsWriter returns a new host writer that writes to the provided writer.
func NewHostsWriter(w io.Writer) *HostWriter {
return &HostWriter{
bw: bufio.NewWriter(w),
}
}
// WriteHost writes the host to the underlying io.Writer.
func (f *HostWriter) WriteHost(host Host) error {
f.bw.WriteString(host.String())
f.bw.WriteByte('\n')
if err := f.bw.Flush(); err != nil {
return fmt.Errorf("failed to write to hosts file: %w", err)
}
return nil
}
// Host represents a host entry with a fingerprint using a certain algorithm.
type Host struct {
Hostname string // hostname
Algorithm string // fingerprint algorithm e.g. SHA-512
@ -222,8 +149,8 @@ type Host struct {
Expires time.Time // unix time of the fingerprint expiration date
}
// NewHost returns the known host entry with a SHA-512
// fingerprint of the provided raw data.
// NewHost returns a new host with a SHA-512 fingerprint of
// the provided raw data.
func NewHost(hostname string, raw []byte, expires time.Time) Host {
sum := sha512.Sum512(raw)
@ -235,6 +162,27 @@ func NewHost(hostname string, raw []byte, expires time.Time) Host {
}
}
// ParseHost parses a host from the provided text.
func ParseHost(text []byte) (Host, error) {
var h Host
err := h.UnmarshalText(text)
return h, err
}
// String returns a string representation of the host.
func (h Host) String() string {
var b strings.Builder
b.WriteString(h.Hostname)
b.WriteByte(' ')
b.WriteString(h.Algorithm)
b.WriteByte(' ')
b.WriteString(h.Fingerprint.String())
b.WriteByte(' ')
b.WriteString(strconv.FormatInt(h.Expires.Unix(), 10))
return b.String()
}
// UnmarshalText unmarshals the host from the provided text.
func (h *Host) UnmarshalText(text []byte) error {
const format = "hostname algorithm hex-fingerprint expiry-unix-ts"
@ -259,11 +207,11 @@ func (h *Host) UnmarshalText(text []byte) error {
h.Algorithm = algorithm
fingerprint := make([]byte, 0, sha512.Size)
scan := bufio.NewScanner(bytes.NewReader(parts[2]))
scan.Split(scanFingerprint)
scanner := bufio.NewScanner(bytes.NewReader(parts[2]))
scanner.Split(scanFingerprint)
for scan.Scan() {
b, err := strconv.ParseUint(scan.Text(), 16, 8)
for scanner.Scan() {
b, err := strconv.ParseUint(scanner.Text(), 16, 8)
if err != nil {
return fmt.Errorf("failed to parse fingerprint hash: %w", err)
}
@ -288,29 +236,6 @@ func (h *Host) UnmarshalText(text []byte) error {
return nil
}
func (h *Host) WriteTo(w io.Writer) (int64, error) {
bw := bufio.NewWriter(w)
var written, n int
n, _ = bw.WriteString(h.Hostname)
bw.WriteByte(' ')
written += n + 1
n, _ = bw.WriteString(h.Algorithm)
bw.WriteByte(' ')
written += n + 1
n, _ = bw.WriteString(h.Fingerprint.String())
bw.WriteByte(' ')
written += n + 1
n, _ = bw.WriteString(strconv.FormatInt(h.Expires.Unix(), 10))
written += n
return int64(written), bw.Flush()
}
func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
@ -328,3 +253,21 @@ func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err er
// Request more data.
return 0, nil, nil
}
// Fingerprint represents a fingerprint.
type Fingerprint []byte
// String returns a string representation of the fingerprint.
func (f Fingerprint) String() string {
var sb strings.Builder
for i, b := range f {
if i > 0 {
sb.WriteByte(':')
}
fmt.Fprintf(&sb, "%02X", b)
}
return sb.String()
}