Refactor client certificates
This commit is contained in:
151
client.go
151
client.go
@@ -6,37 +6,38 @@ import (
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Client represents a Gemini client.
|
||||
// Client is a Gemini client.
|
||||
type Client struct {
|
||||
// KnownHosts is a list of known hosts that the client trusts.
|
||||
// KnownHosts is a list of known hosts.
|
||||
KnownHosts KnownHosts
|
||||
|
||||
// CertificateStore maps hostnames to certificates.
|
||||
// It is used to determine which certificate to use when the server requests
|
||||
// a certificate.
|
||||
CertificateStore ClientCertificateStore
|
||||
// Certificates stores client-side certificates.
|
||||
Certificates CertificateStore
|
||||
|
||||
// CheckRedirect, if not nil, will be called to determine whether
|
||||
// to follow a redirect.
|
||||
// GetInput is called to retrieve input when the server requests it.
|
||||
// If GetInput is nil or returns false, no input will be sent and
|
||||
// the response will be returned.
|
||||
GetInput func(prompt string, sensitive bool) (input string, ok bool)
|
||||
|
||||
// CheckRedirect determines whether to follow a redirect.
|
||||
// If CheckRedirect is nil, a default policy of no more than 5 consecutive
|
||||
// redirects will be enforced.
|
||||
CheckRedirect func(req *Request, via []*Request) error
|
||||
|
||||
// GetInput, if not nil, will be called to retrieve input when the server
|
||||
// requests it.
|
||||
GetInput func(prompt string, sensitive bool) (string, bool)
|
||||
// CreateCertificate is called to generate a certificate upon
|
||||
// the request of a server.
|
||||
// If CreateCertificate is nil or the returned error is not nil,
|
||||
// the request will not be sent again and the response will be returned.
|
||||
CreateCertificate func(hostname, path string) (tls.Certificate, error)
|
||||
|
||||
// GetCertificate, if not nil, will be called when a server requests a certificate.
|
||||
// The returned certificate will be used when sending the request again.
|
||||
// If the certificate is nil, the request will not be sent again and
|
||||
// the response will be returned.
|
||||
GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
|
||||
|
||||
// TrustCertificate, if not nil, will be called to determine whether the
|
||||
// client should trust the given certificate.
|
||||
// If error is not nil, the connection will be aborted.
|
||||
// TrustCertificate determines whether the client should trust
|
||||
// the provided certificate.
|
||||
// If the returned error is not nil, the connection will be aborted.
|
||||
// If TrustCertificate is nil, the client will check KnownHosts
|
||||
// for the certificate.
|
||||
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
|
||||
}
|
||||
|
||||
@@ -59,38 +60,18 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
config := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
// Request certificates take precedence over client certificates
|
||||
if req.Certificate != nil {
|
||||
return req.Certificate, nil
|
||||
}
|
||||
// If we have already stored the certificate, return it
|
||||
if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
|
||||
return cert, nil
|
||||
}
|
||||
return &tls.Certificate{}, nil
|
||||
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return c.getClientCertificate(req)
|
||||
},
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
cert := cs.PeerCertificates[0]
|
||||
// Verify the hostname
|
||||
if err := verifyHostname(cert, hostname(req.Host)); err != nil {
|
||||
return err
|
||||
}
|
||||
// Check that the client trusts the certificate
|
||||
if c.TrustCertificate == nil {
|
||||
if err := c.KnownHosts.Lookup(hostname(req.Host), cert); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := c.TrustCertificate(hostname(req.Host), cert, &c.KnownHosts); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return c.verifyConnection(req, cs)
|
||||
},
|
||||
}
|
||||
conn, err := tls.Dial("tcp", req.Host, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: Set connection deadline
|
||||
|
||||
// Write the request
|
||||
w := bufio.NewWriter(conn)
|
||||
@@ -104,24 +85,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
if err := resp.read(conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Store connection information
|
||||
// Store connection state
|
||||
resp.TLS = conn.ConnectionState()
|
||||
|
||||
// Resend the request with a certificate if the server responded
|
||||
// with CertificateRequired
|
||||
if resp.Status == StatusCertificateRequired {
|
||||
switch {
|
||||
case resp.Status == StatusCertificateRequired:
|
||||
// Check to see if a certificate was already provided to prevent an infinite loop
|
||||
if req.Certificate != nil {
|
||||
return resp, nil
|
||||
}
|
||||
if c.GetCertificate != nil {
|
||||
if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
|
||||
req.Certificate = cert
|
||||
return c.Do(req)
|
||||
|
||||
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
|
||||
if c.CreateCertificate != nil {
|
||||
cert, err := c.CreateCertificate(hostname, path)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
c.Certificates.Add(hostname+path, cert)
|
||||
return c.do(req, via)
|
||||
}
|
||||
return resp, ErrCertificateRequired
|
||||
} else if resp.Status.Class() == StatusClassRedirect {
|
||||
|
||||
case resp.Status.Class() == StatusClassInput:
|
||||
if c.GetInput != nil {
|
||||
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
|
||||
if ok {
|
||||
req.URL.ForceQuery = true
|
||||
req.URL.RawQuery = url.QueryEscape(input)
|
||||
return c.do(req, via)
|
||||
}
|
||||
}
|
||||
return resp, ErrInputRequired
|
||||
|
||||
case resp.Status.Class() == StatusClassRedirect:
|
||||
if via == nil {
|
||||
via = []*Request{}
|
||||
}
|
||||
@@ -146,27 +142,44 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
return resp, ErrTooManyRedirects
|
||||
}
|
||||
return c.do(redirect, via)
|
||||
} else if resp.Status.Class() == StatusClassInput {
|
||||
if c.GetInput != nil {
|
||||
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
|
||||
if ok {
|
||||
req.URL.ForceQuery = true
|
||||
req.URL.RawQuery = url.QueryEscape(input)
|
||||
return c.do(req, via)
|
||||
}
|
||||
}
|
||||
return resp, ErrInputRequired
|
||||
}
|
||||
|
||||
resp.Request = req
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// hostname returns the host without the port.
|
||||
func hostname(host string) string {
|
||||
hostname, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return host
|
||||
func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
|
||||
// Request certificates have the highest precedence
|
||||
if req.Certificate != nil {
|
||||
return req.Certificate, nil
|
||||
}
|
||||
return hostname
|
||||
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
|
||||
if cert, err := c.Certificates.lookup(hostname + path); err == nil {
|
||||
// Remember the certificate used
|
||||
req.Certificate = cert
|
||||
return cert, nil
|
||||
}
|
||||
return &tls.Certificate{}, nil
|
||||
}
|
||||
|
||||
func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
|
||||
// Verify the hostname
|
||||
var hostname string
|
||||
if host, _, err := net.SplitHostPort(req.Host); err == nil {
|
||||
hostname = host
|
||||
} else {
|
||||
hostname = req.Host
|
||||
}
|
||||
cert := cs.PeerCertificates[0]
|
||||
if err := verifyHostname(cert, hostname); err != nil {
|
||||
return err
|
||||
}
|
||||
// Check that the client trusts the certificate
|
||||
var err error
|
||||
if c.TrustCertificate != nil {
|
||||
return c.TrustCertificate(hostname, cert, &c.KnownHosts)
|
||||
} else {
|
||||
err = c.KnownHosts.Lookup(hostname, cert)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user