Refactor client certificates

This commit is contained in:
Adnan Maolood 2020-10-28 13:40:25 -04:00
parent 768664e0c5
commit fbd97a62de
6 changed files with 140 additions and 155 deletions

75
cert.go
View File

@ -20,9 +20,9 @@ type CertificateStore struct {
store map[string]tls.Certificate store map[string]tls.Certificate
} }
// Add adds a certificate for the given hostname to the store. // Add adds a certificate for the given scope to the store.
// It tries to parse the certificate if it is not already parsed. // It tries to parse the certificate if it is not already parsed.
func (c *CertificateStore) Add(hostname string, cert tls.Certificate) { func (c *CertificateStore) Add(scope string, cert tls.Certificate) {
if c.store == nil { if c.store == nil {
c.store = map[string]tls.Certificate{} c.store = map[string]tls.Certificate{}
} }
@ -33,7 +33,7 @@ func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
cert.Leaf = parsed cert.Leaf = parsed
} }
} }
c.store[hostname] = cert c.store[scope] = cert
} }
// Lookup returns the certificate for the given hostname. // Lookup returns the certificate for the given hostname.
@ -49,6 +49,22 @@ func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) {
return &cert, nil return &cert, nil
} }
// lookup returns the certificate for the given hostname + path.
func (c *CertificateStore) lookup(scope string) (*tls.Certificate, error) {
for {
cert, err := c.Lookup(scope)
switch err {
case ErrCertificateExpired, nil:
return cert, err
}
scope = path.Dir(scope)
if scope == "." {
break
}
}
return nil, ErrCertificateUnknown
}
// Load loads certificates from the given path. // Load loads certificates from the given path.
// The path should lead to a directory containing certificates and private keys // The path should lead to a directory containing certificates and private keys
// in the form hostname.crt and hostname.key. // in the form hostname.crt and hostname.key.
@ -71,36 +87,16 @@ func (c *CertificateStore) Load(path string) error {
return nil return nil
} }
type ClientCertificateStore struct { // CertificateOptions configures how a certificate is created.
CertificateStore type CertificateOptions struct {
IPAddresses []net.IP
DNSNames []string
Duration time.Duration
} }
func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) { // CreateCertificate creates a new TLS certificate.
urlPath = path.Clean(urlPath) func CreateCertificate(options CertificateOptions) (tls.Certificate, error) {
if urlPath == "." { crt, priv, err := newX509KeyPair(options)
urlPath = "/"
}
if urlPath[0] != '/' {
urlPath = "/" + urlPath
}
for {
cert, err := c.CertificateStore.Lookup(hostname + urlPath)
switch err {
case ErrCertificateExpired, nil:
return cert, err
}
slash := urlPath == "/"
urlPath = path.Dir(urlPath)
if slash && urlPath == "/" {
break
}
}
return nil, ErrCertificateUnknown
}
// NewCertificate creates and returns a new parsed certificate.
func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
crt, priv, err := newX509KeyPair(host, duration)
if err != nil { if err != nil {
return tls.Certificate{}, err return tls.Certificate{}, err
} }
@ -112,7 +108,7 @@ func NewCertificate(host string, duration time.Duration) (tls.Certificate, error
} }
// newX509KeyPair creates and returns a new certificate and private key. // newX509KeyPair creates and returns a new certificate and private key.
func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, crypto.PrivateKey, error) { func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.PrivateKey, error) {
// Generate an ED25519 private key // Generate an ED25519 private key
_, priv, err := ed25519.GenerateKey(rand.Reader) _, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil { if err != nil {
@ -131,7 +127,7 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
} }
notBefore := time.Now() notBefore := time.Now()
notAfter := notBefore.Add(duration) notAfter := notBefore.Add(options.Duration)
template := x509.Certificate{ template := x509.Certificate{
SerialNumber: serialNumber, SerialNumber: serialNumber,
@ -140,17 +136,8 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
KeyUsage: keyUsage, KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true, BasicConstraintsValid: true,
} IPAddresses: options.IPAddresses,
DNSNames: options.DNSNames,
if host != "" {
hosts := strings.Split(host, ",")
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
} }
crt, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv) crt, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv)

151
client.go
View File

