Restrict client certificates to certain paths
This commit is contained in:
parent
fc72224ce9
commit
d1dcf070ff
36
cert.go
36
cert.go
@ -8,6 +8,7 @@ import (
|
||||
"crypto/x509"
|
||||
"math/big"
|
||||
"net"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
@ -70,6 +71,27 @@ func (c *CertificateStore) Load(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type ClientCertificateStore struct {
|
||||
CertificateStore
|
||||
}
|
||||
|
||||
func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) {
|
||||
urlPath = path.Clean(urlPath)
|
||||
for {
|
||||
cert, err := c.CertificateStore.Lookup(hostname + urlPath)
|
||||
switch err {
|
||||
case ErrCertificateExpired, nil:
|
||||
return cert, err
|
||||
}
|
||||
slash := urlPath == "/"
|
||||
urlPath = path.Dir(urlPath)
|
||||
if slash && urlPath == "/" {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil, ErrCertificateUnknown
|
||||
}
|
||||
|
||||
// NewCertificate creates and returns a new parsed certificate.
|
||||
func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
|
||||
crt, priv, err := newX509KeyPair(host, duration)
|
||||
@ -114,12 +136,14 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
hosts := strings.Split(host, ",")
|
||||
for _, h := range hosts {
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
template.IPAddresses = append(template.IPAddresses, ip)
|
||||
} else {
|
||||
template.DNSNames = append(template.DNSNames, h)
|
||||
if host != "" {
|
||||
hosts := strings.Split(host, ",")
|
||||
for _, h := range hosts {
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
template.IPAddresses = append(template.IPAddresses, ip)
|
||||
} else {
|
||||
template.DNSNames = append(template.DNSNames, h)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -16,7 +16,7 @@ type Client struct {
|
||||
// CertificateStore maps hostnames to certificates.
|
||||
// It is used to determine which certificate to use when the server requests
|
||||
// a certificate.
|
||||
CertificateStore CertificateStore
|
||||
CertificateStore ClientCertificateStore
|
||||
|
||||
// CheckRedirect, if not nil, will be called to determine whether
|
||||
// to follow a redirect.
|
||||
@ -28,7 +28,7 @@ type Client struct {
|
||||
// The returned certificate will be used when sending the request again.
|
||||
// If the certificate is nil, the request will not be sent again and
|
||||
// the response will be returned.
|
||||
GetCertificate func(hostname string, store *CertificateStore) *tls.Certificate
|
||||
GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
|
||||
|
||||
// TrustCertificate, if not nil, will be called to determine whether the
|
||||
// client should trust the given certificate.
|
||||
@ -61,7 +61,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
return req.Certificate, nil
|
||||
}
|
||||
// If we have already stored the certificate, return it
|
||||
if cert, err := c.CertificateStore.Lookup(hostname(req.Host)); err == nil {
|
||||
if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
|
||||
return cert, nil
|
||||
}
|
||||
return &tls.Certificate{}, nil
|
||||
@ -111,7 +111,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
return resp, nil
|
||||
}
|
||||
if c.GetCertificate != nil {
|
||||
if cert := c.GetCertificate(hostname(req.Host), &c.CertificateStore); cert != nil {
|
||||
if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
|
||||
req.Certificate = cert
|
||||
return c.Do(req)
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@ -47,22 +46,27 @@ func init() {
|
||||
}
|
||||
return err
|
||||
}
|
||||
client.GetCertificate = func(hostname string, store *gemini.CertificateStore) *tls.Certificate {
|
||||
client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate {
|
||||
// If the certificate is in the store, return it
|
||||
if cert, err := store.Lookup(hostname); err == nil {
|
||||
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
|
||||
return cert
|
||||
}
|
||||
// Otherwise, generate a certificate
|
||||
fmt.Println("Generating client certificate for", hostname)
|
||||
fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path)
|
||||
duration := time.Hour
|
||||
cert, err := gemini.NewCertificate(hostname, duration)
|
||||
cert, err := gemini.NewCertificate("", duration)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// Store and return the certificate
|
||||
store.Add(hostname, cert)
|
||||
store.Add(req.URL.Hostname()+req.URL.Path, cert)
|
||||
return &cert
|
||||
}
|
||||
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
|
||||
fmt.Printf("%s: ", prompt)
|
||||
scanner.Scan()
|
||||
return scanner.Text(), true
|
||||
}
|
||||
}
|
||||
|
||||
// sendRequest sends a request to the given URL.
|
||||
|
@ -49,16 +49,16 @@ func init() {
|
||||
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
|
||||
return knownHosts.Lookup(hostname, cert)
|
||||
}
|
||||
DefaultClient.GetCertificate = func(hostname string, store *CertificateStore) *tls.Certificate {
|
||||
if cert, err := store.Lookup(hostname); err == nil {
|
||||
DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate {
|
||||
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
|
||||
return cert
|
||||
}
|
||||
duration := time.Hour
|
||||
cert, err := NewCertificate(hostname, duration)
|
||||
cert, err := NewCertificate("", duration)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
store.Add(hostname, cert)
|
||||
store.Add(req.URL.Hostname()+req.URL.Path, cert)
|
||||
return &cert
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user