Refactor KnownHosts
This commit is contained in:
parent
53390dad6b
commit
b76080c863
11
cert.go
11
cert.go
@ -52,16 +52,9 @@ func (c *CertificateStore) Add(scope string, cert tls.Certificate) error {
|
||||
}
|
||||
|
||||
// Lookup returns the certificate for the given scope.
|
||||
func (c *CertificateStore) Lookup(scope string) (*tls.Certificate, error) {
|
||||
func (c *CertificateStore) Lookup(scope string) (tls.Certificate, bool) {
|
||||
cert, ok := c.store[scope]
|
||||
if !ok {
|
||||
return nil, ErrCertificateNotFound
|
||||
}
|
||||
// Ensure that the certificate is not expired
|
||||
if cert.Leaf != nil && cert.Leaf.NotAfter.Before(time.Now()) {
|
||||
return &cert, ErrCertificateExpired
|
||||
}
|
||||
return &cert, nil
|
||||
return cert, ok
|
||||
}
|
||||
|
||||
// Load loads certificates from the given path.
|
||||
|
32
client.go
32
client.go
@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"path"
|
||||
@ -164,7 +165,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
}
|
||||
} else if len(via) > 5 {
|
||||
// Default policy of no more than 5 redirects
|
||||
return resp, ErrTooManyRedirects
|
||||
return resp, errors.New("gemini: too many redirects")
|
||||
}
|
||||
return c.do(redirect, via)
|
||||
}
|
||||
@ -182,13 +183,14 @@ func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
|
||||
// Search recursively for the certificate
|
||||
scope := req.URL.Hostname() + strings.TrimSuffix(req.URL.Path, "/")
|
||||
for {
|
||||
cert, err := c.Certificates.Lookup(scope)
|
||||
if err == nil {
|
||||
cert, ok := c.Certificates.Lookup(scope)
|
||||
if ok {
|
||||
// Ensure that the certificate is not expired
|
||||
if cert.Leaf != nil && !time.Now().After(cert.Leaf.NotAfter) {
|
||||
// Store the certificate
|
||||
req.Certificate = cert
|
||||
return cert, err
|
||||
req.Certificate = &cert
|
||||
return &cert, nil
|
||||
}
|
||||
if err == ErrCertificateExpired {
|
||||
break
|
||||
}
|
||||
scope = path.Dir(scope)
|
||||
@ -216,9 +218,17 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
|
||||
return nil
|
||||
}
|
||||
// Check the known hosts
|
||||
err := c.KnownHosts.Lookup(hostname, cert)
|
||||
switch err {
|
||||
case ErrCertificateExpired, ErrCertificateNotFound:
|
||||
knownHost, ok := c.KnownHosts.Lookup(hostname)
|
||||
if ok && time.Now().After(cert.NotAfter) {
|
||||
// Not expired
|
||||
fingerprint := NewFingerprint(cert)
|
||||
if knownHost.Hex != fingerprint.Hex {
|
||||
return errors.New("gemini: fingerprint does not match")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unknown certificate
|
||||
// See if the client trusts the certificate
|
||||
if c.TrustCertificate != nil {
|
||||
switch c.TrustCertificate(hostname, cert) {
|
||||
@ -230,7 +240,5 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrCertificateNotTrusted
|
||||
}
|
||||
return err
|
||||
return errors.New("gemini: certificate not trusted")
|
||||
}
|
||||
|
@ -33,7 +33,8 @@ func init() {
|
||||
client.Timeout = 30 * time.Second
|
||||
client.KnownHosts.LoadDefault()
|
||||
client.TrustCertificate = func(hostname string, cert *x509.Certificate) gemini.Trust {
|
||||
fmt.Printf(trustPrompt, hostname, gemini.Fingerprint(cert))
|
||||
fingerprint := gemini.NewFingerprint(cert)
|
||||
fmt.Printf(trustPrompt, hostname, fingerprint.Hex)
|
||||
scanner.Scan()
|
||||
switch scanner.Text() {
|
||||
case "t":
|
||||
|
@ -11,11 +11,7 @@ var crlf = []byte("\r\n")
|
||||
var (
|
||||
ErrInvalidURL = errors.New("gemini: invalid URL")
|
||||
ErrInvalidResponse = errors.New("gemini: invalid response")
|
||||
ErrCertificateExpired = errors.New("gemini: certificate expired")
|
||||
ErrCertificateNotFound = errors.New("gemini: certificate not found")
|
||||
ErrCertificateNotTrusted = errors.New("gemini: certificate not trusted")
|
||||
ErrBodyNotAllowed = errors.New("gemini: response body not allowed")
|
||||
ErrTooManyRedirects = errors.New("gemini: too many redirects")
|
||||
)
|
||||
|
||||
// defaultClient is the default client. It is used by Get and Do.
|
||||
|
13
server.go
13
server.go
@ -3,6 +3,7 @@ package gemini
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
@ -150,12 +151,12 @@ func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error
|
||||
|
||||
func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
|
||||
if _, ok := s.hosts[hostname]; !ok {
|
||||
return nil, ErrCertificateNotFound
|
||||
return nil, errors.New("hostname not registered")
|
||||
}
|
||||
cert, err := s.Certificates.Lookup(hostname)
|
||||
|
||||
switch err {
|
||||
case ErrCertificateNotFound, ErrCertificateExpired:
|
||||
// Generate a new certificate if it is missing or expired
|
||||
cert, ok := s.Certificates.Lookup(hostname)
|
||||
if !ok || cert.Leaf != nil && !time.Now().After(cert.Leaf.NotAfter) {
|
||||
if s.CreateCertificate != nil {
|
||||
cert, err := s.CreateCertificate(hostname)
|
||||
if err == nil {
|
||||
@ -165,9 +166,9 @@ func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
|
||||
}
|
||||
return &cert, err
|
||||
}
|
||||
return nil, errors.New("no certificate")
|
||||
}
|
||||
|
||||
return cert, err
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// respond responds to a connection.
|
||||
|
77
tofu.go
77
tofu.go
@ -8,9 +8,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Trust represents the trustworthiness of a certificate.
|
||||
@ -25,7 +23,7 @@ const (
|
||||
// KnownHosts represents a list of known hosts.
|
||||
// The zero value for KnownHosts is an empty list ready to use.
|
||||
type KnownHosts struct {
|
||||
hosts map[string]certInfo
|
||||
hosts map[string]Fingerprint
|
||||
file *os.File
|
||||
}
|
||||
|
||||
@ -80,53 +78,34 @@ func (k *KnownHosts) AddTemporary(hostname string, cert *x509.Certificate) {
|
||||
|
||||
func (k *KnownHosts) add(hostname string, cert *x509.Certificate, write bool) {
|
||||
if k.hosts == nil {
|
||||
k.hosts = map[string]certInfo{}
|
||||
k.hosts = map[string]Fingerprint{}
|
||||
}
|
||||
info := certInfo{
|
||||
Algorithm: "SHA-512",
|
||||
Fingerprint: Fingerprint(cert),
|
||||
Expires: cert.NotAfter.Unix(),
|
||||
}
|
||||
k.hosts[hostname] = info
|
||||
fingerprint := NewFingerprint(cert)
|
||||
k.hosts[hostname] = fingerprint
|
||||
// Append to the file
|
||||
if write && k.file != nil {
|
||||
appendKnownHost(k.file, hostname, info)
|
||||
appendKnownHost(k.file, hostname, fingerprint)
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup looks for the provided certificate in the list of known hosts.
|
||||
// If the hostname is not in the list, Lookup returns ErrCertificateNotFound.
|
||||
// If the fingerprint doesn't match, Lookup returns ErrCertificateNotTrusted.
|
||||
// Otherwise, Lookup returns nil.
|
||||
func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
|
||||
now := time.Now().Unix()
|
||||
fingerprint := Fingerprint(cert)
|
||||
if c, ok := k.hosts[hostname]; ok {
|
||||
if c.Expires <= now {
|
||||
// Certificate is expired
|
||||
return ErrCertificateExpired
|
||||
}
|
||||
if c.Fingerprint != fingerprint {
|
||||
// Fingerprint does not match
|
||||
return ErrCertificateNotTrusted
|
||||
}
|
||||
// Certificate is found
|
||||
return nil
|
||||
}
|
||||
return ErrCertificateNotFound
|
||||
// Lookup returns the fingerprint of the certificate corresponding to
|
||||
// the given hostname.
|
||||
func (k *KnownHosts) Lookup(hostname string) (Fingerprint, bool) {
|
||||
c, ok := k.hosts[hostname]
|
||||
return c, ok
|
||||
}
|
||||
|
||||
// Parse parses the provided reader and adds the parsed known hosts to the list.
|
||||
// Invalid lines are ignored.
|
||||
func (k *KnownHosts) Parse(r io.Reader) {
|
||||
if k.hosts == nil {
|
||||
k.hosts = map[string]certInfo{}
|
||||
k.hosts = map[string]Fingerprint{}
|
||||
}
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
parts := strings.Split(text, " ")
|
||||
if len(parts) < 4 {
|
||||
if len(parts) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -136,15 +115,10 @@ func (k *KnownHosts) Parse(r io.Reader) {
|
||||
continue
|
||||
}
|
||||
fingerprint := parts[2]
|
||||
expires, err := strconv.ParseInt(parts[3], 10, 0)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
k.hosts[hostname] = certInfo{
|
||||
k.hosts[hostname] = Fingerprint{
|
||||
Algorithm: algorithm,
|
||||
Fingerprint: fingerprint,
|
||||
Expires: expires,
|
||||
Hex: fingerprint,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -156,18 +130,18 @@ func (k *KnownHosts) Write(w io.Writer) {
|
||||
}
|
||||
}
|
||||
|
||||
type certInfo struct {
|
||||
func appendKnownHost(w io.Writer, hostname string, f Fingerprint) (int, error) {
|
||||
return fmt.Fprintf(w, "%s %s %s\n", hostname, f.Algorithm, f.Hex)
|
||||
}
|
||||
|
||||
// Fingerprint represents a fingerprint using a certain algorithm.
|
||||
type Fingerprint struct {
|
||||
Algorithm string // fingerprint algorithm e.g. SHA-512
|
||||
Fingerprint string // fingerprint in hexadecimal, with ':' between each octet
|
||||
Expires int64 // unix time of certificate notAfter date
|
||||
Hex string // fingerprint in hexadecimal, with ':' between each octet
|
||||
}
|
||||
|
||||
func appendKnownHost(w io.Writer, hostname string, c certInfo) (int, error) {
|
||||
return fmt.Fprintf(w, "%s %s %s %d\n", hostname, c.Algorithm, c.Fingerprint, c.Expires)
|
||||
}
|
||||
|
||||
// Fingerprint returns the SHA-512 fingerprint of the provided certificate.
|
||||
func Fingerprint(cert *x509.Certificate) string {
|
||||
// NewFingerprint returns the SHA-512 fingerprint of the provided certificate.
|
||||
func NewFingerprint(cert *x509.Certificate) Fingerprint {
|
||||
sum512 := sha512.Sum512(cert.Raw)
|
||||
var b strings.Builder
|
||||
for i, f := range sum512 {
|
||||
@ -176,7 +150,10 @@ func Fingerprint(cert *x509.Certificate) string {
|
||||
}
|
||||
fmt.Fprintf(&b, "%02X", f)
|
||||
}
|
||||
return b.String()
|
||||
return Fingerprint{
|
||||
Algorithm: "SHA-512",
|
||||
Hex: b.String(),
|
||||
}
|
||||
}
|
||||
|
||||
// defaultKnownHostsPath returns the default known_hosts path.
|
||||
|
Loading…
Reference in New Issue
Block a user