Move certificate verification code to VerifyConnection

This commit is contained in:
adnano 2020-10-13 16:44:46 -04:00
parent db89a34935
commit 3d1e9e5519
7 changed files with 18 additions and 60 deletions

View File

@ -41,11 +41,11 @@ func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) { func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) {
cert, ok := c.store[hostname] cert, ok := c.store[hostname]
if !ok { if !ok {
return nil, ErrUnknownCertificate return nil, ErrCertificateUnknown
} }
// Ensure that the certificate is not expired // Ensure that the certificate is not expired
if cert.Leaf != nil && cert.Leaf.NotAfter.Before(time.Now()) { if cert.Leaf != nil && cert.Leaf.NotAfter.Before(time.Now()) {
return &cert, ErrInvalidCertificate return &cert, ErrCertificateExpired
} }
return &cert, nil return &cert, nil
} }

View File

@ -9,15 +9,13 @@ import (
"net" "net"
"net/url" "net/url"
"strconv" "strconv"
"time"
) )
// Client errors. // Client errors.
var ( var (
ErrInvalidURL = errors.New("gemini: invalid URL") ErrInvalidURL = errors.New("gemini: invalid URL")
ErrInvalidResponse = errors.New("gemini: invalid response") ErrInvalidResponse = errors.New("gemini: invalid response")
ErrInvalidCertificate = errors.New("gemini: invalid certificate") ErrCertificateUnknown = errors.New("gemini: unknown certificate")
ErrUnknownCertificate = errors.New("gemini: unknown certificate")
ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted") ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted")
) )
@ -221,18 +219,9 @@ func (c *Client) Send(req *Request) (*Response, error) {
} }
return &tls.Certificate{}, nil return &tls.Certificate{}, nil
}, },
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { VerifyConnection: func(cs tls.ConnectionState) error {
// Parse the certificate cert := cs.PeerCertificates[0]
cert, err := x509.ParseCertificate(rawCerts[0]) // Verify the hostname
if err != nil {
return err
}
// Validate the certificate
if !validCertificate(cert) {
return ErrInvalidCertificate
}
// Check that the certificate is valid for the hostname
// Use our own implementation of verifyHostname
if err := verifyHostname(cert, hostname(req.Host)); err != nil { if err := verifyHostname(cert, hostname(req.Host)); err != nil {
return err return err
} }
@ -285,17 +274,3 @@ func (c *Client) Send(req *Request) (*Response, error) {
} }
return resp, nil return resp, nil
} }
// validCertificate determines whether cert is a valid certificate
func validCertificate(cert *x509.Certificate) bool {
// Check notBefore and notAfter
now := time.Now()
if cert.NotBefore.After(now) {
return false
}
if cert.NotAfter.Before(now) {
return false
}
// No need to check hash algorithms, hopefully tls has checked for us already
return true
}

View File

