Refactor client certificates

This commit is contained in:
Adnan Maolood
2020-10-28 13:40:25 -04:00
parent 768664e0c5
commit fbd97a62de
6 changed files with 140 additions and 155 deletions

151
client.go
View File

@@ -6,37 +6,38 @@ import (
"crypto/x509"
"net"
"net/url"
"strings"
)
// Client represents a Gemini client.
// Client is a Gemini client.
type Client struct {
// KnownHosts is a list of known hosts that the client trusts.
// KnownHosts is a list of known hosts.
KnownHosts KnownHosts
// CertificateStore maps hostnames to certificates.
// It is used to determine which certificate to use when the server requests
// a certificate.
CertificateStore ClientCertificateStore
// Certificates stores client-side certificates.
Certificates CertificateStore
// CheckRedirect, if not nil, will be called to determine whether
// to follow a redirect.
// GetInput is called to retrieve input when the server requests it.
// If GetInput is nil or returns false, no input will be sent and
// the response will be returned.
GetInput func(prompt string, sensitive bool) (input string, ok bool)
// CheckRedirect determines whether to follow a redirect.
// If CheckRedirect is nil, a default policy of no more than 5 consecutive
// redirects will be enforced.
CheckRedirect func(req *Request, via []*Request) error
// GetInput, if not nil, will be called to retrieve input when the server
// requests it.
GetInput func(prompt string, sensitive bool) (string, bool)
// CreateCertificate is called to generate a certificate upon
// the request of a server.
// If CreateCertificate is nil or the returned error is not nil,
// the request will not be sent again and the response will be returned.
CreateCertificate func(hostname, path string) (tls.Certificate, error)
// GetCertificate, if not nil, will be called when a server requests a certificate.
// The returned certificate will be used when sending the request again.
// If the certificate is nil, the request will not be sent again and
// the response will be returned.
GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
// TrustCertificate, if not nil, will be called to determine whether the
// client should trust the given certificate.
// If error is not nil, the connection will be aborted.
// TrustCertificate determines whether the client should trust
// the provided certificate.
// If the returned error is not nil, the connection will be aborted.
// If TrustCertificate is nil, the client will check KnownHosts
// for the certificate.
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
}
@@ -59,38 +60,18 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
config := &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
// Request certificates take precedence over client certificates
if req.Certificate != nil {
return req.Certificate, nil
}
// If we have already stored the certificate, return it
if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
return cert, nil
}
return &tls.Certificate{}, nil
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return c.getClientCertificate(req)
},
VerifyConnection: func(cs tls.ConnectionState) error {
cert := cs.PeerCertificates[0]
// Verify the hostname
if err := verifyHostname(cert, hostname(req.Host)); err != nil {
return err
}
// Check that the client trusts the certificate
if c.TrustCertificate == nil {
if err := c.KnownHosts.Lookup(hostname(req.Host), cert); err != nil {
return err
}
} else if err := c.TrustCertificate(hostname(req.Host), cert, &c.KnownHosts); err != nil {
return err
}
return nil
return c.verifyConnection(req, cs)
},
}
conn, err := tls.Dial("tcp", req.Host, config)
if err != nil {
return nil, err
}
// TODO: Set connection deadline
// Write the request
w := bufio.NewWriter(conn)
@@ -104,24 +85,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
if err := resp.read(conn); err != nil {
return nil, err
}
// Store connection information
// Store connection state
resp.TLS = conn.ConnectionState()
// Resend the request with a certificate if the server responded
// with CertificateRequired
if resp.Status == StatusCertificateRequired {
switch {
case resp.Status == StatusCertificateRequired:
// Check to see if a certificate was already provided to prevent an infinite loop
if req.Certificate != nil {
return resp, nil
}
if c.GetCertificate != nil {
if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
req.Certificate = cert
return c.Do(req)
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
if c.CreateCertificate != nil {
cert, err := c.CreateCertificate(hostname, path)
if err != nil {
return resp, err
}
c.Certificates.Add(hostname+path, cert)
return c.do(req, via)
}
return resp, ErrCertificateRequired
} else if resp.Status.Class() == StatusClassRedirect {
case resp.Status.Class() == StatusClassInput:
if c.GetInput != nil {
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok {
req.URL.ForceQuery = true
req.URL.RawQuery = url.QueryEscape(input)
return c.do(req, via)
}
}
return resp, ErrInputRequired
case resp.Status.Class() == StatusClassRedirect:
if via == nil {
via = []*Request{}
}
@@ -146,27 +142,44 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return resp, ErrTooManyRedirects
}
return c.do(redirect, via)
} else if resp.Status.Class() == StatusClassInput {
if c.GetInput != nil {
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok {
req.URL.ForceQuery = true
req.URL.RawQuery = url.QueryEscape(input)
return c.do(req, via)
}
}
return resp, ErrInputRequired
}
resp.Request = req
return resp, nil
}
// hostname returns the host without the port.
func hostname(host string) string {
hostname, _, err := net.SplitHostPort(host)
if err != nil {
return host
func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
// Request certificates have the highest precedence
if req.Certificate != nil {
return req.Certificate, nil
}
return hostname
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
if cert, err := c.Certificates.lookup(hostname + path); err == nil {
// Remember the certificate used
req.Certificate = cert
return cert, nil
}
return &tls.Certificate{}, nil
}
func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
// Verify the hostname
var hostname string
if host, _, err := net.SplitHostPort(req.Host); err == nil {
hostname = host
} else {
hostname = req.Host
}
cert := cs.PeerCertificates[0]
if err := verifyHostname(cert, hostname); err != nil {
return err
}
// Check that the client trusts the certificate
var err error
if c.TrustCertificate != nil {
return c.TrustCertificate(hostname, cert, &c.KnownHosts)
} else {
err = c.KnownHosts.Lookup(hostname, cert)
}
return err
}