9fe837ffac
This commit introduces the KnownHosts struct, whose purpose is simply to store known hosts entries. The HostWriter struct is now in charge of appending hosts to files, and the two are not dependent on each other. Users are now responsible for opening the known hosts file and closing it when they are finished with it.
274 lines
6.0 KiB
Go
274 lines
6.0 KiB
Go
// Package tofu implements trust on first use using hosts and fingerprints.
|
|
package tofu
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/sha512"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// KnownHosts represents a list of known hosts.
|
|
// The zero value for KnownHosts represents an empty list ready to use.
|
|
//
|
|
// KnownHosts is safe for concurrent use by multiple goroutines.
|
|
type KnownHosts struct {
|
|
hosts map[string]Host
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// Add adds a host to the list of known hosts.
|
|
func (k *KnownHosts) Add(h Host) error {
|
|
k.mu.Lock()
|
|
defer k.mu.Unlock()
|
|
if k.hosts == nil {
|
|
k.hosts = map[string]Host{}
|
|
}
|
|
|
|
k.hosts[h.Hostname] = h
|
|
return nil
|
|
}
|
|
|
|
// Lookup returns the known host entry corresponding to the given hostname.
|
|
func (k *KnownHosts) Lookup(hostname string) (Host, bool) {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
c, ok := k.hosts[hostname]
|
|
return c, ok
|
|
}
|
|
|
|
// WriteTo writes the list of known hosts to the provided io.Writer.
|
|
func (k *KnownHosts) WriteTo(w io.Writer) (int64, error) {
|
|
k.mu.RLock()
|
|
defer k.mu.RUnlock()
|
|
|
|
var written int
|
|
|
|
bw := bufio.NewWriter(w)
|
|
for _, h := range k.hosts {
|
|
n, err := bw.WriteString(h.String())
|
|
written += n
|
|
if err != nil {
|
|
return int64(written), err
|
|
}
|
|
|
|
bw.WriteByte('\n')
|
|
written += 1
|
|
}
|
|
|
|
return int64(written), bw.Flush()
|
|
}
|
|
|
|
// Parse parses the provided io.Reader and adds the parsed hosts to the list.
|
|
// Invalid entries are ignored.
|
|
//
|
|
// For more control over errors encountered by parsing, scan the reader with a bufio.Scanner
|
|
// and call ParseHost with scanner.Bytes().
|
|
func (k *KnownHosts) Parse(r io.Reader) error {
|
|
k.mu.Lock()
|
|
defer k.mu.Unlock()
|
|
|
|
if k.hosts == nil {
|
|
k.hosts = map[string]Host{}
|
|
}
|
|
|
|
scanner := bufio.NewScanner(r)
|
|
var line int
|
|
for scanner.Scan() {
|
|
line++
|
|
|
|
text := scanner.Bytes()
|
|
if len(text) == 0 {
|
|
continue
|
|
}
|
|
|
|
h, err := ParseHost(text)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
k.hosts[h.Hostname] = h
|
|
}
|
|
|
|
return scanner.Err()
|
|
}
|
|
|
|
// TOFU implements a basic Trust On First Use flow.
|
|
func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error {
|
|
host := NewHost(hostname, cert.Raw, cert.NotAfter)
|
|
|
|
knownHost, ok := k.Lookup(hostname)
|
|
if !ok || time.Now().After(knownHost.Expires) {
|
|
k.Add(host)
|
|
return nil
|
|
}
|
|
|
|
// Check fingerprint
|
|
if !bytes.Equal(knownHost.Fingerprint, host.Fingerprint) {
|
|
return fmt.Errorf("fingerprint for %q does not match", hostname)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// HostWriter writes host entries to an io.Writer.
|
|
type HostWriter struct {
|
|
bw *bufio.Writer
|
|
}
|
|
|
|
// NewHostsWriter returns a new host writer that writes to the provided writer.
|
|
func NewHostsWriter(w io.Writer) *HostWriter {
|
|
return &HostWriter{
|
|
bw: bufio.NewWriter(w),
|
|
}
|
|
}
|
|
|
|
// WriteHost writes the host to the underlying io.Writer.
|
|
func (f *HostWriter) WriteHost(host Host) error {
|
|
f.bw.WriteString(host.String())
|
|
f.bw.WriteByte('\n')
|
|
|
|
if err := f.bw.Flush(); err != nil {
|
|
return fmt.Errorf("failed to write to hosts file: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Host represents a host entry with a fingerprint using a certain algorithm.
|
|
type Host struct {
|
|
Hostname string // hostname
|
|
Algorithm string // fingerprint algorithm e.g. SHA-512
|
|
Fingerprint Fingerprint // fingerprint
|
|
Expires time.Time // unix time of the fingerprint expiration date
|
|
}
|
|
|
|
// NewHost returns a new host with a SHA-512 fingerprint of
|
|
// the provided raw data.
|
|
func NewHost(hostname string, raw []byte, expires time.Time) Host {
|
|
sum := sha512.Sum512(raw)
|
|
|
|
return Host{
|
|
Hostname: hostname,
|
|
Algorithm: "SHA-512",
|
|
Fingerprint: sum[:],
|
|
Expires: expires,
|
|
}
|
|
}
|
|
|
|
// ParseHost parses a host from the provided text.
|
|
func ParseHost(text []byte) (Host, error) {
|
|
var h Host
|
|
err := h.UnmarshalText(text)
|
|
return h, err
|
|
}
|
|
|
|
// String returns a string representation of the host.
|
|
func (h Host) String() string {
|
|
var b strings.Builder
|
|
b.WriteString(h.Hostname)
|
|
b.WriteByte(' ')
|
|
b.WriteString(h.Algorithm)
|
|
b.WriteByte(' ')
|
|
b.WriteString(h.Fingerprint.String())
|
|
b.WriteByte(' ')
|
|
b.WriteString(strconv.FormatInt(h.Expires.Unix(), 10))
|
|
return b.String()
|
|
}
|
|
|
|
// UnmarshalText unmarshals the host from the provided text.
|
|
func (h *Host) UnmarshalText(text []byte) error {
|
|
const format = "hostname algorithm hex-fingerprint expiry-unix-ts"
|
|
|
|
parts := bytes.Split(text, []byte(" "))
|
|
if len(parts) != 4 {
|
|
return fmt.Errorf(
|
|
"expected the format %q", format)
|
|
}
|
|
|
|
if len(parts[0]) == 0 {
|
|
return errors.New("empty hostname")
|
|
}
|
|
|
|
h.Hostname = string(parts[0])
|
|
|
|
algorithm := string(parts[1])
|
|
if algorithm != "SHA-512" {
|
|
return fmt.Errorf(
|
|
"unsupported algorithm %q", algorithm)
|
|
}
|
|
|
|
h.Algorithm = algorithm
|
|
|
|
fingerprint := make([]byte, 0, sha512.Size)
|
|
scanner := bufio.NewScanner(bytes.NewReader(parts[2]))
|
|
scanner.Split(scanFingerprint)
|
|
|
|
for scanner.Scan() {
|
|
b, err := strconv.ParseUint(scanner.Text(), 16, 8)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse fingerprint hash: %w", err)
|
|
}
|
|
fingerprint = append(fingerprint, byte(b))
|
|
}
|
|
|
|
if len(fingerprint) != sha512.Size {
|
|
return fmt.Errorf("invalid fingerprint size %d, expected %d",
|
|
len(fingerprint), sha512.Size)
|
|
}
|
|
|
|
h.Fingerprint = fingerprint
|
|
|
|
unix, err := strconv.ParseInt(string(parts[3]), 10, 0)
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"invalid unix timestamp: %w", err)
|
|
}
|
|
|
|
h.Expires = time.Unix(unix, 0)
|
|
|
|
return nil
|
|
}
|
|
|
|
func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
if atEOF && len(data) == 0 {
|
|
return 0, nil, nil
|
|
}
|
|
if i := bytes.IndexByte(data, ':'); i >= 0 {
|
|
// We have a full newline-terminated line.
|
|
return i + 1, data[0:i], nil
|
|
}
|
|
|
|
// If we're at EOF, we have a final, non-terminated hex byte
|
|
if atEOF {
|
|
return len(data), data, nil
|
|
}
|
|
|
|
// Request more data.
|
|
return 0, nil, nil
|
|
}
|
|
|
|
// Fingerprint represents a fingerprint.
|
|
type Fingerprint []byte
|
|
|
|
// String returns a string representation of the fingerprint.
|
|
func (f Fingerprint) String() string {
|
|
var sb strings.Builder
|
|
|
|
for i, b := range f {
|
|
if i > 0 {
|
|
sb.WriteByte(':')
|
|
}
|
|
|
|
fmt.Fprintf(&sb, "%02X", b)
|
|
}
|
|
|
|
return sb.String()
|
|
}
|