tofu: Rename KnownHost to Host

This commit is contained in:
Adnan Maolood 2021-01-14 14:15:08 -05:00
parent 95aff9c573
commit 4b8bb16a3d

View File

@ -16,41 +16,41 @@ import (
"time" "time"
) )
// KnownHostsFile represents a list of known hosts optionally loaded from a file. // HostsFile represents a list of known hosts optionally loaded from a file.
// The zero value for KnownHostsFile represents an empty list ready to use. // The zero value for HostsFile represents an empty list ready to use.
// //
// KnownHostsFile is safe for concurrent use by multiple goroutines. // HostsFile is safe for concurrent use by multiple goroutines.
type KnownHostsFile struct { type HostsFile struct {
hosts map[string]KnownHost hosts map[string]Host
out *bufio.Writer writer *bufio.Writer
closer io.Closer closer io.Closer
mu sync.RWMutex mu sync.RWMutex
} }
// SetOutput sets the output to which new known hosts will be written to. // SetOutput sets the output to which new known hosts will be written to.
func (k *KnownHostsFile) SetOutput(w io.WriteCloser) error { func (k *HostsFile) SetOutput(w io.WriteCloser) error {
k.mu.Lock() k.mu.Lock()
defer k.mu.Unlock() defer k.mu.Unlock()
if k.out != nil { if k.writer != nil {
err := k.closer.Close() err := k.closer.Close()
if err != nil { if err != nil {
return fmt.Errorf("failed to close previous output: %w", err) return fmt.Errorf("failed to close previous output: %w", err)
} }
} }
k.out = bufio.NewWriter(w) k.writer = bufio.NewWriter(w)
k.closer = w k.closer = w
return nil return nil
} }
// Close closes the output. // Close closes the output.
func (k *KnownHostsFile) Close() error { func (k *HostsFile) Close() error {
k.mu.Lock() k.mu.Lock()
defer k.mu.Unlock() defer k.mu.Unlock()
if k.out == nil { if k.writer == nil {
return nil return nil
} }
@ -59,27 +59,27 @@ func (k *KnownHostsFile) Close() error {
return err return err
} }
k.out = nil k.writer = nil
k.closer = nil k.closer = nil
return nil return nil
} }
// Add adds a known host to the list of known hosts. // Add adds a known host to the list of known hosts.
func (k *KnownHostsFile) Add(h KnownHost) error { func (k *HostsFile) Add(h Host) error {
k.mu.Lock() k.mu.Lock()
defer k.mu.Unlock() defer k.mu.Unlock()
if k.hosts == nil { if k.hosts == nil {
k.hosts = map[string]KnownHost{} k.hosts = map[string]Host{}
} }
k.hosts[h.Hostname] = h k.hosts[h.Hostname] = h
if k.out != nil { if k.writer != nil {
h.WriteTo(k.out) h.WriteTo(k.writer)
k.out.WriteRune('\n') k.writer.WriteByte('\n')
if err := k.out.Flush(); err != nil { if err := k.writer.Flush(); err != nil {
return fmt.Errorf("failed to write to known host file: %w", err) return fmt.Errorf("failed to write to known host file: %w", err)
} }
} }
@ -89,7 +89,7 @@ func (k *KnownHostsFile) Add(h KnownHost) error {
// Lookup returns the fingerprint of the certificate corresponding to // Lookup returns the fingerprint of the certificate corresponding to
// the given hostname. // the given hostname.
func (k *KnownHostsFile) Lookup(hostname string) (KnownHost, bool) { func (k *HostsFile) Lookup(hostname string) (Host, bool) {
k.mu.RLock() k.mu.RLock()
defer k.mu.RUnlock() defer k.mu.RUnlock()
c, ok := k.hosts[hostname] c, ok := k.hosts[hostname]
@ -97,7 +97,7 @@ func (k *KnownHostsFile) Lookup(hostname string) (KnownHost, bool) {
} }
// WriteAll writes all of the known hosts to the provided io.Writer. // WriteAll writes all of the known hosts to the provided io.Writer.
func (k *KnownHostsFile) WriteTo(w io.Writer) (int64, error) { func (k *HostsFile) WriteTo(w io.Writer) (int64, error) {
k.mu.RLock() k.mu.RLock()
defer k.mu.RUnlock() defer k.mu.RUnlock()
@ -121,7 +121,7 @@ func (k *KnownHostsFile) WriteTo(w io.Writer) (int64, error) {
// Open loads the known hosts from the provided path. // Open loads the known hosts from the provided path.
// It creates the file if it does not exist. // It creates the file if it does not exist.
// New known hosts will be appended to the file. // New known hosts will be appended to the file.
func (k *KnownHostsFile) Open(path string) error { func (k *HostsFile) Open(path string) error {
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644) f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644)
if err != nil { if err != nil {
return err return err
@ -144,12 +144,12 @@ func (k *KnownHostsFile) Open(path string) error {
// Parse parses the provided reader and adds the parsed known hosts to the list. // Parse parses the provided reader and adds the parsed known hosts to the list.
// Invalid entries are ignored. // Invalid entries are ignored.
func (k *KnownHostsFile) Parse(r io.Reader) error { func (k *HostsFile) Parse(r io.Reader) error {
k.mu.Lock() k.mu.Lock()
defer k.mu.Unlock() defer k.mu.Unlock()
if k.hosts == nil { if k.hosts == nil {
k.hosts = map[string]KnownHost{} k.hosts = map[string]Host{}
} }
scanner := bufio.NewScanner(r) scanner := bufio.NewScanner(r)
@ -162,7 +162,7 @@ func (k *KnownHostsFile) Parse(r io.Reader) error {
continue continue
} }
var h KnownHost var h Host
err := h.UnmarshalText(text) err := h.UnmarshalText(text)
if err != nil { if err != nil {
@ -181,8 +181,8 @@ func (k *KnownHostsFile) Parse(r io.Reader) error {
return nil return nil
} }
func (k *KnownHostsFile) TOFU(hostname string, cert *x509.Certificate) error { func (k *HostsFile) TOFU(hostname string, cert *x509.Certificate) error {
host := NewKnownHost(hostname, cert.Raw, cert.NotAfter) host := NewHost(hostname, cert.Raw, cert.NotAfter)
knownHost, ok := k.Lookup(hostname) knownHost, ok := k.Lookup(hostname)
if !ok || time.Now().After(knownHost.Expires) { if !ok || time.Now().After(knownHost.Expires) {
@ -214,15 +214,28 @@ func (f Fingerprint) String() string {
return sb.String() return sb.String()
} }
// KnownHost represents a known host entry for a fingerprint using a certain algorithm. // Host represents a known host entry for a fingerprint using a certain algorithm.
type KnownHost struct { type Host struct {
Hostname string // hostname Hostname string // hostname
Algorithm string // fingerprint algorithm e.g. SHA-512 Algorithm string // fingerprint algorithm e.g. SHA-512
Fingerprint Fingerprint // fingerprint Fingerprint Fingerprint // fingerprint
Expires time.Time // unix time of the fingerprint expiration date Expires time.Time // unix time of the fingerprint expiration date
} }
func (f *KnownHost) UnmarshalText(text []byte) error { // NewHost returns the known host entry with a SHA-512
// fingerprint of the provided raw data.
func NewHost(hostname string, raw []byte, expires time.Time) Host {
sum := sha512.Sum512(raw)
return Host{
Hostname: hostname,
Algorithm: "SHA-512",
Fingerprint: sum[:],
Expires: expires,
}
}
func (h *Host) UnmarshalText(text []byte) error {
const format = "hostname algorithm hex-fingerprint expiry-unix-ts" const format = "hostname algorithm hex-fingerprint expiry-unix-ts"
parts := bytes.Split(text, []byte(" ")) parts := bytes.Split(text, []byte(" "))
@ -235,7 +248,7 @@ func (f *KnownHost) UnmarshalText(text []byte) error {
return errors.New("empty hostname") return errors.New("empty hostname")
} }
f.Hostname = string(parts[0]) h.Hostname = string(parts[0])
algorithm := string(parts[1]) algorithm := string(parts[1])
if algorithm != "SHA-512" { if algorithm != "SHA-512" {
@ -243,7 +256,7 @@ func (f *KnownHost) UnmarshalText(text []byte) error {
"unsupported algorithm %q", algorithm) "unsupported algorithm %q", algorithm)
} }
f.Algorithm = algorithm h.Algorithm = algorithm
fingerprint := make([]byte, 0, sha512.Size) fingerprint := make([]byte, 0, sha512.Size)
scan := bufio.NewScanner(bytes.NewReader(parts[2])) scan := bufio.NewScanner(bytes.NewReader(parts[2]))
@ -262,7 +275,7 @@ func (f *KnownHost) UnmarshalText(text []byte) error {
len(fingerprint), sha512.Size) len(fingerprint), sha512.Size)
} }
f.Fingerprint = fingerprint h.Fingerprint = fingerprint
unix, err := strconv.ParseInt(string(parts[3]), 10, 0) unix, err := strconv.ParseInt(string(parts[3]), 10, 0)
if err != nil { if err != nil {
@ -270,12 +283,12 @@ func (f *KnownHost) UnmarshalText(text []byte) error {
"invalid unix timestamp: %w", err) "invalid unix timestamp: %w", err)
} }
f.Expires = time.Unix(unix, 0) h.Expires = time.Unix(unix, 0)
return nil return nil
} }
func (h *KnownHost) WriteTo(w io.Writer) (int64, error) { func (h *Host) WriteTo(w io.Writer) (int64, error) {
bw := bufio.NewWriter(w) bw := bufio.NewWriter(w)
var written, n int var written, n int
@ -315,16 +328,3 @@ func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err er
// Request more data. // Request more data.
return 0, nil, nil return 0, nil, nil
} }
// NewKnownHost returns the known host entry with a SHA-512
// fingerprint of the provided raw data.
func NewKnownHost(hostname string, raw []byte, expires time.Time) KnownHost {
sum := sha512.Sum512(raw)
return KnownHost{
Hostname: hostname,
Algorithm: "SHA-512",
Fingerprint: sum[:],
Expires: expires,
}
}