diff --git a/client.go b/client.go index f1e4127..5a7f682 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package gemini import ( "bufio" + "bytes" "crypto/tls" "crypto/x509" "errors" @@ -221,7 +222,7 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error { // Check the known hosts knownHost, ok := c.KnownHosts.Lookup(hostname) - if !ok || time.Now().Unix() >= knownHost.Expires { + if !ok || !time.Now().Before(knownHost.Expires) { // See if the client trusts the certificate if c.TrustCertificate != nil { switch c.TrustCertificate(hostname, cert) { @@ -240,7 +241,7 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error { } fingerprint := NewFingerprint(cert.Raw, cert.NotAfter) - if knownHost.Hex == fingerprint.Hex { + if !bytes.Equal(knownHost.Raw, fingerprint.Raw) { return nil } return errors.New("gemini: fingerprint does not match") diff --git a/tofu.go b/tofu.go index de971b6..7e1067c 100644 --- a/tofu.go +++ b/tofu.go @@ -3,6 +3,7 @@ package gemini import ( "bufio" "crypto/sha512" + "encoding/base64" "fmt" "io" "os" @@ -20,7 +21,7 @@ const ( TrustAlways // The certificate is trusted always. ) -// KnownHosts represents a list of known hosts. +// KnownHosts maps hosts to fingerprints. type KnownHosts map[string]Fingerprint // KnownHostsFile represents a list of known hosts optionally loaded from a file. @@ -69,7 +70,8 @@ func (k *KnownHostsFile) WriteAll(w io.Writer) error { // writeKnownHost writes a known host to the provided io.Writer. func (k *KnownHostsFile) writeKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) { - return fmt.Fprintf(w, "%s %s %s %d\n", hostname, f.Algorithm, f.Hex, f.Expires) + s := base64.StdEncoding.EncodeToString([]byte(f.Raw)) + return fmt.Fprintf(w, "%s %s %s %d\n", hostname, f.Algorithm, s, f.Expires) } // Load loads the known hosts from the provided path. @@ -111,15 +113,20 @@ func (k *KnownHostsFile) Parse(r io.Reader) { continue } fingerprint := parts[2] - - expires, err := strconv.ParseInt(parts[3], 10, 0) + raw, err := base64.StdEncoding.DecodeString(fingerprint) if err != nil { continue } + unix, err := strconv.ParseInt(parts[3], 10, 0) + if err != nil { + continue + } + expires := time.Unix(unix, 0) + k.KnownHosts[hostname] = Fingerprint{ + Raw: raw, Algorithm: algorithm, - Hex: fingerprint, Expires: expires, } } @@ -127,24 +134,17 @@ func (k *KnownHostsFile) Parse(r io.Reader) { // Fingerprint represents a fingerprint using a certain algorithm. type Fingerprint struct { - Algorithm string // fingerprint algorithm e.g. SHA-512 - Hex string // fingerprint in hexadecimal, with ':' between each octet - Expires int64 // unix time of the fingerprint expiration date + Raw []byte // raw fingerprint data + Algorithm string // fingerprint algorithm e.g. SHA-512 + Expires time.Time // unix time of the fingerprint expiration date } // NewFingerprint returns the SHA-512 fingerprint of the provided raw data. func NewFingerprint(raw []byte, expires time.Time) Fingerprint { sum512 := sha512.Sum512(raw) - var b strings.Builder - for i, f := range sum512 { - if i > 0 { - b.WriteByte(':') - } - fmt.Fprintf(&b, "%02X", f) - } return Fingerprint{ + Raw: sum512[:], Algorithm: "SHA-512", - Hex: b.String(), - Expires: expires.Unix(), + Expires: expires, } }