@ -29,7 +29,7 @@ func init() {
// Alert the user that the certificate is not trusted // Alert the user that the certificate is not trusted
fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname) fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname)
fmt.Println("This could indicate a Man-in-the-Middle attack.") fmt.Println("This could indicate a Man-in-the-Middle attack.")
case gmi.ErrUnknownCertificate: case gmi.ErrCertificateUnknown:
// Prompt the user to trust the certificate // Prompt the user to trust the certificate
trust := trustCertificate(cert) trust := trustCertificate(cert)
switch trust { switch trust {

View File

@ -19,7 +19,7 @@ func main() {
cert, err := store.Lookup(hostname) cert, err := store.Lookup(hostname)
if err != nil { if err != nil {
switch err { switch err {
case gmi.ErrInvalidCertificate: case gmi.ErrCertificateExpired:
log.Print("Old certificate expired, creating new one") log.Print("Old certificate expired, creating new one")
// Generate a new certificate if the old one is expired. // Generate a new certificate if the old one is expired.
crt, key, err := gmi.NewRawCertificate(hostname, time.Minute) crt, key, err := gmi.NewRawCertificate(hostname, time.Minute)

View File

@ -19,6 +19,7 @@ import (
// Server errors. // Server errors.
var ( var (
ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body") ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body")
ErrCertificateExpired = errors.New("gemini: certificate expired")
) )
// Server is a Gemini server. // Server is a Gemini server.
@ -92,7 +93,7 @@ func (s *Server) ListenAndServe() error {
defer ln.Close() defer ln.Close()
config := &tls.Config{ config := &tls.Config{
InsecureSkipVerify: true, ClientAuth: tls.RequestClientCert,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
GetCertificate: func(h *tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
if s.GetCertificate != nil { if s.GetCertificate != nil {
@ -100,7 +101,6 @@ func (s *Server) ListenAndServe() error {
} }
return s.CertificateStore.Lookup(h.ServerName) return s.CertificateStore.Lookup(h.ServerName)
}, },
ClientAuth: tls.RequestClientCert,
} }
tlsListener := tls.NewListener(ln, config) tlsListener := tls.NewListener(ln, config)
return s.Serve(tlsListener) return s.Serve(tlsListener)

View File

@ -7,7 +7,6 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io" "io"
"log"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
@ -93,7 +92,6 @@ func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
// Certificate is expired // Certificate is expired
continue continue
} }
log.Print(k.hosts[i].Expires, now)
if k.hosts[i].Fingerprint == fingerprint { if k.hosts[i].Fingerprint == fingerprint {
// Fingerprint matches // Fingerprint matches
return nil return nil
@ -101,7 +99,7 @@ func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
// Fingerprint does not match // Fingerprint does not match
return ErrCertificateNotTrusted return ErrCertificateNotTrusted
} }
return ErrUnknownCertificate return ErrCertificateUnknown
} }
// Parse parses the provided reader and adds the parsed known hosts to the list. // Parse parses the provided reader and adds the parsed known hosts to the list.

View File

@ -10,37 +10,22 @@ package gmi
import ( import (
"crypto/x509" "crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"net" "net"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
) )
var ( var oidExtensionSubjectAltName = []int{2, 5, 29, 17}
oidExtensionSubjectAltName = []int{2, 5, 29, 17}
)
// oidNotInExtensions reports whether an extension with the given oid exists in func hasSANExtension(c *x509.Certificate) bool {
// extensions. for _, e := range c.Extensions {
func oidInExtensions(oid asn1.ObjectIdentifier, extensions []pkix.Extension) bool { if e.Id.Equal(oidExtensionSubjectAltName) {
for _, e := range extensions {
if e.Id.Equal(oid) {
return true return true
} }
} }
return false return false
} }
func hasSANExtension(c *x509.Certificate) bool {
return oidInExtensions(oidExtensionSubjectAltName, c.Extensions)
}
// ignoreCN disables interpreting Common Name as a hostname. See issue 24151.
// NOTE: This is set to false so that certificates with common names will still
// be supported.
var ignoreCN = false
func validHostnamePattern(host string) bool { return validHostname(host, true) } func validHostnamePattern(host string) bool { return validHostname(host, true) }
func validHostnameInput(host string) bool { return validHostname(host, false) } func validHostnameInput(host string) bool { return validHostname(host, false) }
@ -100,7 +85,7 @@ func validHostname(host string, isPattern bool) bool {
// constraints if there is no risk the CN would be matched as a hostname. // constraints if there is no risk the CN would be matched as a hostname.
// See NameConstraintsWithoutSANs and issue 24151. // See NameConstraintsWithoutSANs and issue 24151.
func commonNameAsHostname(c *x509.Certificate) bool { func commonNameAsHostname(c *x509.Certificate) bool {
return !ignoreCN && !hasSANExtension(c) && validHostnamePattern(c.Subject.CommonName) return !hasSANExtension(c) && validHostnamePattern(c.Subject.CommonName)
} }
func matchExactly(hostA, hostB string) bool { func matchExactly(hostA, hostB string) bool {