diff --git a/cert.go b/cert.go index dad45f3..c8d0321 100644 --- a/cert.go +++ b/cert.go @@ -41,11 +41,11 @@ func (c *CertificateStore) Add(hostname string, cert tls.Certificate) { func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) { cert, ok := c.store[hostname] if !ok { - return nil, ErrUnknownCertificate + return nil, ErrCertificateUnknown } // Ensure that the certificate is not expired if cert.Leaf != nil && cert.Leaf.NotAfter.Before(time.Now()) { - return &cert, ErrInvalidCertificate + return &cert, ErrCertificateExpired } return &cert, nil } diff --git a/client.go b/client.go index e4155c6..b6e227d 100644 --- a/client.go +++ b/client.go @@ -9,15 +9,13 @@ import ( "net" "net/url" "strconv" - "time" ) // Client errors. var ( ErrInvalidURL = errors.New("gemini: invalid URL") ErrInvalidResponse = errors.New("gemini: invalid response") - ErrInvalidCertificate = errors.New("gemini: invalid certificate") - ErrUnknownCertificate = errors.New("gemini: unknown certificate") + ErrCertificateUnknown = errors.New("gemini: unknown certificate") ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted") ) @@ -221,18 +219,9 @@ func (c *Client) Send(req *Request) (*Response, error) { } return &tls.Certificate{}, nil }, - VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { - // Parse the certificate - cert, err := x509.ParseCertificate(rawCerts[0]) - if err != nil { - return err - } - // Validate the certificate - if !validCertificate(cert) { - return ErrInvalidCertificate - } - // Check that the certificate is valid for the hostname - // Use our own implementation of verifyHostname + VerifyConnection: func(cs tls.ConnectionState) error { + cert := cs.PeerCertificates[0] + // Verify the hostname if err := verifyHostname(cert, hostname(req.Host)); err != nil { return err } @@ -285,17 +274,3 @@ func (c *Client) Send(req *Request) (*Response, error) { } return resp, nil } - -// validCertificate determines whether cert is a valid certificate -func validCertificate(cert *x509.Certificate) bool { - // Check notBefore and notAfter - now := time.Now() - if cert.NotBefore.After(now) { - return false - } - if cert.NotAfter.Before(now) { - return false - } - // No need to check hash algorithms, hopefully tls has checked for us already - return true -} diff --git a/examples/client.go b/examples/client.go index edea321..8d818c3 100644 --- a/examples/client.go +++ b/examples/client.go @@ -29,7 +29,7 @@ func init() { // Alert the user that the certificate is not trusted fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname) fmt.Println("This could indicate a Man-in-the-Middle attack.") - case gmi.ErrUnknownCertificate: + case gmi.ErrCertificateUnknown: // Prompt the user to trust the certificate trust := trustCertificate(cert) switch trust { diff --git a/examples/server.go b/examples/server.go index 2385c95..4feb090 100644 --- a/examples/server.go +++ b/examples/server.go @@ -19,7 +19,7 @@ func main() { cert, err := store.Lookup(hostname) if err != nil { switch err { - case gmi.ErrInvalidCertificate: + case gmi.ErrCertificateExpired: log.Print("Old certificate expired, creating new one") // Generate a new certificate if the old one is expired. crt, key, err := gmi.NewRawCertificate(hostname, time.Minute) diff --git a/server.go b/server.go index 35c3f4b..e6c67c6 100644 --- a/server.go +++ b/server.go @@ -18,7 +18,8 @@ import ( // Server errors. var ( - ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body") + ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body") + ErrCertificateExpired = errors.New("gemini: certificate expired") ) // Server is a Gemini server. @@ -92,15 +93,14 @@ func (s *Server) ListenAndServe() error { defer ln.Close() config := &tls.Config{ - InsecureSkipVerify: true, - MinVersion: tls.VersionTLS12, + ClientAuth: tls.RequestClientCert, + MinVersion: tls.VersionTLS12, GetCertificate: func(h *tls.ClientHelloInfo) (*tls.Certificate, error) { if s.GetCertificate != nil { return s.GetCertificate(h.ServerName, &s.CertificateStore), nil } return s.CertificateStore.Lookup(h.ServerName) }, - ClientAuth: tls.RequestClientCert, } tlsListener := tls.NewListener(ln, config) return s.Serve(tlsListener) diff --git a/tofu.go b/tofu.go index bcbc41a..8932b2f 100644 --- a/tofu.go +++ b/tofu.go @@ -7,7 +7,6 @@ import ( "crypto/x509" "fmt" "io" - "log" "os" "path/filepath" "strconv" @@ -93,7 +92,6 @@ func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error { // Certificate is expired continue } - log.Print(k.hosts[i].Expires, now) if k.hosts[i].Fingerprint == fingerprint { // Fingerprint matches return nil @@ -101,7 +99,7 @@ func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error { // Fingerprint does not match return ErrCertificateNotTrusted } - return ErrUnknownCertificate + return ErrCertificateUnknown } // Parse parses the provided reader and adds the parsed known hosts to the list. diff --git a/vendor.go b/vendor.go index 8a0fd5a..207159a 100644 --- a/vendor.go +++ b/vendor.go @@ -10,37 +10,22 @@ package gmi import ( "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" "net" "strings" "unicode/utf8" ) -var ( - oidExtensionSubjectAltName = []int{2, 5, 29, 17} -) +var oidExtensionSubjectAltName = []int{2, 5, 29, 17} -// oidNotInExtensions reports whether an extension with the given oid exists in -// extensions. -func oidInExtensions(oid asn1.ObjectIdentifier, extensions []pkix.Extension) bool { - for _, e := range extensions { - if e.Id.Equal(oid) { +func hasSANExtension(c *x509.Certificate) bool { + for _, e := range c.Extensions { + if e.Id.Equal(oidExtensionSubjectAltName) { return true } } return false } -func hasSANExtension(c *x509.Certificate) bool { - return oidInExtensions(oidExtensionSubjectAltName, c.Extensions) -} - -// ignoreCN disables interpreting Common Name as a hostname. See issue 24151. -// NOTE: This is set to false so that certificates with common names will still -// be supported. -var ignoreCN = false - func validHostnamePattern(host string) bool { return validHostname(host, true) } func validHostnameInput(host string) bool { return validHostname(host, false) } @@ -100,7 +85,7 @@ func validHostname(host string, isPattern bool) bool { // constraints if there is no risk the CN would be matched as a hostname. // See NameConstraintsWithoutSANs and issue 24151. func commonNameAsHostname(c *x509.Certificate) bool { - return !ignoreCN && !hasSANExtension(c) && validHostnamePattern(c.Subject.CommonName) + return !hasSANExtension(c) && validHostnamePattern(c.Subject.CommonName) } func matchExactly(hostA, hostB string) bool {