Move certificate verification code to VerifyConnection
This commit is contained in:
		
							parent
							
								
									db89a34935
								
							
						
					
					
						commit
						3d1e9e5519
					
				
							
								
								
									
										4
									
								
								cert.go
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								cert.go
									
									
									
									
									
								
							@ -41,11 +41,11 @@ func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
 | 
			
		||||
func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) {
 | 
			
		||||
	cert, ok := c.store[hostname]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil, ErrUnknownCertificate
 | 
			
		||||
		return nil, ErrCertificateUnknown
 | 
			
		||||
	}
 | 
			
		||||
	// Ensure that the certificate is not expired
 | 
			
		||||
	if cert.Leaf != nil && cert.Leaf.NotAfter.Before(time.Now()) {
 | 
			
		||||
		return &cert, ErrInvalidCertificate
 | 
			
		||||
		return &cert, ErrCertificateExpired
 | 
			
		||||
	}
 | 
			
		||||
	return &cert, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										33
									
								
								client.go
									
									
									
									
									
								
							
							
						
						
									
										33
									
								
								client.go
									
									
									
									
									
								
							@ -9,15 +9,13 @@ import (
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Client errors.
 | 
			
		||||
var (
 | 
			
		||||
	ErrInvalidURL            = errors.New("gemini: invalid URL")
 | 
			
		||||
	ErrInvalidResponse       = errors.New("gemini: invalid response")
 | 
			
		||||
	ErrInvalidCertificate    = errors.New("gemini: invalid certificate")
 | 
			
		||||
	ErrUnknownCertificate    = errors.New("gemini: unknown certificate")
 | 
			
		||||
	ErrCertificateUnknown    = errors.New("gemini: unknown certificate")
 | 
			
		||||
	ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -221,18 +219,9 @@ func (c *Client) Send(req *Request) (*Response, error) {
 | 
			
		||||
			}
 | 
			
		||||
			return &tls.Certificate{}, nil
 | 
			
		||||
		},
 | 
			
		||||
		VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
 | 
			
		||||
			// Parse the certificate
 | 
			
		||||
			cert, err := x509.ParseCertificate(rawCerts[0])
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			// Validate the certificate
 | 
			
		||||
			if !validCertificate(cert) {
 | 
			
		||||
				return ErrInvalidCertificate
 | 
			
		||||
			}
 | 
			
		||||
			// Check that the certificate is valid for the hostname
 | 
			
		||||
			// Use our own implementation of verifyHostname
 | 
			
		||||
		VerifyConnection: func(cs tls.ConnectionState) error {
 | 
			
		||||
			cert := cs.PeerCertificates[0]
 | 
			
		||||
			// Verify the hostname
 | 
			
		||||
			if err := verifyHostname(cert, hostname(req.Host)); err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
@ -285,17 +274,3 @@ func (c *Client) Send(req *Request) (*Response, error) {
 | 
			
		||||
	}
 | 
			
		||||
	return resp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// validCertificate determines whether cert is a valid certificate
 | 
			
		||||
func validCertificate(cert *x509.Certificate) bool {
 | 
			
		||||
	// Check notBefore and notAfter
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	if cert.NotBefore.After(now) {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if cert.NotAfter.Before(now) {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	// No need to check hash algorithms, hopefully tls has checked for us already
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -29,7 +29,7 @@ func init() {
 | 
			
		||||
				// Alert the user that the certificate is not trusted
 | 
			
		||||
				fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname)
 | 
			
		||||
				fmt.Println("This could indicate a Man-in-the-Middle attack.")
 | 
			
		||||
			case gmi.ErrUnknownCertificate:
 | 
			
		||||
			case gmi.ErrCertificateUnknown:
 | 
			
		||||
				// Prompt the user to trust the certificate
 | 
			
		||||
				trust := trustCertificate(cert)
 | 
			
		||||
				switch trust {
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ func main() {
 | 
			
		||||
		cert, err := store.Lookup(hostname)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			switch err {
 | 
			
		||||
			case gmi.ErrInvalidCertificate:
 | 
			
		||||
			case gmi.ErrCertificateExpired:
 | 
			
		||||
				log.Print("Old certificate expired, creating new one")
 | 
			
		||||
				// Generate a new certificate if the old one is expired.
 | 
			
		||||
				crt, key, err := gmi.NewRawCertificate(hostname, time.Minute)
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,8 @@ import (
 | 
			
		||||
 | 
			
		||||
// Server errors.
 | 
			
		||||
var (
 | 
			
		||||
	ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body")
 | 
			
		||||
	ErrBodyNotAllowed     = errors.New("gemini: response status code does not allow for body")
 | 
			
		||||
	ErrCertificateExpired = errors.New("gemini: certificate expired")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Server is a Gemini server.
 | 
			
		||||
@ -92,15 +93,14 @@ func (s *Server) ListenAndServe() error {
 | 
			
		||||
	defer ln.Close()
 | 
			
		||||
 | 
			
		||||
	config := &tls.Config{
 | 
			
		||||
		InsecureSkipVerify: true,
 | 
			
		||||
		MinVersion:         tls.VersionTLS12,
 | 
			
		||||
		ClientAuth: tls.RequestClientCert,
 | 
			
		||||
		MinVersion: tls.VersionTLS12,
 | 
			
		||||
		GetCertificate: func(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
 | 
			
		||||
			if s.GetCertificate != nil {
 | 
			
		||||
				return s.GetCertificate(h.ServerName, &s.CertificateStore), nil
 | 
			
		||||
			}
 | 
			
		||||
			return s.CertificateStore.Lookup(h.ServerName)
 | 
			
		||||
		},
 | 
			
		||||
		ClientAuth: tls.RequestClientCert,
 | 
			
		||||
	}
 | 
			
		||||
	tlsListener := tls.NewListener(ln, config)
 | 
			
		||||
	return s.Serve(tlsListener)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								tofu.go
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								tofu.go
									
									
									
									
									
								
							@ -7,7 +7,6 @@ import (
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@ -93,7 +92,6 @@ func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
 | 
			
		||||
			// Certificate is expired
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		log.Print(k.hosts[i].Expires, now)
 | 
			
		||||
		if k.hosts[i].Fingerprint == fingerprint {
 | 
			
		||||
			// Fingerprint matches
 | 
			
		||||
			return nil
 | 
			
		||||
@ -101,7 +99,7 @@ func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
 | 
			
		||||
		// Fingerprint does not match
 | 
			
		||||
		return ErrCertificateNotTrusted
 | 
			
		||||
	}
 | 
			
		||||
	return ErrUnknownCertificate
 | 
			
		||||
	return ErrCertificateUnknown
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Parse parses the provided reader and adds the parsed known hosts to the list.
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										25
									
								
								vendor.go
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								vendor.go
									
									
									
									
									
								
							@ -10,37 +10,22 @@ package gmi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"crypto/x509/pkix"
 | 
			
		||||
	"encoding/asn1"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	oidExtensionSubjectAltName = []int{2, 5, 29, 17}
 | 
			
		||||
)
 | 
			
		||||
var oidExtensionSubjectAltName = []int{2, 5, 29, 17}
 | 
			
		||||
 | 
			
		||||
// oidNotInExtensions reports whether an extension with the given oid exists in
 | 
			
		||||
// extensions.
 | 
			
		||||
func oidInExtensions(oid asn1.ObjectIdentifier, extensions []pkix.Extension) bool {
 | 
			
		||||
	for _, e := range extensions {
 | 
			
		||||
		if e.Id.Equal(oid) {
 | 
			
		||||
func hasSANExtension(c *x509.Certificate) bool {
 | 
			
		||||
	for _, e := range c.Extensions {
 | 
			
		||||
		if e.Id.Equal(oidExtensionSubjectAltName) {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func hasSANExtension(c *x509.Certificate) bool {
 | 
			
		||||
	return oidInExtensions(oidExtensionSubjectAltName, c.Extensions)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ignoreCN disables interpreting Common Name as a hostname. See issue 24151.
 | 
			
		||||
// NOTE: This is set to false so that certificates with common names will still
 | 
			
		||||
// be supported.
 | 
			
		||||
var ignoreCN = false
 | 
			
		||||
 | 
			
		||||
func validHostnamePattern(host string) bool { return validHostname(host, true) }
 | 
			
		||||
func validHostnameInput(host string) bool   { return validHostname(host, false) }
 | 
			
		||||
 | 
			
		||||
@ -100,7 +85,7 @@ func validHostname(host string, isPattern bool) bool {
 | 
			
		||||
// constraints if there is no risk the CN would be matched as a hostname.
 | 
			
		||||
// See NameConstraintsWithoutSANs and issue 24151.
 | 
			
		||||
func commonNameAsHostname(c *x509.Certificate) bool {
 | 
			
		||||
	return !ignoreCN && !hasSANExtension(c) && validHostnamePattern(c.Subject.CommonName)
 | 
			
		||||
	return !hasSANExtension(c) && validHostnamePattern(c.Subject.CommonName)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func matchExactly(hostA, hostB string) bool {
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user