Move certificate verification code to VerifyConnection
This commit is contained in:
parent
db89a34935
commit
3d1e9e5519
4
cert.go
4
cert.go
@ -41,11 +41,11 @@ func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
|
||||
func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) {
|
||||
cert, ok := c.store[hostname]
|
||||
if !ok {
|
||||
return nil, ErrUnknownCertificate
|
||||
return nil, ErrCertificateUnknown
|
||||
}
|
||||
// Ensure that the certificate is not expired
|
||||
if cert.Leaf != nil && cert.Leaf.NotAfter.Before(time.Now()) {
|
||||
return &cert, ErrInvalidCertificate
|
||||
return &cert, ErrCertificateExpired
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
33
client.go
33
client.go
@ -9,15 +9,13 @@ import (
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Client errors.
|
||||
var (
|
||||
ErrInvalidURL = errors.New("gemini: invalid URL")
|
||||
ErrInvalidResponse = errors.New("gemini: invalid response")
|
||||
ErrInvalidCertificate = errors.New("gemini: invalid certificate")
|
||||
ErrUnknownCertificate = errors.New("gemini: unknown certificate")
|
||||
ErrCertificateUnknown = errors.New("gemini: unknown certificate")
|
||||
ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted")
|
||||
)
|
||||
|
||||
@ -221,18 +219,9 @@ func (c *Client) Send(req *Request) (*Response, error) {
|
||||
}
|
||||
return &tls.Certificate{}, nil
|
||||
},
|
||||
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
// Parse the certificate
|
||||
cert, err := x509.ParseCertificate(rawCerts[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Validate the certificate
|
||||
if !validCertificate(cert) {
|
||||
return ErrInvalidCertificate
|
||||
}
|
||||
// Check that the certificate is valid for the hostname
|
||||
// Use our own implementation of verifyHostname
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
cert := cs.PeerCertificates[0]
|
||||
// Verify the hostname
|
||||
if err := verifyHostname(cert, hostname(req.Host)); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -285,17 +274,3 @@ func (c *Client) Send(req *Request) (*Response, error) {
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// validCertificate determines whether cert is a valid certificate
|
||||
func validCertificate(cert *x509.Certificate) bool {
|
||||
// Check notBefore and notAfter
|
||||
now := time.Now()
|
||||
if cert.NotBefore.After(now) {
|
||||
return false
|
||||
}
|
||||
if cert.NotAfter.Before(now) {
|
||||
return false
|
||||
}
|
||||
// No need to check hash algorithms, hopefully tls has checked for us already
|
||||
return true
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ func init() {
|
||||
// Alert the user that the certificate is not trusted
|
||||
fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname)
|
||||
fmt.Println("This could indicate a Man-in-the-Middle attack.")
|
||||
case gmi.ErrUnknownCertificate:
|
||||
case gmi.ErrCertificateUnknown:
|
||||
// Prompt the user to trust the certificate
|
||||
trust := trustCertificate(cert)
|
||||
switch trust {
|
||||
|
@ -19,7 +19,7 @@ func main() {
|
||||
cert, err := store.Lookup(hostname)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case gmi.ErrInvalidCertificate:
|
||||
case gmi.ErrCertificateExpired:
|
||||
log.Print("Old certificate expired, creating new one")
|
||||
// Generate a new certificate if the old one is expired.
|
||||
crt, key, err := gmi.NewRawCertificate(hostname, time.Minute)
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
// Server errors.
|
||||
var (
|
||||
ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body")
|
||||
ErrCertificateExpired = errors.New("gemini: certificate expired")
|
||||
)
|
||||
|
||||
// Server is a Gemini server.
|
||||
@ -92,7 +93,7 @@ func (s *Server) ListenAndServe() error {
|
||||
defer ln.Close()
|
||||
|
||||
config := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ClientAuth: tls.RequestClientCert,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetCertificate: func(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if s.GetCertificate != nil {
|
||||
@ -100,7 +101,6 @@ func (s *Server) ListenAndServe() error {
|
||||
}
|
||||
return s.CertificateStore.Lookup(h.ServerName)
|
||||
},
|
||||
ClientAuth: tls.RequestClientCert,
|
||||
}
|
||||
tlsListener := tls.NewListener(ln, config)
|
||||
return s.Serve(tlsListener)
|
||||
|
4
tofu.go
4
tofu.go
@ -7,7 +7,6 @@ import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@ -93,7 +92,6 @@ func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
|
||||
// Certificate is expired
|
||||
continue
|
||||
}
|
||||
log.Print(k.hosts[i].Expires, now)
|
||||
if k.hosts[i].Fingerprint == fingerprint {
|
||||
// Fingerprint matches
|
||||
return nil
|
||||
@ -101,7 +99,7 @@ func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
|
||||
// Fingerprint does not match
|
||||
return ErrCertificateNotTrusted
|
||||
}
|
||||
return ErrUnknownCertificate
|
||||
return ErrCertificateUnknown
|
||||
}
|
||||
|
||||
// Parse parses the provided reader and adds the parsed known hosts to the list.
|
||||
|
25
vendor.go
25
vendor.go
@ -10,37 +10,22 @@ package gmi
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"net"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var (
|
||||
oidExtensionSubjectAltName = []int{2, 5, 29, 17}
|
||||
)
|
||||
var oidExtensionSubjectAltName = []int{2, 5, 29, 17}
|
||||
|
||||
// oidNotInExtensions reports whether an extension with the given oid exists in
|
||||
// extensions.
|
||||
func oidInExtensions(oid asn1.ObjectIdentifier, extensions []pkix.Extension) bool {
|
||||
for _, e := range extensions {
|
||||
if e.Id.Equal(oid) {
|
||||
func hasSANExtension(c *x509.Certificate) bool {
|
||||
for _, e := range c.Extensions {
|
||||
if e.Id.Equal(oidExtensionSubjectAltName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasSANExtension(c *x509.Certificate) bool {
|
||||
return oidInExtensions(oidExtensionSubjectAltName, c.Extensions)
|
||||
}
|
||||
|
||||
// ignoreCN disables interpreting Common Name as a hostname. See issue 24151.
|
||||
// NOTE: This is set to false so that certificates with common names will still
|
||||
// be supported.
|
||||
var ignoreCN = false
|
||||
|
||||
func validHostnamePattern(host string) bool { return validHostname(host, true) }
|
||||
func validHostnameInput(host string) bool { return validHostname(host, false) }
|
||||
|
||||
@ -100,7 +85,7 @@ func validHostname(host string, isPattern bool) bool {
|
||||
// constraints if there is no risk the CN would be matched as a hostname.
|
||||
// See NameConstraintsWithoutSANs and issue 24151.
|
||||
func commonNameAsHostname(c *x509.Certificate) bool {
|
||||
return !ignoreCN && !hasSANExtension(c) && validHostnamePattern(c.Subject.CommonName)
|
||||
return !hasSANExtension(c) && validHostnamePattern(c.Subject.CommonName)
|
||||
}
|
||||
|
||||
func matchExactly(hostA, hostB string) bool {
|
||||
|
Loading…
Reference in New Issue
Block a user