Restrict client certificates to certain paths

This commit is contained in:
Adnan Maolood 2020-10-27 23:34:06 -04:00
parent fc72224ce9
commit d1dcf070ff
4 changed files with 48 additions and 20 deletions

36
cert.go
View File

@ -8,6 +8,7 @@ import (
"crypto/x509"
"math/big"
"net"
"path"
"path/filepath"
"strings"
"time"
@ -70,6 +71,27 @@ func (c *CertificateStore) Load(path string) error {
return nil
}
type ClientCertificateStore struct {
CertificateStore
}
func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) {
urlPath = path.Clean(urlPath)
for {
cert, err := c.CertificateStore.Lookup(hostname + urlPath)
switch err {
case ErrCertificateExpired, nil:
return cert, err
}
slash := urlPath == "/"
urlPath = path.Dir(urlPath)
if slash && urlPath == "/" {
break
}
}
return nil, ErrCertificateUnknown
}
// NewCertificate creates and returns a new parsed certificate.
func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
crt, priv, err := newX509KeyPair(host, duration)
@ -114,12 +136,14 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
BasicConstraintsValid: true,
}
hosts := strings.Split(host, ",")
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
if host != "" {
hosts := strings.Split(host, ",")
for _, h := range hosts {
if ip := net.ParseIP(h); ip != nil {
template.IPAddresses = append(template.IPAddresses, ip)
} else {
template.DNSNames = append(template.DNSNames, h)
}
}
}

View File

@ -16,7 +16,7 @@ type Client struct {
// CertificateStore maps hostnames to certificates.
// It is used to determine which certificate to use when the server requests
// a certificate.
CertificateStore CertificateStore
CertificateStore ClientCertificateStore
// CheckRedirect, if not nil, will be called to determine whether
// to follow a redirect.
@ -28,7 +28,7 @@ type Client struct {
// The returned certificate will be used when sending the request again.
// If the certificate is nil, the request will not be sent again and
// the response will be returned.
GetCertificate func(hostname string, store *CertificateStore) *tls.Certificate
GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
// TrustCertificate, if not nil, will be called to determine whether the
// client should trust the given certificate.
@ -61,7 +61,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return req.Certificate, nil
}
// If we have already stored the certificate, return it
if cert, err := c.CertificateStore.Lookup(hostname(req.Host)); err == nil {
if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
return cert, nil
}
return &tls.Certificate{}, nil
@ -111,7 +111,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return resp, nil
}
if c.GetCertificate != nil {
if cert := c.GetCertificate(hostname(req.Host), &c.CertificateStore); cert != nil {
if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
req.Certificate = cert
return c.Do(req)
}

View File

@ -8,7 +8,6 @@ import (
"crypto/x509"
"fmt"
"io/ioutil"
"net/url"
"os"
"time"
@ -47,22 +46,27 @@ func init() {
}
return err
}
client.GetCertificate = func(hostname string, store *gemini.CertificateStore) *tls.Certificate {
client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate {
// If the certificate is in the store, return it
if cert, err := store.Lookup(hostname); err == nil {
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
return cert
}
// Otherwise, generate a certificate
fmt.Println("Generating client certificate for", hostname)
fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path)
duration := time.Hour
cert, err := gemini.NewCertificate(hostname, duration)
cert, err := gemini.NewCertificate("", duration)
if err != nil {
return nil
}
// Store and return the certificate
store.Add(hostname, cert)
store.Add(req.URL.Hostname()+req.URL.Path, cert)
return &cert
}
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
fmt.Printf("%s: ", prompt)
scanner.Scan()
return scanner.Text(), true
}
}
// sendRequest sends a request to the given URL.

View File

@ -49,16 +49,16 @@ func init() {
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
return knownHosts.Lookup(hostname, cert)
}
DefaultClient.GetCertificate = func(hostname string, store *CertificateStore) *tls.Certificate {
if cert, err := store.Lookup(hostname); err == nil {
DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate {
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
return cert
}
duration := time.Hour
cert, err := NewCertificate(hostname, duration)
cert, err := NewCertificate("", duration)
if err != nil {
return nil
}
store.Add(hostname, cert)
store.Add(req.URL.Hostname()+req.URL.Path, cert)
return &cert
}
}