tofu: Implement PersistentHosts

This commit is contained in:
Adnan Maolood 2021-01-25 12:02:09 -05:00
parent 3efa17f6fb
commit 62960266ac
2 changed files with 87 additions and 17 deletions

View File

@ -20,8 +20,7 @@ type Client struct {
// If the returned error is not nil, the certificate will not be trusted
// and the request will be aborted.
//
// For a basic trust on first use implementation, see (*KnownHosts).TOFU
// in the tofu submodule.
// See the tofu submodule for an implementation of trust on first use.
TrustCertificate func(hostname string, cert *x509.Certificate) error
// Timeout specifies a time limit for requests made by this
@ -88,8 +87,7 @@ func (c *Client) Do(req *Request) (*Response, error) {
if c.Timeout != 0 {
err := conn.SetDeadline(start.Add(c.Timeout))
if err != nil {
return nil, fmt.Errorf(
"failed to set connection deadline: %w", err)
return nil, fmt.Errorf("failed to set connection deadline: %w", err)
}
}
@ -114,8 +112,7 @@ func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) {
err := req.Write(w)
if err != nil {
return nil, fmt.Errorf(
"failed to write request data: %w", err)
return nil, fmt.Errorf("failed to write request: %w", err)
}
if err := w.Flush(); err != nil {

View File

@ -27,7 +27,7 @@ type KnownHosts struct {
}
// Add adds a host to the list of known hosts.
func (k *KnownHosts) Add(h Host) error {
func (k *KnownHosts) Add(h Host) {
k.mu.Lock()
defer k.mu.Unlock()
if k.hosts == nil {
@ -35,7 +35,6 @@ func (k *KnownHosts) Add(h Host) error {
}
k.hosts[h.Hostname] = h
return nil
}
// Lookup returns the known host entry corresponding to the given hostname.
@ -144,7 +143,7 @@ func (k *KnownHosts) Parse(r io.Reader) error {
// TOFU implements basic trust on first use.
//
// If the host is not on file, it is added to the list.
// If the host on file is expired, it is replaced with the provided host.
// If the host on file is expired, a new entry is added to the list.
// If the fingerprint does not match the one on file, an error is returned.
func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error {
host := NewHost(hostname, cert.Raw, cert.NotAfter)
@ -181,9 +180,9 @@ func NewHostWriter(w io.WriteCloser) *HostWriter {
}
}
// NewHostsFile returns a new host writer that appends to the file at the given path.
// OpenHostsFile returns a new host writer that appends to the file at the given path.
// The file is created if it does not exist.
func NewHostsFile(path string) (*HostWriter, error) {
func OpenHostsFile(path string) (*HostWriter, error) {
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return nil, err
@ -212,6 +211,83 @@ func (h *HostWriter) Close() error {
return h.cl.Close()
}
// PersistentHosts represents a persistent set of known hosts.
type PersistentHosts struct {
hosts *KnownHosts
writer *HostWriter
}
// NewPersistentHosts returns a new persistent set of known hosts.
func NewPersistentHosts(hosts *KnownHosts, writer *HostWriter) *PersistentHosts {
return &PersistentHosts{
hosts,
writer,
}
}
// LoadPersistentHosts loads persistent hosts from the file at the given path.
func LoadPersistentHosts(path string) (*PersistentHosts, error) {
hosts := &KnownHosts{}
if err := hosts.Load(path); err != nil {
return nil, err
}
writer, err := OpenHostsFile(path)
if err != nil {
return nil, err
}
return &PersistentHosts{
hosts,
writer,
}, nil
}
// Add adds a host to the list of known hosts.
// It returns an error if the host could not be persisted.
func (p *PersistentHosts) Add(h Host) error {
err := p.writer.WriteHost(h)
if err != nil {
return fmt.Errorf("failed to persist host: %w", err)
}
p.hosts.Add(h)
return nil
}
// Lookup returns the known host entry corresponding to the given hostname.
func (p *PersistentHosts) Lookup(hostname string) (Host, bool) {
return p.hosts.Lookup(hostname)
}
// Entries returns the known host entries sorted by hostname.
func (p *PersistentHosts) Entries() []Host {
return p.hosts.Entries()
}
// TOFU implements trust on first use with a persistent set of known hosts.
//
// If the host is not on file, it is added to the list.
// If the host on file is expired, a new entry is added to the list.
// If the fingerprint does not match the one on file, an error is returned.
func (p *PersistentHosts) TOFU(hostname string, cert *x509.Certificate) error {
host := NewHost(hostname, cert.Raw, cert.NotAfter)
knownHost, ok := p.Lookup(hostname)
if !ok || time.Now().After(knownHost.Expires) {
return p.Add(host)
}
// Check fingerprint
if !bytes.Equal(knownHost.Fingerprint, host.Fingerprint) {
return fmt.Errorf("fingerprint for %q does not match", hostname)
}
return nil
}
// Close closes the underlying HostWriter.
func (p *PersistentHosts) Close() error {
return p.writer.Close()
}
// Host represents a host entry with a fingerprint using a certain algorithm.
type Host struct {
Hostname string // hostname
@ -259,8 +335,7 @@ func (h *Host) UnmarshalText(text []byte) error {
parts := bytes.Split(text, []byte(" "))
if len(parts) != 4 {
return fmt.Errorf(
"expected the format %q", format)
return fmt.Errorf("expected the format %q", format)
}
if len(parts[0]) == 0 {
@ -271,8 +346,7 @@ func (h *Host) UnmarshalText(text []byte) error {
algorithm := string(parts[1])
if algorithm != "SHA-512" {
return fmt.Errorf(
"unsupported algorithm %q", algorithm)
return fmt.Errorf("unsupported algorithm %q", algorithm)
}
h.Algorithm = algorithm
@ -298,8 +372,7 @@ func (h *Host) UnmarshalText(text []byte) error {
unix, err := strconv.ParseInt(string(parts[3]), 10, 0)
if err != nil {
return fmt.Errorf(
"invalid unix timestamp: %w", err)
return fmt.Errorf("invalid unix timestamp: %w", err)
}
h.Expires = time.Unix(unix, 0)