tofu: Refactor
This commit changes underlying file handling and known hosts parsing. A known hosts file opened through Load() never closed the underlying file. During known hosts parsing most errors were unchecked, or just led to the line being skipped. I removed the KnownHosts type, which didn't really have a role after the refactor. The embedding of KnownHosts in KnownHosts file has been removed as it also leaked the map unprotected by the mutex. The Fingerprint type is now KnownHost and has taken over the responsibility of marshalling and unmarshalling. SetOutput now takes a WriteCloser so that we can close the underlying writer when it's replaced, or when it's explicitly closed through the new Close() function. KnownHostsFile.Add() now also writes the known host to the output if set. I think that makes sense expectation-wise for the type. Turned WriteAll() into WriteTo() to conform with the io.WriterTo interface. Load() is now Open() to better reflect the fact that a file is opened, and kept open. It can now also return errors from the parsing process. The parser does a lot more error checking, and this might be an area where I've changed a desired behaviour as invalid entries no longer are ignored, but aborts the parsing process. That could be changed to a warning, or some kind of parsing feedback. I added KnownHostsFile.TOFU() to fill the developer experience gap that was left after the client no longer knows about KnownHostsFile. It implements a basic non-interactive TOFU flow.
This commit is contained in:
parent
de042e4724
commit
95aff9c573
@ -6,6 +6,7 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -29,7 +30,7 @@ var (
|
||||
func init() {
|
||||
// Load known hosts file
|
||||
path := filepath.Join(xdg.DataHome(), "gemini", "known_hosts")
|
||||
err := hosts.Load(path)
|
||||
err := hosts.Open(path)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
@ -47,25 +48,24 @@ Otherwise, this should be safe to trust.
|
||||
=> `
|
||||
|
||||
func trustCertificate(hostname string, cert *x509.Certificate) error {
|
||||
fingerprint := tofu.NewFingerprint(cert.Raw, cert.NotAfter)
|
||||
host := tofu.NewKnownHost(hostname, cert.Raw, cert.NotAfter)
|
||||
|
||||
knownHost, ok := hosts.Lookup(hostname)
|
||||
if ok && time.Now().Before(knownHost.Expires) {
|
||||
// Check fingerprint
|
||||
if knownHost.Hex == fingerprint.Hex {
|
||||
if bytes.Equal(knownHost.Fingerprint, host.Fingerprint) {
|
||||
return nil
|
||||
}
|
||||
return errors.New("error: fingerprint does not match!")
|
||||
}
|
||||
|
||||
fmt.Printf(trustPrompt, hostname, fingerprint.Hex)
|
||||
fmt.Printf(trustPrompt, hostname, host.Fingerprint)
|
||||
scanner.Scan()
|
||||
switch scanner.Text() {
|
||||
case "t":
|
||||
hosts.Add(hostname, fingerprint)
|
||||
hosts.Write(hostname, fingerprint)
|
||||
hosts.Add(host)
|
||||
return nil
|
||||
case "o":
|
||||
hosts.Add(hostname, fingerprint)
|
||||
return nil
|
||||
default:
|
||||
return errors.New("certificate not trusted")
|
||||
|
328
tofu/tofu.go
328
tofu/tofu.go
@ -3,7 +3,10 @@ package tofu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/sha512"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
@ -13,52 +16,70 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// KnownHosts maps hosts to fingerprints.
|
||||
type KnownHosts map[string]Fingerprint
|
||||
|
||||
// KnownHostsFile represents a list of known hosts optionally loaded from a file.
|
||||
// The zero value for KnownHostsFile represents an empty list ready to use.
|
||||
//
|
||||
// KnownHostsFile is safe for concurrent use by multiple goroutines.
|
||||
type KnownHostsFile struct {
|
||||
KnownHosts
|
||||
out io.Writer
|
||||
mu sync.RWMutex
|
||||
hosts map[string]KnownHost
|
||||
out *bufio.Writer
|
||||
closer io.Closer
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SetOutput sets the output to which new known hosts will be written to.
|
||||
func (k *KnownHostsFile) SetOutput(w io.Writer) {
|
||||
func (k *KnownHostsFile) SetOutput(w io.WriteCloser) error {
|
||||
k.mu.Lock()
|
||||
defer k.mu.Unlock()
|
||||
k.out = w
|
||||
|
||||
if k.out != nil {
|
||||
err := k.closer.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close previous output: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
k.out = bufio.NewWriter(w)
|
||||
k.closer = w
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the output.
|
||||
func (k *KnownHostsFile) Close() error {
|
||||
k.mu.Lock()
|
||||
defer k.mu.Unlock()
|
||||
|
||||
if k.out == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := k.closer.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
k.out = nil
|
||||
k.closer = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add adds a known host to the list of known hosts.
|
||||
func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) {
|
||||
func (k *KnownHostsFile) Add(h KnownHost) error {
|
||||
k.mu.Lock()
|
||||
defer k.mu.Unlock()
|
||||
if k.KnownHosts == nil {
|
||||
k.KnownHosts = KnownHosts{}
|
||||
if k.hosts == nil {
|
||||
k.hosts = map[string]KnownHost{}
|
||||
}
|
||||
k.KnownHosts[hostname] = fingerprint
|
||||
}
|
||||
|
||||
// Lookup returns the fingerprint of the certificate corresponding to
|
||||
// the given hostname.
|
||||
func (k *KnownHostsFile) Lookup(hostname string) (Fingerprint, bool) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
c, ok := k.KnownHosts[hostname]
|
||||
return c, ok
|
||||
}
|
||||
k.hosts[h.Hostname] = h
|
||||
|
||||
// Write writes a known hosts entry to the configured output.
|
||||
func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) error {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
if k.out != nil {
|
||||
_, err := k.writeKnownHost(k.out, hostname, fingerprint)
|
||||
if err != nil {
|
||||
h.WriteTo(k.out)
|
||||
k.out.WriteRune('\n')
|
||||
|
||||
if err := k.out.Flush(); err != nil {
|
||||
return fmt.Errorf("failed to write to known host file: %w", err)
|
||||
}
|
||||
}
|
||||
@ -66,93 +87,244 @@ func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteAll writes all of the known hosts to the provided io.Writer.
|
||||
func (k *KnownHostsFile) WriteAll(w io.Writer) error {
|
||||
// Lookup returns the fingerprint of the certificate corresponding to
|
||||
// the given hostname.
|
||||
func (k *KnownHostsFile) Lookup(hostname string) (KnownHost, bool) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
for h, c := range k.KnownHosts {
|
||||
if _, err := k.writeKnownHost(w, h, c); err != nil {
|
||||
return err
|
||||
c, ok := k.hosts[hostname]
|
||||
return c, ok
|
||||
}
|
||||
|
||||
// WriteAll writes all of the known hosts to the provided io.Writer.
|
||||
func (k *KnownHostsFile) WriteTo(w io.Writer) (int64, error) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
|
||||
bw := bufio.NewWriter(w)
|
||||
|
||||
var written int64
|
||||
|
||||
for _, h := range k.hosts {
|
||||
n, err := h.WriteTo(bw)
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
|
||||
bw.WriteByte('\n')
|
||||
written += n + 1
|
||||
}
|
||||
return nil
|
||||
|
||||
return written, bw.Flush()
|
||||
}
|
||||
|
||||
// writeKnownHost writes a known host to the provided io.Writer.
|
||||
func (k *KnownHostsFile) writeKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) {
|
||||
return fmt.Fprintf(w, "%s %s %s %d\n", hostname, f.Algorithm, f.Hex, f.Expires.Unix())
|
||||
}
|
||||
|
||||
// Load loads the known hosts from the provided path.
|
||||
// Open loads the known hosts from the provided path.
|
||||
// It creates the file if it does not exist.
|
||||
// New known hosts will be appended to the file.
|
||||
func (k *KnownHostsFile) Load(path string) error {
|
||||
func (k *KnownHostsFile) Open(path string) error {
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k.Parse(f)
|
||||
k.SetOutput(f)
|
||||
|
||||
err = k.Parse(f)
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
err = k.SetOutput(f)
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse parses the provided reader and adds the parsed known hosts to the list.
|
||||
// Invalid entries are ignored.
|
||||
func (k *KnownHostsFile) Parse(r io.Reader) {
|
||||
func (k *KnownHostsFile) Parse(r io.Reader) error {
|
||||
k.mu.Lock()
|
||||
defer k.mu.Unlock()
|
||||
if k.KnownHosts == nil {
|
||||
k.KnownHosts = map[string]Fingerprint{}
|
||||
|
||||
if k.hosts == nil {
|
||||
k.hosts = map[string]KnownHost{}
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
var line int
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
parts := strings.Split(text, " ")
|
||||
if len(parts) < 4 {
|
||||
line++
|
||||
|
||||
text := scanner.Bytes()
|
||||
if len(text) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
hostname := parts[0]
|
||||
algorithm := parts[1]
|
||||
if algorithm != "SHA-512" {
|
||||
continue
|
||||
}
|
||||
hex := parts[2]
|
||||
var h KnownHost
|
||||
|
||||
unix, err := strconv.ParseInt(parts[3], 10, 0)
|
||||
err := h.UnmarshalText(text)
|
||||
if err != nil {
|
||||
continue
|
||||
return fmt.Errorf("error when parsing line %d: %w",
|
||||
line, err)
|
||||
}
|
||||
expires := time.Unix(unix, 0)
|
||||
|
||||
k.KnownHosts[hostname] = Fingerprint{
|
||||
Algorithm: algorithm,
|
||||
Hex: hex,
|
||||
Expires: expires,
|
||||
}
|
||||
k.hosts[h.Hostname] = h
|
||||
}
|
||||
|
||||
err := scanner.Err()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read lines: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fingerprint represents a fingerprint using a certain algorithm.
|
||||
type Fingerprint struct {
|
||||
Algorithm string // fingerprint algorithm e.g. SHA-512
|
||||
Hex string // fingerprint in hexadecimal, with ':' between each octet
|
||||
Expires time.Time // unix time of the fingerprint expiration date
|
||||
func (k *KnownHostsFile) TOFU(hostname string, cert *x509.Certificate) error {
|
||||
host := NewKnownHost(hostname, cert.Raw, cert.NotAfter)
|
||||
|
||||
knownHost, ok := k.Lookup(hostname)
|
||||
if !ok || time.Now().After(knownHost.Expires) {
|
||||
k.Add(host)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check fingerprint
|
||||
if !bytes.Equal(knownHost.Fingerprint, host.Fingerprint) {
|
||||
return fmt.Errorf("fingerprint for %q does not match", hostname)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewFingerprint returns the SHA-512 fingerprint of the provided raw data.
|
||||
func NewFingerprint(raw []byte, expires time.Time) Fingerprint {
|
||||
sum512 := sha512.Sum512(raw)
|
||||
var b strings.Builder
|
||||
for i, f := range sum512 {
|
||||
type Fingerprint []byte
|
||||
|
||||
func (f Fingerprint) String() string {
|
||||
var sb strings.Builder
|
||||
|
||||
for i, b := range f {
|
||||
if i > 0 {
|
||||
b.WriteByte(':')
|
||||
sb.WriteByte(':')
|
||||
}
|
||||
fmt.Fprintf(&b, "%02X", f)
|
||||
|
||||
fmt.Fprintf(&sb, "%02X", b)
|
||||
}
|
||||
return Fingerprint{
|
||||
Algorithm: "SHA-512",
|
||||
Hex: b.String(),
|
||||
Expires: expires,
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// KnownHost represents a known host entry for a fingerprint using a certain algorithm.
|
||||
type KnownHost struct {
|
||||
Hostname string // hostname
|
||||
Algorithm string // fingerprint algorithm e.g. SHA-512
|
||||
Fingerprint Fingerprint // fingerprint
|
||||
Expires time.Time // unix time of the fingerprint expiration date
|
||||
}
|
||||
|
||||
func (f *KnownHost) UnmarshalText(text []byte) error {
|
||||
const format = "hostname algorithm hex-fingerprint expiry-unix-ts"
|
||||
|
||||
parts := bytes.Split(text, []byte(" "))
|
||||
if len(parts) != 4 {
|
||||
return fmt.Errorf(
|
||||
"expected the format %q", format)
|
||||
}
|
||||
|
||||
if len(parts[0]) == 0 {
|
||||
return errors.New("empty hostname")
|
||||
}
|
||||
|
||||
f.Hostname = string(parts[0])
|
||||
|
||||
algorithm := string(parts[1])
|
||||
if algorithm != "SHA-512" {
|
||||
return fmt.Errorf(
|
||||
"unsupported algorithm %q", algorithm)
|
||||
}
|
||||
|
||||
f.Algorithm = algorithm
|
||||
|
||||
fingerprint := make([]byte, 0, sha512.Size)
|
||||
scan := bufio.NewScanner(bytes.NewReader(parts[2]))
|
||||
scan.Split(scanFingerprint)
|
||||
|
||||
for scan.Scan() {
|
||||
b, err := strconv.ParseUint(scan.Text(), 16, 8)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse fingerprint hash: %w", err)
|
||||
}
|
||||
fingerprint = append(fingerprint, byte(b))
|
||||
}
|
||||
|
||||
if len(fingerprint) != sha512.Size {
|
||||
return fmt.Errorf("invalid fingerprint size %d, expected %d",
|
||||
len(fingerprint), sha512.Size)
|
||||
}
|
||||
|
||||
f.Fingerprint = fingerprint
|
||||
|
||||
unix, err := strconv.ParseInt(string(parts[3]), 10, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"invalid unix timestamp: %w", err)
|
||||
}
|
||||
|
||||
f.Expires = time.Unix(unix, 0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *KnownHost) WriteTo(w io.Writer) (int64, error) {
|
||||
bw := bufio.NewWriter(w)
|
||||
|
||||
var written, n int
|
||||
|
||||
n, _ = bw.WriteString(h.Hostname)
|
||||
bw.WriteByte(' ')
|
||||
written += n + 1
|
||||
|
||||
n, _ = bw.WriteString(h.Algorithm)
|
||||
bw.WriteByte(' ')
|
||||
written += n + 1
|
||||
|
||||
n, _ = bw.WriteString(h.Fingerprint.String())
|
||||
bw.WriteByte(' ')
|
||||
written += n + 1
|
||||
|
||||
n, _ = bw.WriteString(strconv.FormatInt(h.Expires.Unix(), 10))
|
||||
written += n
|
||||
|
||||
return int64(written), bw.Flush()
|
||||
}
|
||||
|
||||
func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := bytes.IndexByte(data, ':'); i >= 0 {
|
||||
// We have a full newline-terminated line.
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
|
||||
// If we're at EOF, we have a final, non-terminated hex byte
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
|
||||
// Request more data.
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
// NewKnownHost returns the known host entry with a SHA-512
|
||||
// fingerprint of the provided raw data.
|
||||
func NewKnownHost(hostname string, raw []byte, expires time.Time) KnownHost {
|
||||
sum := sha512.Sum512(raw)
|
||||
|
||||
return KnownHost{
|
||||
Hostname: hostname,
|
||||
Algorithm: "SHA-512",
|
||||
Fingerprint: sum[:],
|
||||
Expires: expires,
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user