Restrict client certificates to certain paths

This commit is contained in:
Adnan Maolood 2020-10-27 23:34:06 -04:00
parent fc72224ce9
commit d1dcf070ff
4 changed files with 48 additions and 20 deletions

24
cert.go
View File

@ -8,6 +8,7 @@ import (
"crypto/x509" "crypto/x509"
"math/big" "math/big"
"net" "net"
"path"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
@ -70,6 +71,27 @@ func (c *CertificateStore) Load(path string) error {
return nil return nil
} }
type ClientCertificateStore struct {
CertificateStore
}
func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) {
urlPath = path.Clean(urlPath)
for {
cert, err := c.CertificateStore.Lookup(hostname + urlPath)
switch err {
case ErrCertificateExpired, nil:
return cert, err
}
slash := urlPath == "/"
urlPath = path.Dir(urlPath)
if slash && urlPath == "/" {
break
}
}
return nil, ErrCertificateUnknown
}
// NewCertificate creates and returns a new parsed certificate. // NewCertificate creates and returns a new parsed certificate.
func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) { func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
crt, priv, err := newX509KeyPair(host, duration) crt, priv, err := newX509KeyPair(host, duration)
@ -114,6 +136,7 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
BasicConstraintsValid: true, BasicConstraintsValid: true,
} }
if host != "" {
hosts := strings.Split(host, ",") hosts := strings.Split(host, ",")
for _, h := range hosts { for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil { if ip := net.ParseIP(h); ip != nil {
@ -122,6 +145,7 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
template.DNSNames = append(template.DNSNames, h) template.DNSNames = append(template.DNSNames, h)
} }
} }
}
crt, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv) crt, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv)
if err != nil { if err != nil {

View File

@ -16,7 +16,7 @@ type Client struct {
// CertificateStore maps hostnames to certificates. // CertificateStore maps hostnames to certificates.
// It is used to determine which certificate to use when the server requests // It is used to determine which certificate to use when the server requests
// a certificate. // a certificate.
CertificateStore CertificateStore CertificateStore ClientCertificateStore
// CheckRedirect, if not nil, will be called to determine whether // CheckRedirect, if not nil, will be called to determine whether
// to follow a redirect. // to follow a redirect.
@ -28,7 +28,7 @@ type Client struct {
// The returned certificate will be used when sending the request again. // The returned certificate will be used when sending the request again.
// If the certificate is nil, the request will not be sent again and // If the certificate is nil, the request will not be sent again and
// the response will be returned. // the response will be returned.
GetCertificate func(hostname string, store *CertificateStore) *tls.Certificate GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
// TrustCertificate, if not nil, will be called to determine whether the // TrustCertificate, if not nil, will be called to determine whether the
// client should trust the given certificate. // client should trust the given certificate.
@ -61,7 +61,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return req.Certificate, nil return req.Certificate, nil
} }
// If we have already stored the certificate, return it // If we have already stored the certificate, return it
if cert, err := c.CertificateStore.Lookup(hostname(req.Host)); err == nil { if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
return cert, nil return cert, nil
} }
return &tls.Certificate{}, nil return &tls.Certificate{}, nil
@ -111,7 +111,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return resp, nil return resp, nil
} }
if c.GetCertificate != nil { if c.GetCertificate != nil {
if cert := c.GetCertificate(hostname(req.Host), &c.CertificateStore); cert != nil { if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
req.Certificate = cert req.Certificate = cert
return c.Do(req) return c.Do(req)
} }

View File

@ -8,7 +8,6 @@ import (
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net/url"
"os" "os"
"time" "time"
@ -47,22 +46,27 @@ func init() {
} }
return err return err
} }
client.GetCertificate = func(hostname string, store *gemini.CertificateStore) *tls.Certificate { client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate {
// If the certificate is in the store, return it // If the certificate is in the store, return it
if cert, err := store.Lookup(hostname); err == nil { if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
return cert return cert
} }
// Otherwise, generate a certificate // Otherwise, generate a certificate
fmt.Println("Generating client certificate for", hostname) fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path)
duration := time.Hour duration := time.Hour
cert, err := gemini.NewCertificate(hostname, duration) cert, err := gemini.NewCertificate("", duration)
if err != nil { if err != nil {
return nil return nil
} }
// Store and return the certificate // Store and return the certificate
store.Add(hostname, cert) store.Add(req.URL.Hostname()+req.URL.Path, cert)
return &cert return &cert
} }
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
fmt.Printf("%s: ", prompt)
scanner.Scan()
return scanner.Text(), true
}
} }
// sendRequest sends a request to the given URL. // sendRequest sends a request to the given URL.

View File

@ -49,16 +49,16 @@ func init() {
defaultClientOnce.Do(func() { knownHosts.LoadDefault() }) defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
return knownHosts.Lookup(hostname, cert) return knownHosts.Lookup(hostname, cert)
} }
DefaultClient.GetCertificate = func(hostname string, store *CertificateStore) *tls.Certificate { DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate {
if cert, err := store.Lookup(hostname); err == nil { if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
return cert return cert
} }
duration := time.Hour duration := time.Hour
cert, err := NewCertificate(hostname, duration) cert, err := NewCertificate("", duration)
if err != nil { if err != nil {
return nil return nil
} }
store.Add(hostname, cert) store.Add(req.URL.Hostname()+req.URL.Path, cert)
return &cert return &cert
} }
} }