Refactor client certificates

This commit is contained in:
Adnan Maolood 2020-10-28 13:40:25 -04:00
parent 768664e0c5
commit fbd97a62de
6 changed files with 140 additions and 155 deletions

75
cert.go
View File

@ -20,9 +20,9 @@ type CertificateStore struct {
store map[string]tls.Certificate
}
// Add adds a certificate for the given hostname to the store.
// Add adds a certificate for the given scope to the store.
// It tries to parse the certificate if it is not already parsed.
func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
func (c *CertificateStore) Add(scope string, cert tls.Certificate) {
if c.store == nil {
c.store = map[string]tls.Certificate{}
}
@ -33,7 +33,7 @@ func (c *CertificateStore) Add(hostname string, cert tls.Certificate) {
cert.Leaf = parsed
}
}
c.store[hostname] = cert
c.store[scope] = cert
}
// Lookup returns the certificate for the given hostname.
@ -49,6 +49,22 @@ func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) {
return &cert, nil
}
// lookup returns the certificate for the given hostname + path.
func (c *CertificateStore) lookup(scope string) (*tls.Certificate, error) {
for {
cert, err := c.Lookup(scope)
switch err {
case ErrCertificateExpired, nil:
return cert, err
}
scope = path.Dir(scope)
if scope == "." {
break
}
}
return nil, ErrCertificateUnknown
}
// Load loads certificates from the given path.
// The path should lead to a directory containing certificates and private keys
// in the form hostname.crt and hostname.key.
@ -71,36 +87,16 @@ func (c *CertificateStore) Load(path string) error {
return nil
}
type ClientCertificateStore struct {
CertificateStore
// CertificateOptions configures how a certificate is created.
type CertificateOptions struct {
IPAddresses []net.IP
DNSNames []string
Duration time.Duration
}
func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) {
urlPath = path.Clean(urlPath)
if urlPath == "." {
urlPath = "/"
}
if urlPath[0] != '/' {
urlPath = "/" + urlPath
}
for {
cert, err := c.CertificateStore.Lookup(hostname + urlPath)
switch err {
case ErrCertificateExpired, nil:
return cert, err
}
slash := urlPath == "/"
urlPath = path.Dir(urlPath)
if slash && urlPath == "/" {
break
}
}
return nil, ErrCertificateUnknown
}
// NewCertificate creates and returns a new parsed certificate.
func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
crt, priv, err := newX509KeyPair(host, duration)
// CreateCertificate creates a new TLS certificate.
func CreateCertificate(options CertificateOptions) (tls.Certificate, error) {
crt, priv, err := newX509KeyPair(options)
if err != nil {
return tls.Certificate{}, err
}
@ -112,7 +108,7 @@ func NewCertificate(host string, duration time.Duration) (tls.Certificate, error
}
// newX509KeyPair creates and returns a new certificate and private key.
func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, crypto.PrivateKey, error) {
func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.PrivateKey, error) {
// Generate an ED25519 private key
_, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
@ -131,7 +127,7 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
}
notBefore := time.Now()
notAfter := notBefore.Add(duration)
notAfter := notBefore.Add(options.Duration)
template := x509.Certificate{
SerialNumber: serialNumber,
@ -140,17 +136,8 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
if host != "" {
hosts := strings.Split(host, ",")
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
IPAddresses: options.IPAddresses,
DNSNames: options.DNSNames,
}
crt, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv)

151
client.go
View File

