tofu: Implement PersistentHosts
This commit is contained in:
parent
3efa17f6fb
commit
62960266ac
@ -20,8 +20,7 @@ type Client struct {
|
||||
// If the returned error is not nil, the certificate will not be trusted
|
||||
// and the request will be aborted.
|
||||
//
|
||||
// For a basic trust on first use implementation, see (*KnownHosts).TOFU
|
||||
// in the tofu submodule.
|
||||
// See the tofu submodule for an implementation of trust on first use.
|
||||
TrustCertificate func(hostname string, cert *x509.Certificate) error
|
||||
|
||||
// Timeout specifies a time limit for requests made by this
|
||||
@ -88,8 +87,7 @@ func (c *Client) Do(req *Request) (*Response, error) {
|
||||
if c.Timeout != 0 {
|
||||
err := conn.SetDeadline(start.Add(c.Timeout))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to set connection deadline: %w", err)
|
||||
return nil, fmt.Errorf("failed to set connection deadline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,8 +112,7 @@ func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) {
|
||||
|
||||
err := req.Write(w)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to write request data: %w", err)
|
||||
return nil, fmt.Errorf("failed to write request: %w", err)
|
||||
}
|
||||
|
||||
if err := w.Flush(); err != nil {
|
||||
|
95
tofu/tofu.go
95
tofu/tofu.go
@ -27,7 +27,7 @@ type KnownHosts struct {
|
||||
}
|
||||
|
||||
// Add adds a host to the list of known hosts.
|
||||
func (k *KnownHosts) Add(h Host) error {
|
||||
func (k *KnownHosts) Add(h Host) {
|
||||
k.mu.Lock()
|
||||
defer k.mu.Unlock()
|
||||
if k.hosts == nil {
|
||||
@ -35,7 +35,6 @@ func (k *KnownHosts) Add(h Host) error {
|
||||
}
|
||||
|
||||
k.hosts[h.Hostname] = h
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lookup returns the known host entry corresponding to the given hostname.
|
||||
@ -144,7 +143,7 @@ func (k *KnownHosts) Parse(r io.Reader) error {
|
||||
// TOFU implements basic trust on first use.
|
||||
//
|
||||
// If the host is not on file, it is added to the list.
|
||||
// If the host on file is expired, it is replaced with the provided host.
|
||||
// If the host on file is expired, a new entry is added to the list.
|
||||
// If the fingerprint does not match the one on file, an error is returned.
|
||||
func (k *KnownHosts) TOFU(hostname string, cert *x509.Certificate) error {
|
||||
host := NewHost(hostname, cert.Raw, cert.NotAfter)
|
||||
@ -181,9 +180,9 @@ func NewHostWriter(w io.WriteCloser) *HostWriter {
|
||||
}
|
||||
}
|
||||
|
||||
// NewHostsFile returns a new host writer that appends to the file at the given path.
|
||||
// OpenHostsFile returns a new host writer that appends to the file at the given path.
|
||||
// The file is created if it does not exist.
|
||||
func NewHostsFile(path string) (*HostWriter, error) {
|
||||
func OpenHostsFile(path string) (*HostWriter, error) {
|
||||
f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -212,6 +211,83 @@ func (h *HostWriter) Close() error {
|
||||
return h.cl.Close()
|
||||
}
|
||||
|
||||
// PersistentHosts represents a persistent set of known hosts.
|
||||
type PersistentHosts struct {
|
||||
hosts *KnownHosts
|
||||
writer *HostWriter
|
||||
}
|
||||
|
||||
// NewPersistentHosts returns a new persistent set of known hosts.
|
||||
func NewPersistentHosts(hosts *KnownHosts, writer *HostWriter) *PersistentHosts {
|
||||
return &PersistentHosts{
|
||||
hosts,
|
||||
writer,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadPersistentHosts loads persistent hosts from the file at the given path.
|
||||
func LoadPersistentHosts(path string) (*PersistentHosts, error) {
|
||||
hosts := &KnownHosts{}
|
||||
if err := hosts.Load(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
writer, err := OpenHostsFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PersistentHosts{
|
||||
hosts,
|
||||
writer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Add adds a host to the list of known hosts.
|
||||
// It returns an error if the host could not be persisted.
|
||||
func (p *PersistentHosts) Add(h Host) error {
|
||||
err := p.writer.WriteHost(h)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to persist host: %w", err)
|
||||
}
|
||||
p.hosts.Add(h)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Lookup returns the known host entry corresponding to the given hostname.
|
||||
func (p *PersistentHosts) Lookup(hostname string) (Host, bool) {
|
||||
return p.hosts.Lookup(hostname)
|
||||
}
|
||||
|
||||
// Entries returns the known host entries sorted by hostname.
|
||||
func (p *PersistentHosts) Entries() []Host {
|
||||
return p.hosts.Entries()
|
||||
}
|
||||
|
||||
// TOFU implements trust on first use with a persistent set of known hosts.
|
||||
//
|
||||
// If the host is not on file, it is added to the list.
|
||||
// If the host on file is expired, a new entry is added to the list.
|
||||
// If the fingerprint does not match the one on file, an error is returned.
|
||||
func (p *PersistentHosts) TOFU(hostname string, cert *x509.Certificate) error {
|
||||
host := NewHost(hostname, cert.Raw, cert.NotAfter)
|
||||
|
||||
knownHost, ok := p.Lookup(hostname)
|
||||
if !ok || time.Now().After(knownHost.Expires) {
|
||||
return p.Add(host)
|
||||
}
|
||||
|
||||
// Check fingerprint
|
||||
if !bytes.Equal(knownHost.Fingerprint, host.Fingerprint) {
|
||||
return fmt.Errorf("fingerprint for %q does not match", hostname)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the underlying HostWriter.
|
||||
func (p *PersistentHosts) Close() error {
|
||||
return p.writer.Close()
|
||||
}
|
||||
|
||||
// Host represents a host entry with a fingerprint using a certain algorithm.
|
||||
type Host struct {
|
||||
Hostname string // hostname
|
||||
@ -259,8 +335,7 @@ func (h *Host) UnmarshalText(text []byte) error {
|
||||
|
||||
parts := bytes.Split(text, []byte(" "))
|
||||
if len(parts) != 4 {
|
||||
return fmt.Errorf(
|
||||
"expected the format %q", format)
|
||||
return fmt.Errorf("expected the format %q", format)
|
||||
}
|
||||
|
||||
if len(parts[0]) == 0 {
|
||||
@ -271,8 +346,7 @@ func (h *Host) UnmarshalText(text []byte) error {
|
||||
|
||||
algorithm := string(parts[1])
|
||||
if algorithm != "SHA-512" {
|
||||
return fmt.Errorf(
|
||||
"unsupported algorithm %q", algorithm)
|
||||
return fmt.Errorf("unsupported algorithm %q", algorithm)
|
||||
}
|
||||
|
||||
h.Algorithm = algorithm
|
||||
@ -298,8 +372,7 @@ func (h *Host) UnmarshalText(text []byte) error {
|
||||
|
||||
unix, err := strconv.ParseInt(string(parts[3]), 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"invalid unix timestamp: %w", err)
|
||||
return fmt.Errorf("invalid unix timestamp: %w", err)
|
||||
}
|
||||
|
||||
h.Expires = time.Unix(unix, 0)
|
||||
|
Loading…
Reference in New Issue
Block a user