diff --git a/examples/client.go b/examples/client.go index 251c6cd..60c8351 100644 --- a/examples/client.go +++ b/examples/client.go @@ -6,6 +6,7 @@ package main import ( "bufio" + "bytes" "crypto/x509" "errors" "fmt" @@ -29,7 +30,7 @@ var ( func init() { // Load known hosts file path := filepath.Join(xdg.DataHome(), "gemini", "known_hosts") - err := hosts.Load(path) + err := hosts.Open(path) if err != nil { log.Println(err) } @@ -47,25 +48,24 @@ Otherwise, this should be safe to trust. => ` func trustCertificate(hostname string, cert *x509.Certificate) error { - fingerprint := tofu.NewFingerprint(cert.Raw, cert.NotAfter) + host := tofu.NewKnownHost(hostname, cert.Raw, cert.NotAfter) + knownHost, ok := hosts.Lookup(hostname) if ok && time.Now().Before(knownHost.Expires) { // Check fingerprint - if knownHost.Hex == fingerprint.Hex { + if bytes.Equal(knownHost.Fingerprint, host.Fingerprint) { return nil } return errors.New("error: fingerprint does not match!") } - fmt.Printf(trustPrompt, hostname, fingerprint.Hex) + fmt.Printf(trustPrompt, hostname, host.Fingerprint) scanner.Scan() switch scanner.Text() { case "t": - hosts.Add(hostname, fingerprint) - hosts.Write(hostname, fingerprint) + hosts.Add(host) return nil case "o": - hosts.Add(hostname, fingerprint) return nil default: return errors.New("certificate not trusted") diff --git a/tofu/tofu.go b/tofu/tofu.go index 547894a..bc4fb9d 100644 --- a/tofu/tofu.go +++ b/tofu/tofu.go @@ -3,7 +3,10 @@ package tofu import ( "bufio" + "bytes" "crypto/sha512" + "crypto/x509" + "errors" "fmt" "io" "os" @@ -13,52 +16,70 @@ import ( "time" ) -// KnownHosts maps hosts to fingerprints. -type KnownHosts map[string]Fingerprint - // KnownHostsFile represents a list of known hosts optionally loaded from a file. // The zero value for KnownHostsFile represents an empty list ready to use. // // KnownHostsFile is safe for concurrent use by multiple goroutines. type KnownHostsFile struct { - KnownHosts - out io.Writer - mu sync.RWMutex + hosts map[string]KnownHost + out *bufio.Writer + closer io.Closer + mu sync.RWMutex } // SetOutput sets the output to which new known hosts will be written to. -func (k *KnownHostsFile) SetOutput(w io.Writer) { +func (k *KnownHostsFile) SetOutput(w io.WriteCloser) error { k.mu.Lock() defer k.mu.Unlock() - k.out = w + + if k.out != nil { + err := k.closer.Close() + if err != nil { + return fmt.Errorf("failed to close previous output: %w", err) + } + } + + k.out = bufio.NewWriter(w) + k.closer = w + + return nil +} + +// Close closes the output. +func (k *KnownHostsFile) Close() error { + k.mu.Lock() + defer k.mu.Unlock() + + if k.out == nil { + return nil + } + + err := k.closer.Close() + if err != nil { + return err + } + + k.out = nil + k.closer = nil + + return nil } // Add adds a known host to the list of known hosts. -func (k *KnownHostsFile) Add(hostname string, fingerprint Fingerprint) { +func (k *KnownHostsFile) Add(h KnownHost) error { k.mu.Lock() defer k.mu.Unlock() - if k.KnownHosts == nil { - k.KnownHosts = KnownHosts{} + if k.hosts == nil { + k.hosts = map[string]KnownHost{} } - k.KnownHosts[hostname] = fingerprint -} -// Lookup returns the fingerprint of the certificate corresponding to -// the given hostname. -func (k *KnownHostsFile) Lookup(hostname string) (Fingerprint, bool) { - k.mu.RLock() - defer k.mu.RUnlock() - c, ok := k.KnownHosts[hostname] - return c, ok -} + k.hosts[h.Hostname] = h -// Write writes a known hosts entry to the configured output. -func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) error { - k.mu.RLock() - defer k.mu.RUnlock() if k.out != nil { - _, err := k.writeKnownHost(k.out, hostname, fingerprint) - if err != nil { + h.WriteTo(k.out) + k.out.WriteRune('\n') + + if err := k.out.Flush(); err != nil { return fmt.Errorf("failed to write to known host file: %w", err) } } @@ -66,93 +87,244 @@ func (k *KnownHostsFile) Write(hostname string, fingerprint Fingerprint) error { return nil } -// WriteAll writes all of the known hosts to the provided io.Writer. -func (k *KnownHostsFile) WriteAll(w io.Writer) error { +// Lookup returns the fingerprint of the certificate corresponding to +// the given hostname. +func (k *KnownHostsFile) Lookup(hostname string) (KnownHost, bool) { k.mu.RLock() defer k.mu.RUnlock() - for h, c := range k.KnownHosts { - if _, err := k.writeKnownHost(w, h, c); err != nil { - return err + c, ok := k.hosts[hostname] + return c, ok +} + +// WriteAll writes all of the known hosts to the provided io.Writer. +func (k *KnownHostsFile) WriteTo(w io.Writer) (int64, error) { + k.mu.RLock() + defer k.mu.RUnlock() + + bw := bufio.NewWriter(w) + + var written int64 + + for _, h := range k.hosts { + n, err := h.WriteTo(bw) + if err != nil { + return written, err } + + bw.WriteByte('\n') + written += n + 1 } - return nil + + return written, bw.Flush() } -// writeKnownHost writes a known host to the provided io.Writer. -func (k *KnownHostsFile) writeKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) { - return fmt.Fprintf(w, "%s %s %s %d\n", hostname, f.Algorithm, f.Hex, f.Expires.Unix()) -} - -// Load loads the known hosts from the provided path. +// Open loads the known hosts from the provided path. // It creates the file if it does not exist. // New known hosts will be appended to the file. -func (k *KnownHostsFile) Load(path string) error { +func (k *KnownHostsFile) Open(path string) error { f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0644) if err != nil { return err } - k.Parse(f) - k.SetOutput(f) + + err = k.Parse(f) + if err != nil { + _ = f.Close() + return err + } + + err = k.SetOutput(f) + if err != nil { + _ = f.Close() + return err + } + return nil } // Parse parses the provided reader and adds the parsed known hosts to the list. // Invalid entries are ignored. -func (k *KnownHostsFile) Parse(r io.Reader) { +func (k *KnownHostsFile) Parse(r io.Reader) error { k.mu.Lock() defer k.mu.Unlock() - if k.KnownHosts == nil { - k.KnownHosts = map[string]Fingerprint{} + + if k.hosts == nil { + k.hosts = map[string]KnownHost{} } + scanner := bufio.NewScanner(r) + var line int for scanner.Scan() { - text := scanner.Text() - parts := strings.Split(text, " ") - if len(parts) < 4 { + line++ + + text := scanner.Bytes() + if len(text) == 0 { continue } - hostname := parts[0] - algorithm := parts[1] - if algorithm != "SHA-512" { - continue - } - hex := parts[2] + var h KnownHost - unix, err := strconv.ParseInt(parts[3], 10, 0) + err := h.UnmarshalText(text) if err != nil { - continue + return fmt.Errorf("error when parsing line %d: %w", + line, err) } - expires := time.Unix(unix, 0) - k.KnownHosts[hostname] = Fingerprint{ - Algorithm: algorithm, - Hex: hex, - Expires: expires, - } + k.hosts[h.Hostname] = h } + + err := scanner.Err() + if err != nil { + return fmt.Errorf("failed to read lines: %w", err) + } + + return nil } -// Fingerprint represents a fingerprint using a certain algorithm. -type Fingerprint struct { - Algorithm string // fingerprint algorithm e.g. SHA-512 - Hex string // fingerprint in hexadecimal, with ':' between each octet - Expires time.Time // unix time of the fingerprint expiration date +func (k *KnownHostsFile) TOFU(hostname string, cert *x509.Certificate) error { + host := NewKnownHost(hostname, cert.Raw, cert.NotAfter) + + knownHost, ok := k.Lookup(hostname) + if !ok || time.Now().After(knownHost.Expires) { + k.Add(host) + return nil + } + + // Check fingerprint + if !bytes.Equal(knownHost.Fingerprint, host.Fingerprint) { + return fmt.Errorf("fingerprint for %q does not match", hostname) + } + + return nil } -// NewFingerprint returns the SHA-512 fingerprint of the provided raw data. -func NewFingerprint(raw []byte, expires time.Time) Fingerprint { - sum512 := sha512.Sum512(raw) - var b strings.Builder - for i, f := range sum512 { +type Fingerprint []byte + +func (f Fingerprint) String() string { + var sb strings.Builder + + for i, b := range f { if i > 0 { - b.WriteByte(':') + sb.WriteByte(':') } - fmt.Fprintf(&b, "%02X", f) + + fmt.Fprintf(&sb, "%02X", b) } - return Fingerprint{ - Algorithm: "SHA-512", - Hex: b.String(), - Expires: expires, + + return sb.String() +} + +// KnownHost represents a known host entry for a fingerprint using a certain algorithm. +type KnownHost struct { + Hostname string // hostname + Algorithm string // fingerprint algorithm e.g. SHA-512 + Fingerprint Fingerprint // fingerprint + Expires time.Time // unix time of the fingerprint expiration date +} + +func (f *KnownHost) UnmarshalText(text []byte) error { + const format = "hostname algorithm hex-fingerprint expiry-unix-ts" + + parts := bytes.Split(text, []byte(" ")) + if len(parts) != 4 { + return fmt.Errorf( + "expected the format %q", format) + } + + if len(parts[0]) == 0 { + return errors.New("empty hostname") + } + + f.Hostname = string(parts[0]) + + algorithm := string(parts[1]) + if algorithm != "SHA-512" { + return fmt.Errorf( + "unsupported algorithm %q", algorithm) + } + + f.Algorithm = algorithm + + fingerprint := make([]byte, 0, sha512.Size) + scan := bufio.NewScanner(bytes.NewReader(parts[2])) + scan.Split(scanFingerprint) + + for scan.Scan() { + b, err := strconv.ParseUint(scan.Text(), 16, 8) + if err != nil { + return fmt.Errorf("failed to parse fingerprint hash: %w", err) + } + fingerprint = append(fingerprint, byte(b)) + } + + if len(fingerprint) != sha512.Size { + return fmt.Errorf("invalid fingerprint size %d, expected %d", + len(fingerprint), sha512.Size) + } + + f.Fingerprint = fingerprint + + unix, err := strconv.ParseInt(string(parts[3]), 10, 0) + if err != nil { + return fmt.Errorf( + "invalid unix timestamp: %w", err) + } + + f.Expires = time.Unix(unix, 0) + + return nil +} + +func (h *KnownHost) WriteTo(w io.Writer) (int64, error) { + bw := bufio.NewWriter(w) + + var written, n int + + n, _ = bw.WriteString(h.Hostname) + bw.WriteByte(' ') + written += n + 1 + + n, _ = bw.WriteString(h.Algorithm) + bw.WriteByte(' ') + written += n + 1 + + n, _ = bw.WriteString(h.Fingerprint.String()) + bw.WriteByte(' ') + written += n + 1 + + n, _ = bw.WriteString(strconv.FormatInt(h.Expires.Unix(), 10)) + written += n + + return int64(written), bw.Flush() +} + +func scanFingerprint(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, ':'); i >= 0 { + // We have a full newline-terminated line. + return i + 1, data[0:i], nil + } + + // If we're at EOF, we have a final, non-terminated hex byte + if atEOF { + return len(data), data, nil + } + + // Request more data. + return 0, nil, nil +} + +// NewKnownHost returns the known host entry with a SHA-512 +// fingerprint of the provided raw data. +func NewKnownHost(hostname string, raw []byte, expires time.Time) KnownHost { + sum := sha512.Sum512(raw) + + return KnownHost{ + Hostname: hostname, + Algorithm: "SHA-512", + Fingerprint: sum[:], + Expires: expires, } }