@ -6,37 +6,38 @@ import (
"crypto/x509"
"net"
"net/url"
"strings"
)
// Client represents a Gemini client.
// Client is a Gemini client.
type Client struct {
// KnownHosts is a list of known hosts that the client trusts.
// KnownHosts is a list of known hosts.
KnownHosts KnownHosts
// CertificateStore maps hostnames to certificates.
// It is used to determine which certificate to use when the server requests
// a certificate.
CertificateStore ClientCertificateStore
// Certificates stores client-side certificates.
Certificates CertificateStore
// CheckRedirect, if not nil, will be called to determine whether
// to follow a redirect.
// GetInput is called to retrieve input when the server requests it.
// If GetInput is nil or returns false, no input will be sent and
// the response will be returned.
GetInput func(prompt string, sensitive bool) (input string, ok bool)
// CheckRedirect determines whether to follow a redirect.
// If CheckRedirect is nil, a default policy of no more than 5 consecutive
// redirects will be enforced.
CheckRedirect func(req *Request, via []*Request) error
// GetInput, if not nil, will be called to retrieve input when the server
// requests it.
GetInput func(prompt string, sensitive bool) (string, bool)
// CreateCertificate is called to generate a certificate upon
// the request of a server.
// If CreateCertificate is nil or the returned error is not nil,
// the request will not be sent again and the response will be returned.
CreateCertificate func(hostname, path string) (tls.Certificate, error)
// GetCertificate, if not nil, will be called when a server requests a certificate.
// The returned certificate will be used when sending the request again.
// If the certificate is nil, the request will not be sent again and
// the response will be returned.
GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
// TrustCertificate, if not nil, will be called to determine whether the
// client should trust the given certificate.
// If error is not nil, the connection will be aborted.
// TrustCertificate determines whether the client should trust
// the provided certificate.
// If the returned error is not nil, the connection will be aborted.
// If TrustCertificate is nil, the client will check KnownHosts
// for the certificate.
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
}
@ -59,38 +60,18 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
config := &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12,
GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
// Request certificates take precedence over client certificates
if req.Certificate != nil {
return req.Certificate, nil
}
// If we have already stored the certificate, return it
if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
return cert, nil
}
return &tls.Certificate{}, nil
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return c.getClientCertificate(req)
},
VerifyConnection: func(cs tls.ConnectionState) error {
cert := cs.PeerCertificates[0]
// Verify the hostname
if err := verifyHostname(cert, hostname(req.Host)); err != nil {
return err
}
// Check that the client trusts the certificate
if c.TrustCertificate == nil {
if err := c.KnownHosts.Lookup(hostname(req.Host), cert); err != nil {
return err
}
} else if err := c.TrustCertificate(hostname(req.Host), cert, &c.KnownHosts); err != nil {
return err
}
return nil
return c.verifyConnection(req, cs)
},
}
conn, err := tls.Dial("tcp", req.Host, config)
if err != nil {
return nil, err
}
// TODO: Set connection deadline
// Write the request
w := bufio.NewWriter(conn)
@ -104,24 +85,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
if err := resp.read(conn); err != nil {
return nil, err
}
// Store connection information
// Store connection state
resp.TLS = conn.ConnectionState()
// Resend the request with a certificate if the server responded
// with CertificateRequired
if resp.Status == StatusCertificateRequired {
switch {
case resp.Status == StatusCertificateRequired:
// Check to see if a certificate was already provided to prevent an infinite loop
if req.Certificate != nil {
return resp, nil
}
if c.GetCertificate != nil {
if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
req.Certificate = cert
return c.Do(req)
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
if c.CreateCertificate != nil {
cert, err := c.CreateCertificate(hostname, path)
if err != nil {
return resp, err
}
c.Certificates.Add(hostname+path, cert)
return c.do(req, via)
}
return resp, ErrCertificateRequired
} else if resp.Status.Class() == StatusClassRedirect {
case resp.Status.Class() == StatusClassInput:
if c.GetInput != nil {
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok {
req.URL.ForceQuery = true
req.URL.RawQuery = url.QueryEscape(input)
return c.do(req, via)
}
}
return resp, ErrInputRequired
case resp.Status.Class() == StatusClassRedirect:
if via == nil {
via = []*Request{}
}
@ -146,27 +142,44 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return resp, ErrTooManyRedirects
}
return c.do(redirect, via)
} else if resp.Status.Class() == StatusClassInput {
if c.GetInput != nil {
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok {
req.URL.ForceQuery = true
req.URL.RawQuery = url.QueryEscape(input)
return c.do(req, via)
}
}
return resp, ErrInputRequired
}
resp.Request = req
return resp, nil
}
// hostname returns the host without the port.
func hostname(host string) string {
hostname, _, err := net.SplitHostPort(host)
if err != nil {
return host
func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
// Request certificates have the highest precedence
if req.Certificate != nil {
return req.Certificate, nil
}
return hostname
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
if cert, err := c.Certificates.lookup(hostname + path); err == nil {
// Remember the certificate used
req.Certificate = cert
return cert, nil
}
return &tls.Certificate{}, nil
}
func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
// Verify the hostname
var hostname string
if host, _, err := net.SplitHostPort(req.Host); err == nil {
hostname = host
} else {
hostname = req.Host
}
cert := cs.PeerCertificates[0]
if err := verifyHostname(cert, hostname); err != nil {
return err
}
// Check that the client trusts the certificate
var err error
if c.TrustCertificate != nil {
return c.TrustCertificate(hostname, cert, &c.KnownHosts)
} else {
err = c.KnownHosts.Lookup(hostname, cert)
}
return err
}