@ -6,37 +6,38 @@ import (
"crypto/x509" "crypto/x509"
"net" "net"
"net/url" "net/url"
"strings"
) )
// Client represents a Gemini client. // Client is a Gemini client.
type Client struct { type Client struct {
// KnownHosts is a list of known hosts that the client trusts. // KnownHosts is a list of known hosts.
KnownHosts KnownHosts KnownHosts KnownHosts
// CertificateStore maps hostnames to certificates. // Certificates stores client-side certificates.
// It is used to determine which certificate to use when the server requests Certificates CertificateStore
// a certificate.
CertificateStore ClientCertificateStore
// CheckRedirect, if not nil, will be called to determine whether // GetInput is called to retrieve input when the server requests it.
// to follow a redirect. // If GetInput is nil or returns false, no input will be sent and
// the response will be returned.
GetInput func(prompt string, sensitive bool) (input string, ok bool)
// CheckRedirect determines whether to follow a redirect.
// If CheckRedirect is nil, a default policy of no more than 5 consecutive // If CheckRedirect is nil, a default policy of no more than 5 consecutive
// redirects will be enforced. // redirects will be enforced.
CheckRedirect func(req *Request, via []*Request) error CheckRedirect func(req *Request, via []*Request) error
// GetInput, if not nil, will be called to retrieve input when the server // CreateCertificate is called to generate a certificate upon
// requests it. // the request of a server.
GetInput func(prompt string, sensitive bool) (string, bool) // If CreateCertificate is nil or the returned error is not nil,
// the request will not be sent again and the response will be returned.
CreateCertificate func(hostname, path string) (tls.Certificate, error)
// GetCertificate, if not nil, will be called when a server requests a certificate. // TrustCertificate determines whether the client should trust
// The returned certificate will be used when sending the request again. // the provided certificate.
// If the certificate is nil, the request will not be sent again and // If the returned error is not nil, the connection will be aborted.
// the response will be returned. // If TrustCertificate is nil, the client will check KnownHosts
GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate // for the certificate.
// TrustCertificate, if not nil, will be called to determine whether the
// client should trust the given certificate.
// If error is not nil, the connection will be aborted.
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
} }
@ -59,38 +60,18 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
config := &tls.Config{ config := &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
// Request certificates take precedence over client certificates return c.getClientCertificate(req)
if req.Certificate != nil {
return req.Certificate, nil
}
// If we have already stored the certificate, return it
if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
return cert, nil
}
return &tls.Certificate{}, nil
}, },
VerifyConnection: func(cs tls.ConnectionState) error { VerifyConnection: func(cs tls.ConnectionState) error {
cert := cs.PeerCertificates[0] return c.verifyConnection(req, cs)
// Verify the hostname
if err := verifyHostname(cert, hostname(req.Host)); err != nil {
return err
}
// Check that the client trusts the certificate
if c.TrustCertificate == nil {
if err := c.KnownHosts.Lookup(hostname(req.Host), cert); err != nil {
return err
}
} else if err := c.TrustCertificate(hostname(req.Host), cert, &c.KnownHosts); err != nil {
return err
}
return nil
}, },
} }
conn, err := tls.Dial("tcp", req.Host, config) conn, err := tls.Dial("tcp", req.Host, config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: Set connection deadline
// Write the request // Write the request
w := bufio.NewWriter(conn) w := bufio.NewWriter(conn)
@ -104,24 +85,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
if err := resp.read(conn); err != nil { if err := resp.read(conn); err != nil {
return nil, err return nil, err
} }
// Store connection information // Store connection state
resp.TLS = conn.ConnectionState() resp.TLS = conn.ConnectionState()
// Resend the request with a certificate if the server responded switch {
// with CertificateRequired case resp.Status == StatusCertificateRequired:
if resp.Status == StatusCertificateRequired {
// Check to see if a certificate was already provided to prevent an infinite loop // Check to see if a certificate was already provided to prevent an infinite loop
if req.Certificate != nil { if req.Certificate != nil {
return resp, nil return resp, nil
} }
if c.GetCertificate != nil {
if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil { hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
req.Certificate = cert if c.CreateCertificate != nil {
return c.Do(req) cert, err := c.CreateCertificate(hostname, path)
if err != nil {
return resp, err
} }
c.Certificates.Add(hostname+path, cert)
return c.do(req, via)
} }
return resp, ErrCertificateRequired return resp, ErrCertificateRequired
} else if resp.Status.Class() == StatusClassRedirect {
case resp.Status.Class() == StatusClassInput:
if c.GetInput != nil {
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok {
req.URL.ForceQuery = true
req.URL.RawQuery = url.QueryEscape(input)
return c.do(req, via)
}
}
return resp, ErrInputRequired
case resp.Status.Class() == StatusClassRedirect:
if via == nil { if via == nil {
via = []*Request{} via = []*Request{}
} }
@ -146,27 +142,44 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return resp, ErrTooManyRedirects return resp, ErrTooManyRedirects
} }
return c.do(redirect, via) return c.do(redirect, via)
} else if resp.Status.Class() == StatusClassInput {
if c.GetInput != nil {
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok {
req.URL.ForceQuery = true
req.URL.RawQuery = url.QueryEscape(input)
return c.do(req, via)
}
}
return resp, ErrInputRequired
} }
resp.Request = req resp.Request = req
return resp, nil return resp, nil
} }
// hostname returns the host without the port. func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
func hostname(host string) string { // Request certificates have the highest precedence
hostname, _, err := net.SplitHostPort(host) if req.Certificate != nil {
if err != nil { return req.Certificate, nil
return host
} }
return hostname hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
if cert, err := c.Certificates.lookup(hostname + path); err == nil {
// Remember the certificate used
req.Certificate = cert
return cert, nil
}
return &tls.Certificate{}, nil
}
func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
// Verify the hostname
var hostname string
if host, _, err := net.SplitHostPort(req.Host); err == nil {
hostname = host
} else {
hostname = req.Host
}
cert := cs.PeerCertificates[0]
if err := verifyHostname(cert, hostname); err != nil {
return err
}
// Check that the client trusts the certificate
var err error
if c.TrustCertificate != nil {
return c.TrustCertificate(hostname, cert, &c.KnownHosts)
} else {
err = c.KnownHosts.Lookup(hostname, cert)
}
return err
} }

