Refactor client.TrustCertificate workflow

This commit is contained in:
Adnan Maolood 2020-10-31 22:34:51 -04:00
parent b9cb7fe71d
commit 65a5065250
6 changed files with 81 additions and 108 deletions

View File

@ -16,7 +16,7 @@ import (
"time" "time"
) )
// CertificateStore maps hostnames to certificates. // CertificateStore maps certificate scopes to certificates.
// The zero value of CertificateStore is an empty store ready to use. // The zero value of CertificateStore is an empty store ready to use.
type CertificateStore struct { type CertificateStore struct {
store map[string]tls.Certificate store map[string]tls.Certificate

View File

@ -43,12 +43,10 @@ type Client struct {
// the request will not be sent again and the response will be returned. // the request will not be sent again and the response will be returned.
CreateCertificate func(hostname, path string) (tls.Certificate, error) CreateCertificate func(hostname, path string) (tls.Certificate, error)
// TrustCertificate determines whether the client should trust // TrustCertificate is called to determine whether the client
// the provided certificate. // should trust a certificate it has not seen before.
// If the returned error is not nil, the connection will be aborted. // If TrustCertificate is nil, the certificate will not be trusted.
// If TrustCertificate is nil, the client will check KnownHosts TrustCertificate func(hostname string, cert *x509.Certificate) Trust
// for the certificate.
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
} }
// Get performs a Gemini request for the given url. // Get performs a Gemini request for the given url.
@ -198,10 +196,25 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
if err := verifyHostname(cert, hostname); err != nil { if err := verifyHostname(cert, hostname); err != nil {
return err return err
} }
// Check that the client trusts the certificate // Check the known hosts
var err error err := c.KnownHosts.Lookup(hostname, cert)
switch err {
case ErrCertificateExpired, ErrCertificateNotFound:
default:
return err
}
// See if the client trusts the certificate
if c.TrustCertificate != nil { if c.TrustCertificate != nil {
return c.TrustCertificate(hostname, cert, &c.KnownHosts) switch c.TrustCertificate(hostname, cert) {
case TrustOnce:
c.KnownHosts.AddTemporary(hostname, cert)
return nil
case TrustAlways:
c.KnownHosts.Add(hostname, cert)
return nil
default:
return ErrCertificateNotTrusted
}
} else { } else {
err = c.KnownHosts.Lookup(hostname, cert) err = c.KnownHosts.Lookup(hostname, cert)
} }

View File

@ -8,12 +8,22 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
"os" "os"
"time" "time"
"git.sr.ht/~adnano/go-gemini" "git.sr.ht/~adnano/go-gemini"
) )
const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is:
%s
If you knew the fingerprint to expect in advance, verify that this matches.
Otherwise, this should be safe to trust.
[t]rust always; trust [o]nce; [a]bort
=> `
var ( var (
scanner = bufio.NewScanner(os.Stdin) scanner = bufio.NewScanner(os.Stdin)
client = &gemini.Client{} client = &gemini.Client{}
@ -22,30 +32,17 @@ var (
func init() { func init() {
client.Timeout = 2 * time.Minute client.Timeout = 2 * time.Minute
client.KnownHosts.LoadDefault() client.KnownHosts.LoadDefault()
client.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *gemini.KnownHosts) error { client.TrustCertificate = func(hostname string, cert *x509.Certificate) gemini.Trust {
err := knownHosts.Lookup(hostname, cert) fmt.Printf(trustPrompt, hostname, gemini.Fingerprint(cert))
if err != nil { scanner.Scan()
switch err { switch scanner.Text() {
case gemini.ErrCertificateNotTrusted: case "t":
// Alert the user that the certificate is not trusted return gemini.TrustAlways
fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname) case "o":
fmt.Println("This could indicate a Man-in-the-Middle attack.") return gemini.TrustOnce
case gemini.ErrCertificateUnknown: default:
// Prompt the user to trust the certificate return gemini.TrustNone
trust := trustCertificate(cert)
switch trust {
case trustOnce:
// Temporarily trust the certificate
knownHosts.AddTemporary(hostname, cert)
return nil
case trustAlways:
// Add the certificate to the known hosts file
knownHosts.Add(hostname, cert)
return nil
}
}
} }
return err
} }
client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) { client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
fmt.Println("Generating client certificate for", hostname, path) fmt.Println("Generating client certificate for", hostname, path)
@ -60,54 +57,6 @@ func init() {
} }
} }
func doRequest(req *gemini.Request) error {
resp, err := client.Do(req)
if err != nil {
return err
}
if resp.Status.Class() == gemini.StatusClassSuccess {
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return err
}
fmt.Print(string(body))
return nil
}
return fmt.Errorf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta)
}
type trust int
const (
trustAbort trust = iota
trustOnce
trustAlways
)
const trustPrompt = `The certificate offered by this server is of unknown trust. Its fingerprint is:
%s
If you knew the fingerprint to expect in advance, verify that this matches.
Otherwise, this should be safe to trust.
[t]rust always; trust [o]nce; [a]bort
=> `
func trustCertificate(cert *x509.Certificate) trust {
fmt.Printf(trustPrompt, gemini.Fingerprint(cert))
scanner.Scan()
switch scanner.Text() {
case "t":
return trustAlways
case "o":
return trustOnce
default:
return trustAbort
}
}
func main() { func main() {
if len(os.Args) < 2 { if len(os.Args) < 2 {
fmt.Printf("usage: %s gemini://... [host]", os.Args[0]) fmt.Printf("usage: %s gemini://... [host]", os.Args[0])
@ -125,8 +74,19 @@ func main() {
req.Host = os.Args[2] req.Host = os.Args[2]
} }
if err := doRequest(req); err != nil { resp, err := client.Do(req)
fmt.Println(err) if err != nil {
os.Exit(1) log.Fatal(err)
}
if resp.Status.Class() == gemini.StatusClassSuccess {
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatal(err)
}
fmt.Print(string(body))
} else {
log.Fatalf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta)
} }
} }

View File

@ -1,11 +1,8 @@
package gemini package gemini
import ( import (
"crypto/tls"
"crypto/x509"
"errors" "errors"
"sync" "sync"
"time"
) )
var crlf = []byte("\r\n") var crlf = []byte("\r\n")
@ -14,16 +11,15 @@ var crlf = []byte("\r\n")
var ( var (
ErrInvalidURL = errors.New("gemini: invalid URL") ErrInvalidURL = errors.New("gemini: invalid URL")
ErrInvalidResponse = errors.New("gemini: invalid response") ErrInvalidResponse = errors.New("gemini: invalid response")
ErrCertificateUnknown = errors.New("gemini: unknown certificate")
ErrCertificateExpired = errors.New("gemini: certificate expired") ErrCertificateExpired = errors.New("gemini: certificate expired")
ErrCertificateNotFound = errors.New("gemini: certificate not found")
ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted") ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted")
ErrCertificateRequired = errors.New("gemini: certificate required")
ErrNotAFile = errors.New("gemini: not a file") ErrNotAFile = errors.New("gemini: not a file")
ErrNotAGeminiURL = errors.New("gemini: not a Gemini URL") ErrNotAGeminiURL = errors.New("gemini: not a Gemini URL")
ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body") ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body")
ErrTooManyRedirects = errors.New("gemini: too many redirects") ErrTooManyRedirects = errors.New("gemini: too many redirects")
ErrInputRequired = errors.New("gemini: input required") ErrInputRequired = errors.New("gemini: input required")
ErrCertificateRequired = errors.New("gemini: certificate required")
ErrCertificateNotFound = errors.New("gemini: certificate not found")
) )
// DefaultClient is the default client. It is used by Get and Do. // DefaultClient is the default client. It is used by Get and Do.
@ -35,6 +31,7 @@ var DefaultClient Client
// //
// Get is a wrapper around DefaultClient.Get. // Get is a wrapper around DefaultClient.Get.
func Get(url string) (*Response, error) { func Get(url string) (*Response, error) {
setupDefaultClientOnce()
return DefaultClient.Get(url) return DefaultClient.Get(url)
} }
@ -42,19 +39,14 @@ func Get(url string) (*Response, error) {
// //
// Do is a wrapper around DefaultClient.Do. // Do is a wrapper around DefaultClient.Do.
func Do(req *Request) (*Response, error) { func Do(req *Request) (*Response, error) {
setupDefaultClientOnce()
return DefaultClient.Do(req) return DefaultClient.Do(req)
} }
var defaultClientOnce sync.Once var defaultClientOnce sync.Once
func init() { func setupDefaultClientOnce() {
DefaultClient.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error { defaultClientOnce.Do(func() {
defaultClientOnce.Do(func() { knownHosts.LoadDefault() }) DefaultClient.KnownHosts.LoadDefault()
return knownHosts.Lookup(hostname, cert) })
}
DefaultClient.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
return CreateCertificate(CertificateOptions{
Duration: time.Hour,
})
}
} }

View File

@ -146,7 +146,7 @@ func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error
func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) { func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
if _, ok := s.hosts[hostname]; !ok { if _, ok := s.hosts[hostname]; !ok {
return nil, ErrCertificateUnknown return nil, ErrCertificateNotFound
} }
cert, err := s.Certificates.Lookup(hostname) cert, err := s.Certificates.Lookup(hostname)

22
tofu.go
View File

@ -13,6 +13,15 @@ import (
"time" "time"
) )
// Trust represents the trustworthiness of a certificate.
type Trust int
const (
TrustNone Trust = iota // The certificate is not trusted.
TrustOnce // The certificate is trusted once.
TrustAlways // The certificate is trusted always.
)
// KnownHosts represents a list of known hosts. // KnownHosts represents a list of known hosts.
// The zero value for KnownHosts is an empty list ready to use. // The zero value for KnownHosts is an empty list ready to use.
type KnownHosts struct { type KnownHosts struct {
@ -86,26 +95,25 @@ func (k *KnownHosts) add(hostname string, cert *x509.Certificate, write bool) {
} }
// Lookup looks for the provided certificate in the list of known hosts. // Lookup looks for the provided certificate in the list of known hosts.
// If the hostname is in the list, but the fingerprint differs, // If the hostname is not in the list, Lookup returns ErrCertificateNotFound.
// Lookup returns ErrCertificateNotTrusted. // If the fingerprint doesn't match, Lookup returns ErrCertificateNotTrusted.
// If the hostname is not in the list, Lookup returns ErrCertificateUnknown. // Otherwise, Lookup returns nil.
// If the certificate is found and the fingerprint matches, error will be nil.
func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error { func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
now := time.Now().Unix() now := time.Now().Unix()
fingerprint := Fingerprint(cert) fingerprint := Fingerprint(cert)
if c, ok := k.hosts[hostname]; ok { if c, ok := k.hosts[hostname]; ok {
if c.Expires <= now { if c.Expires <= now {
// Certificate is expired // Certificate is expired
return ErrCertificateUnknown return ErrCertificateExpired
} }
if c.Fingerprint != fingerprint { if c.Fingerprint != fingerprint {
// Fingerprint does not match // Fingerprint does not match
return ErrCertificateNotTrusted return ErrCertificateNotTrusted
} }
// Certificate is trusted // Certificate is found
return nil return nil
} }
return ErrCertificateUnknown return ErrCertificateNotFound
} }
// Parse parses the provided reader and adds the parsed known hosts to the list. // Parse parses the provided reader and adds the parsed known hosts to the list.