diff --git a/cert.go b/cert.go index c7cf04b..25ca342 100644 --- a/cert.go +++ b/cert.go @@ -16,7 +16,7 @@ import ( "time" ) -// CertificateStore maps hostnames to certificates. +// CertificateStore maps certificate scopes to certificates. // The zero value of CertificateStore is an empty store ready to use. type CertificateStore struct { store map[string]tls.Certificate diff --git a/client.go b/client.go index 4e6b926..4efd95f 100644 --- a/client.go +++ b/client.go @@ -43,12 +43,10 @@ type Client struct { // the request will not be sent again and the response will be returned. CreateCertificate func(hostname, path string) (tls.Certificate, error) - // TrustCertificate determines whether the client should trust - // the provided certificate. - // If the returned error is not nil, the connection will be aborted. - // If TrustCertificate is nil, the client will check KnownHosts - // for the certificate. - TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error + // TrustCertificate is called to determine whether the client + // should trust a certificate it has not seen before. + // If TrustCertificate is nil, the certificate will not be trusted. + TrustCertificate func(hostname string, cert *x509.Certificate) Trust } // Get performs a Gemini request for the given url. @@ -198,10 +196,25 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error { if err := verifyHostname(cert, hostname); err != nil { return err } - // Check that the client trusts the certificate - var err error + // Check the known hosts + err := c.KnownHosts.Lookup(hostname, cert) + switch err { + case ErrCertificateExpired, ErrCertificateNotFound: + default: + return err + } + // See if the client trusts the certificate if c.TrustCertificate != nil { - return c.TrustCertificate(hostname, cert, &c.KnownHosts) + switch c.TrustCertificate(hostname, cert) { + case TrustOnce: + c.KnownHosts.AddTemporary(hostname, cert) + return nil + case TrustAlways: + c.KnownHosts.Add(hostname, cert) + return nil + default: + return ErrCertificateNotTrusted + } } else { err = c.KnownHosts.Lookup(hostname, cert) } diff --git a/examples/client.go b/examples/client.go index 1fea082..4c9b172 100644 --- a/examples/client.go +++ b/examples/client.go @@ -8,12 +8,22 @@ import ( "crypto/x509" "fmt" "io/ioutil" + "log" "os" "time" "git.sr.ht/~adnano/go-gemini" ) +const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is: +%s + +If you knew the fingerprint to expect in advance, verify that this matches. +Otherwise, this should be safe to trust. + +[t]rust always; trust [o]nce; [a]bort +=> ` + var ( scanner = bufio.NewScanner(os.Stdin) client = &gemini.Client{} @@ -22,30 +32,17 @@ var ( func init() { client.Timeout = 2 * time.Minute client.KnownHosts.LoadDefault() - client.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *gemini.KnownHosts) error { - err := knownHosts.Lookup(hostname, cert) - if err != nil { - switch err { - case gemini.ErrCertificateNotTrusted: - // Alert the user that the certificate is not trusted - fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname) - fmt.Println("This could indicate a Man-in-the-Middle attack.") - case gemini.ErrCertificateUnknown: - // Prompt the user to trust the certificate - trust := trustCertificate(cert) - switch trust { - case trustOnce: - // Temporarily trust the certificate - knownHosts.AddTemporary(hostname, cert) - return nil - case trustAlways: - // Add the certificate to the known hosts file - knownHosts.Add(hostname, cert) - return nil - } - } + client.TrustCertificate = func(hostname string, cert *x509.Certificate) gemini.Trust { + fmt.Printf(trustPrompt, hostname, gemini.Fingerprint(cert)) + scanner.Scan() + switch scanner.Text() { + case "t": + return gemini.TrustAlways + case "o": + return gemini.TrustOnce + default: + return gemini.TrustNone } - return err } client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) { fmt.Println("Generating client certificate for", hostname, path) @@ -60,54 +57,6 @@ func init() { } } -func doRequest(req *gemini.Request) error { - resp, err := client.Do(req) - if err != nil { - return err - } - - if resp.Status.Class() == gemini.StatusClassSuccess { - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return err - } - fmt.Print(string(body)) - return nil - } - return fmt.Errorf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta) -} - -type trust int - -const ( - trustAbort trust = iota - trustOnce - trustAlways -) - -const trustPrompt = `The certificate offered by this server is of unknown trust. Its fingerprint is: -%s - -If you knew the fingerprint to expect in advance, verify that this matches. -Otherwise, this should be safe to trust. - -[t]rust always; trust [o]nce; [a]bort -=> ` - -func trustCertificate(cert *x509.Certificate) trust { - fmt.Printf(trustPrompt, gemini.Fingerprint(cert)) - scanner.Scan() - switch scanner.Text() { - case "t": - return trustAlways - case "o": - return trustOnce - default: - return trustAbort - } -} - func main() { if len(os.Args) < 2 { fmt.Printf("usage: %s gemini://... [host]", os.Args[0]) @@ -125,8 +74,19 @@ func main() { req.Host = os.Args[2] } - if err := doRequest(req); err != nil { - fmt.Println(err) - os.Exit(1) + resp, err := client.Do(req) + if err != nil { + log.Fatal(err) + } + + if resp.Status.Class() == gemini.StatusClassSuccess { + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatal(err) + } + fmt.Print(string(body)) + } else { + log.Fatalf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta) } } diff --git a/gemini.go b/gemini.go index 7015536..b03400d 100644 --- a/gemini.go +++ b/gemini.go @@ -1,11 +1,8 @@ package gemini import ( - "crypto/tls" - "crypto/x509" "errors" "sync" - "time" ) var crlf = []byte("\r\n") @@ -14,16 +11,15 @@ var crlf = []byte("\r\n") var ( ErrInvalidURL = errors.New("gemini: invalid URL") ErrInvalidResponse = errors.New("gemini: invalid response") - ErrCertificateUnknown = errors.New("gemini: unknown certificate") ErrCertificateExpired = errors.New("gemini: certificate expired") + ErrCertificateNotFound = errors.New("gemini: certificate not found") ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted") + ErrCertificateRequired = errors.New("gemini: certificate required") ErrNotAFile = errors.New("gemini: not a file") ErrNotAGeminiURL = errors.New("gemini: not a Gemini URL") ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body") ErrTooManyRedirects = errors.New("gemini: too many redirects") ErrInputRequired = errors.New("gemini: input required") - ErrCertificateRequired = errors.New("gemini: certificate required") - ErrCertificateNotFound = errors.New("gemini: certificate not found") ) // DefaultClient is the default client. It is used by Get and Do. @@ -35,6 +31,7 @@ var DefaultClient Client // // Get is a wrapper around DefaultClient.Get. func Get(url string) (*Response, error) { + setupDefaultClientOnce() return DefaultClient.Get(url) } @@ -42,19 +39,14 @@ func Get(url string) (*Response, error) { // // Do is a wrapper around DefaultClient.Do. func Do(req *Request) (*Response, error) { + setupDefaultClientOnce() return DefaultClient.Do(req) } var defaultClientOnce sync.Once -func init() { - DefaultClient.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error { - defaultClientOnce.Do(func() { knownHosts.LoadDefault() }) - return knownHosts.Lookup(hostname, cert) - } - DefaultClient.CreateCertificate = func(hostname, path string) (tls.Certificate, error) { - return CreateCertificate(CertificateOptions{ - Duration: time.Hour, - }) - } +func setupDefaultClientOnce() { + defaultClientOnce.Do(func() { + DefaultClient.KnownHosts.LoadDefault() + }) } diff --git a/server.go b/server.go index 4fbd73a..9a46bac 100644 --- a/server.go +++ b/server.go @@ -146,7 +146,7 @@ func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) { if _, ok := s.hosts[hostname]; !ok { - return nil, ErrCertificateUnknown + return nil, ErrCertificateNotFound } cert, err := s.Certificates.Lookup(hostname) diff --git a/tofu.go b/tofu.go index 73aead0..05a8624 100644 --- a/tofu.go +++ b/tofu.go @@ -13,6 +13,15 @@ import ( "time" ) +// Trust represents the trustworthiness of a certificate. +type Trust int + +const ( + TrustNone Trust = iota // The certificate is not trusted. + TrustOnce // The certificate is trusted once. + TrustAlways // The certificate is trusted always. +) + // KnownHosts represents a list of known hosts. // The zero value for KnownHosts is an empty list ready to use. type KnownHosts struct { @@ -86,26 +95,25 @@ func (k *KnownHosts) add(hostname string, cert *x509.Certificate, write bool) { } // Lookup looks for the provided certificate in the list of known hosts. -// If the hostname is in the list, but the fingerprint differs, -// Lookup returns ErrCertificateNotTrusted. -// If the hostname is not in the list, Lookup returns ErrCertificateUnknown. -// If the certificate is found and the fingerprint matches, error will be nil. +// If the hostname is not in the list, Lookup returns ErrCertificateNotFound. +// If the fingerprint doesn't match, Lookup returns ErrCertificateNotTrusted. +// Otherwise, Lookup returns nil. func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error { now := time.Now().Unix() fingerprint := Fingerprint(cert) if c, ok := k.hosts[hostname]; ok { if c.Expires <= now { // Certificate is expired - return ErrCertificateUnknown + return ErrCertificateExpired } if c.Fingerprint != fingerprint { // Fingerprint does not match return ErrCertificateNotTrusted } - // Certificate is trusted + // Certificate is found return nil } - return ErrCertificateUnknown + return ErrCertificateNotFound } // Parse parses the provided reader and adds the parsed known hosts to the list.