View File

@ -46,21 +46,11 @@ func init() {
} }
return err return err
} }
client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate { client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
// If the certificate is in the store, return it fmt.Println("Generating client certificate for", hostname, path)
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil { return gemini.CreateCertificate(gemini.CertificateOptions{
return cert Duration: time.Hour,
} })
// Otherwise, generate a certificate
fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path)
duration := time.Hour
cert, err := gemini.NewCertificate("", duration)
if err != nil {
return nil
}
// Store and return the certificate
store.Add(req.URL.Hostname()+req.URL.Path, cert)
return &cert
} }
client.GetInput = func(prompt string, sensitive bool) (string, bool) { client.GetInput = func(prompt string, sensitive bool) (string, bool) {
fmt.Printf("%s: ", prompt) fmt.Printf("%s: ", prompt)
@ -69,8 +59,7 @@ func init() {
} }
} }
// sendRequest sends a request to the given URL. func doRequest(req *gemini.Request) error {
func sendRequest(req *gemini.Request) error {
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return err return err
@ -131,7 +120,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if err := sendRequest(req); err != nil { if err := doRequest(req); err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }

View File

@ -29,7 +29,10 @@ func main() {
fallthrough fallthrough
case gmi.ErrCertificateUnknown: case gmi.ErrCertificateUnknown:
// Generate a certificate if one does not exist. // Generate a certificate if one does not exist.
cert, err := gmi.NewCertificate(hostname, time.Minute) cert, err := gmi.CreateCertificate(gmi.CertificateOptions{
DNSNames: []string{hostname},
Duration: time.Hour,
})
if err != nil { if err != nil {
// Failed to generate new certificate, abort // Failed to generate new certificate, abort
return nil return nil

View File

@ -51,16 +51,9 @@ func init() {
defaultClientOnce.Do(func() { knownHosts.LoadDefault() }) defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
return knownHosts.Lookup(hostname, cert) return knownHosts.Lookup(hostname, cert)
} }
DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate { DefaultClient.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil { return CreateCertificate(CertificateOptions{
return cert Duration: time.Hour,
} })
duration := time.Hour
cert, err := NewCertificate("", duration)
if err != nil {
return nil
}
store.Add(req.URL.Hostname()+req.URL.Path, cert)
return &cert
} }
} }

View File

@ -24,6 +24,18 @@ const (
StatusCertificateNotValid Status = 62 StatusCertificateNotValid Status = 62
) )
// Status code categories.
type StatusClass int
const (
StatusClassInput StatusClass = 1
StatusClassSuccess StatusClass = 2
StatusClassRedirect StatusClass = 3
StatusClassTemporaryFailure StatusClass = 4
StatusClassPermanentFailure StatusClass = 5
StatusClassCertificateRequired StatusClass = 6
)
// Class returns the status class for this status code. // Class returns the status class for this status code.
func (s Status) Class() StatusClass { func (s Status) Class() StatusClass {
return StatusClass(s / 10) return StatusClass(s / 10)
@ -71,15 +83,3 @@ func (s Status) Message() string {
} }
return "" return ""
} }
// Status code categories.
type StatusClass int
const (
StatusClassInput StatusClass = 1
StatusClassSuccess StatusClass = 2
StatusClassRedirect StatusClass = 3
StatusClassTemporaryFailure StatusClass = 4
StatusClassPermanentFailure StatusClass = 5
StatusClassCertificateRequired StatusClass = 6
)