View File

@ -46,21 +46,11 @@ func init() {
}
return err
}
client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate {
// If the certificate is in the store, return it
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
return cert
}
// Otherwise, generate a certificate
fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path)
duration := time.Hour
cert, err := gemini.NewCertificate("", duration)
if err != nil {
return nil
}
// Store and return the certificate
store.Add(req.URL.Hostname()+req.URL.Path, cert)
return &cert
client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
fmt.Println("Generating client certificate for", hostname, path)
return gemini.CreateCertificate(gemini.CertificateOptions{
Duration: time.Hour,
})
}
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
fmt.Printf("%s: ", prompt)
@ -69,8 +59,7 @@ func init() {
}
}
// sendRequest sends a request to the given URL.
func sendRequest(req *gemini.Request) error {
func doRequest(req *gemini.Request) error {
resp, err := client.Do(req)
if err != nil {
return err
@ -131,7 +120,7 @@ func main() {
os.Exit(1)
}
if err := sendRequest(req); err != nil {
if err := doRequest(req); err != nil {
fmt.Println(err)
os.Exit(1)
}

View File

@ -29,7 +29,10 @@ func main() {
fallthrough
case gmi.ErrCertificateUnknown:
// Generate a certificate if one does not exist.
cert, err := gmi.NewCertificate(hostname, time.Minute)
cert, err := gmi.CreateCertificate(gmi.CertificateOptions{
DNSNames: []string{hostname},
Duration: time.Hour,
})
if err != nil {
// Failed to generate new certificate, abort
return nil

View File

@ -51,16 +51,9 @@ func init() {
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
return knownHosts.Lookup(hostname, cert)
}
DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate {
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
return cert
}
duration := time.Hour
cert, err := NewCertificate("", duration)
if err != nil {
return nil
}
store.Add(req.URL.Hostname()+req.URL.Path, cert)
return &cert
DefaultClient.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
return CreateCertificate(CertificateOptions{
Duration: time.Hour,
})
}
}

View File

@ -24,6 +24,18 @@ const (
StatusCertificateNotValid Status = 62
)
// Status code categories.
type StatusClass int
const (
StatusClassInput StatusClass = 1
StatusClassSuccess StatusClass = 2
StatusClassRedirect StatusClass = 3
StatusClassTemporaryFailure StatusClass = 4
StatusClassPermanentFailure StatusClass = 5
StatusClassCertificateRequired StatusClass = 6
)
// Class returns the status class for this status code.
func (s Status) Class() StatusClass {
return StatusClass(s / 10)
@ -71,15 +83,3 @@ func (s Status) Message() string {
}
return ""
}
// Status code categories.
type StatusClass int
const (
StatusClassInput StatusClass = 1
StatusClassSuccess StatusClass = 2
StatusClassRedirect StatusClass = 3
StatusClassTemporaryFailure StatusClass = 4
StatusClassPermanentFailure StatusClass = 5
StatusClassCertificateRequired StatusClass = 6
)