Refactor client.TrustCertificate workflow

This commit is contained in:
Adnan Maolood 2020-10-31 22:34:51 -04:00
parent b9cb7fe71d
commit 65a5065250
6 changed files with 81 additions and 108 deletions

View File

@ -16,7 +16,7 @@ import (
"time"
)
// CertificateStore maps hostnames to certificates.
// CertificateStore maps certificate scopes to certificates.
// The zero value of CertificateStore is an empty store ready to use.
type CertificateStore struct {
store map[string]tls.Certificate

View File

@ -43,12 +43,10 @@ type Client struct {
// the request will not be sent again and the response will be returned.
CreateCertificate func(hostname, path string) (tls.Certificate, error)
// TrustCertificate determines whether the client should trust
// the provided certificate.
// If the returned error is not nil, the connection will be aborted.
// If TrustCertificate is nil, the client will check KnownHosts
// for the certificate.
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
// TrustCertificate is called to determine whether the client
// should trust a certificate it has not seen before.
// If TrustCertificate is nil, the certificate will not be trusted.
TrustCertificate func(hostname string, cert *x509.Certificate) Trust
}
// Get performs a Gemini request for the given url.
@ -198,10 +196,25 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
if err := verifyHostname(cert, hostname); err != nil {
return err
}
// Check that the client trusts the certificate
var err error
// Check the known hosts
err := c.KnownHosts.Lookup(hostname, cert)
switch err {
case ErrCertificateExpired, ErrCertificateNotFound:
default:
return err
}
// See if the client trusts the certificate
if c.TrustCertificate != nil {
return c.TrustCertificate(hostname, cert, &c.KnownHosts)
switch c.TrustCertificate(hostname, cert) {
case TrustOnce:
c.KnownHosts.AddTemporary(hostname, cert)
return nil
case TrustAlways:
c.KnownHosts.Add(hostname, cert)
return nil
default:
return ErrCertificateNotTrusted
}
} else {
err = c.KnownHosts.Lookup(hostname, cert)
}

View File

@ -8,12 +8,22 @@ import (
"crypto/x509"
"fmt"
"io/ioutil"
"log"
"os"
"time"
"git.sr.ht/~adnano/go-gemini"
)
const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is:
%s
If you knew the fingerprint to expect in advance, verify that this matches.
Otherwise, this should be safe to trust.
[t]rust always; trust [o]nce; [a]bort
=> `
var (
scanner = bufio.NewScanner(os.Stdin)
client = &gemini.Client{}
@ -22,30 +32,17 @@ var (
func init() {
client.Timeout = 2 * time.Minute
client.KnownHosts.LoadDefault()
client.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *gemini.KnownHosts) error {
err := knownHosts.Lookup(hostname, cert)
if err != nil {
switch err {
case gemini.ErrCertificateNotTrusted:
// Alert the user that the certificate is not trusted
fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname)
fmt.Println("This could indicate a Man-in-the-Middle attack.")
case gemini.ErrCertificateUnknown:
// Prompt the user to trust the certificate
trust := trustCertificate(cert)
switch trust {
case trustOnce:
// Temporarily trust the certificate
knownHosts.AddTemporary(hostname, cert)
return nil
case trustAlways:
// Add the certificate to the known hosts file
knownHosts.Add(hostname, cert)
return nil
}
}
client.TrustCertificate = func(hostname string, cert *x509.Certificate) gemini.Trust {
fmt.Printf(trustPrompt, hostname, gemini.Fingerprint(cert))
scanner.Scan()
switch scanner.Text() {
case "t":
return gemini.TrustAlways
case "o":
return gemini.TrustOnce
default:
return gemini.TrustNone
}
return err
}
client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
fmt.Println("Generating client certificate for", hostname, path)
@ -60,54 +57,6 @@ func init() {
}
}
func doRequest(req *gemini.Request) error {
resp, err := client.Do(req)
if err != nil {
return err
}
if resp.Status.Class() == gemini.StatusClassSuccess {
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return err
}
fmt.Print(string(body))
return nil
}
return fmt.Errorf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta)
}
type trust int
const (
trustAbort trust = iota
trustOnce
trustAlways
)
const trustPrompt = `The certificate offered by this server is of unknown trust. Its fingerprint is:
%s
If you knew the fingerprint to expect in advance, verify that this matches.
Otherwise, this should be safe to trust.
[t]rust always; trust [o]nce; [a]bort
=> `
func trustCertificate(cert *x509.Certificate) trust {
fmt.Printf(trustPrompt, gemini.Fingerprint(cert))
scanner.Scan()
switch scanner.Text() {
case "t":
return trustAlways
case "o":
return trustOnce
default:
return trustAbort
}
}
func main() {
if len(os.Args) < 2 {
fmt.Printf("usage: %s gemini://... [host]", os.Args[0])
@ -125,8 +74,19 @@ func main() {
req.Host = os.Args[2]
}
if err := doRequest(req); err != nil {
fmt.Println(err)
os.Exit(1)
resp, err := client.Do(req)
if err != nil {
log.Fatal(err)
}
if resp.Status.Class() == gemini.StatusClassSuccess {
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
log.Fatal(err)
}
fmt.Print(string(body))
} else {
log.Fatalf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta)
}
}

View File

@ -1,11 +1,8 @@
package gemini
import (
"crypto/tls"
"crypto/x509"
"errors"
"sync"
"time"
)
var crlf = []byte("\r\n")
@ -14,16 +11,15 @@ var crlf = []byte("\r\n")
var (
ErrInvalidURL = errors.New("gemini: invalid URL")
ErrInvalidResponse = errors.New("gemini: invalid response")
ErrCertificateUnknown = errors.New("gemini: unknown certificate")
ErrCertificateExpired = errors.New("gemini: certificate expired")
ErrCertificateNotFound = errors.New("gemini: certificate not found")
ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted")
ErrCertificateRequired = errors.New("gemini: certificate required")
ErrNotAFile = errors.New("gemini: not a file")
ErrNotAGeminiURL = errors.New("gemini: not a Gemini URL")
ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body")
ErrTooManyRedirects = errors.New("gemini: too many redirects")
ErrInputRequired = errors.New("gemini: input required")
ErrCertificateRequired = errors.New("gemini: certificate required")
ErrCertificateNotFound = errors.New("gemini: certificate not found")
)
// DefaultClient is the default client. It is used by Get and Do.
@ -35,6 +31,7 @@ var DefaultClient Client
//
// Get is a wrapper around DefaultClient.Get.
func Get(url string) (*Response, error) {
setupDefaultClientOnce()
return DefaultClient.Get(url)
}
@ -42,19 +39,14 @@ func Get(url string) (*Response, error) {
//
// Do is a wrapper around DefaultClient.Do.
func Do(req *Request) (*Response, error) {
setupDefaultClientOnce()
return DefaultClient.Do(req)
}
var defaultClientOnce sync.Once
func init() {
DefaultClient.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error {
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
return knownHosts.Lookup(hostname, cert)
}
DefaultClient.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
return CreateCertificate(CertificateOptions{
Duration: time.Hour,
})
}
func setupDefaultClientOnce() {
defaultClientOnce.Do(func() {
DefaultClient.KnownHosts.LoadDefault()
})
}

View File

@ -146,7 +146,7 @@ func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error
func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
if _, ok := s.hosts[hostname]; !ok {
return nil, ErrCertificateUnknown
return nil, ErrCertificateNotFound
}
cert, err := s.Certificates.Lookup(hostname)

22
tofu.go
View File

@ -13,6 +13,15 @@ import (
"time"
)
// Trust represents the trustworthiness of a certificate.
type Trust int
const (
TrustNone Trust = iota // The certificate is not trusted.
TrustOnce // The certificate is trusted once.
TrustAlways // The certificate is trusted always.
)
// KnownHosts represents a list of known hosts.
// The zero value for KnownHosts is an empty list ready to use.
type KnownHosts struct {
@ -86,26 +95,25 @@ func (k *KnownHosts) add(hostname string, cert *x509.Certificate, write bool) {
}
// Lookup looks for the provided certificate in the list of known hosts.
// If the hostname is in the list, but the fingerprint differs,
// Lookup returns ErrCertificateNotTrusted.
// If the hostname is not in the list, Lookup returns ErrCertificateUnknown.
// If the certificate is found and the fingerprint matches, error will be nil.
// If the hostname is not in the list, Lookup returns ErrCertificateNotFound.
// If the fingerprint doesn't match, Lookup returns ErrCertificateNotTrusted.
// Otherwise, Lookup returns nil.
func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
now := time.Now().Unix()
fingerprint := Fingerprint(cert)
if c, ok := k.hosts[hostname]; ok {
if c.Expires <= now {
// Certificate is expired
return ErrCertificateUnknown
return ErrCertificateExpired
}
if c.Fingerprint != fingerprint {
// Fingerprint does not match
return ErrCertificateNotTrusted
}
// Certificate is trusted
// Certificate is found
return nil
}
return ErrCertificateUnknown
return ErrCertificateNotFound
}
// Parse parses the provided reader and adds the parsed known hosts to the list.