Restrict client certificates to certain paths
This commit is contained in:
parent
fc72224ce9
commit
d1dcf070ff
24
cert.go
24
cert.go
@ -8,6 +8,7 @@ import (
|
|||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -70,6 +71,27 @@ func (c *CertificateStore) Load(path string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ClientCertificateStore struct {
|
||||||
|
CertificateStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) {
|
||||||
|
urlPath = path.Clean(urlPath)
|
||||||
|
for {
|
||||||
|
cert, err := c.CertificateStore.Lookup(hostname + urlPath)
|
||||||
|
switch err {
|
||||||
|
case ErrCertificateExpired, nil:
|
||||||
|
return cert, err
|
||||||
|
}
|
||||||
|
slash := urlPath == "/"
|
||||||
|
urlPath = path.Dir(urlPath)
|
||||||
|
if slash && urlPath == "/" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrCertificateUnknown
|
||||||
|
}
|
||||||
|
|
||||||
// NewCertificate creates and returns a new parsed certificate.
|
// NewCertificate creates and returns a new parsed certificate.
|
||||||
func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
|
func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) {
|
||||||
crt, priv, err := newX509KeyPair(host, duration)
|
crt, priv, err := newX509KeyPair(host, duration)
|
||||||
@ -114,6 +136,7 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
|
|||||||
BasicConstraintsValid: true,
|
BasicConstraintsValid: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if host != "" {
|
||||||
hosts := strings.Split(host, ",")
|
hosts := strings.Split(host, ",")
|
||||||
for _, h := range hosts {
|
for _, h := range hosts {
|
||||||
if ip := net.ParseIP(h); ip != nil {
|
if ip := net.ParseIP(h); ip != nil {
|
||||||
@ -122,6 +145,7 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry
|
|||||||
template.DNSNames = append(template.DNSNames, h)
|
template.DNSNames = append(template.DNSNames, h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
crt, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv)
|
crt, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -16,7 +16,7 @@ type Client struct {
|
|||||||
// CertificateStore maps hostnames to certificates.
|
// CertificateStore maps hostnames to certificates.
|
||||||
// It is used to determine which certificate to use when the server requests
|
// It is used to determine which certificate to use when the server requests
|
||||||
// a certificate.
|
// a certificate.
|
||||||
CertificateStore CertificateStore
|
CertificateStore ClientCertificateStore
|
||||||
|
|
||||||
// CheckRedirect, if not nil, will be called to determine whether
|
// CheckRedirect, if not nil, will be called to determine whether
|
||||||
// to follow a redirect.
|
// to follow a redirect.
|
||||||
@ -28,7 +28,7 @@ type Client struct {
|
|||||||
// The returned certificate will be used when sending the request again.
|
// The returned certificate will be used when sending the request again.
|
||||||
// If the certificate is nil, the request will not be sent again and
|
// If the certificate is nil, the request will not be sent again and
|
||||||
// the response will be returned.
|
// the response will be returned.
|
||||||
GetCertificate func(hostname string, store *CertificateStore) *tls.Certificate
|
GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate
|
||||||
|
|
||||||
// TrustCertificate, if not nil, will be called to determine whether the
|
// TrustCertificate, if not nil, will be called to determine whether the
|
||||||
// client should trust the given certificate.
|
// client should trust the given certificate.
|
||||||
@ -61,7 +61,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
|||||||
return req.Certificate, nil
|
return req.Certificate, nil
|
||||||
}
|
}
|
||||||
// If we have already stored the certificate, return it
|
// If we have already stored the certificate, return it
|
||||||
if cert, err := c.CertificateStore.Lookup(hostname(req.Host)); err == nil {
|
if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil {
|
||||||
return cert, nil
|
return cert, nil
|
||||||
}
|
}
|
||||||
return &tls.Certificate{}, nil
|
return &tls.Certificate{}, nil
|
||||||
@ -111,7 +111,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
if c.GetCertificate != nil {
|
if c.GetCertificate != nil {
|
||||||
if cert := c.GetCertificate(hostname(req.Host), &c.CertificateStore); cert != nil {
|
if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil {
|
||||||
req.Certificate = cert
|
req.Certificate = cert
|
||||||
return c.Do(req)
|
return c.Do(req)
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -47,22 +46,27 @@ func init() {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
client.GetCertificate = func(hostname string, store *gemini.CertificateStore) *tls.Certificate {
|
client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate {
|
||||||
// If the certificate is in the store, return it
|
// If the certificate is in the store, return it
|
||||||
if cert, err := store.Lookup(hostname); err == nil {
|
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
|
||||||
return cert
|
return cert
|
||||||
}
|
}
|
||||||
// Otherwise, generate a certificate
|
// Otherwise, generate a certificate
|
||||||
fmt.Println("Generating client certificate for", hostname)
|
fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path)
|
||||||
duration := time.Hour
|
duration := time.Hour
|
||||||
cert, err := gemini.NewCertificate(hostname, duration)
|
cert, err := gemini.NewCertificate("", duration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// Store and return the certificate
|
// Store and return the certificate
|
||||||
store.Add(hostname, cert)
|
store.Add(req.URL.Hostname()+req.URL.Path, cert)
|
||||||
return &cert
|
return &cert
|
||||||
}
|
}
|
||||||
|
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
|
||||||
|
fmt.Printf("%s: ", prompt)
|
||||||
|
scanner.Scan()
|
||||||
|
return scanner.Text(), true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendRequest sends a request to the given URL.
|
// sendRequest sends a request to the given URL.
|
||||||
|
@ -49,16 +49,16 @@ func init() {
|
|||||||
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
|
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
|
||||||
return knownHosts.Lookup(hostname, cert)
|
return knownHosts.Lookup(hostname, cert)
|
||||||
}
|
}
|
||||||
DefaultClient.GetCertificate = func(hostname string, store *CertificateStore) *tls.Certificate {
|
DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate {
|
||||||
if cert, err := store.Lookup(hostname); err == nil {
|
if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil {
|
||||||
return cert
|
return cert
|
||||||
}
|
}
|
||||||
duration := time.Hour
|
duration := time.Hour
|
||||||
cert, err := NewCertificate(hostname, duration)
|
cert, err := NewCertificate("", duration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
store.Add(hostname, cert)
|
store.Add(req.URL.Hostname()+req.URL.Path, cert)
|
||||||
return &cert
|
return &cert
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user