diff --git a/cert.go b/cert.go index a19a2bf..4f11fe6 100644 --- a/cert.go +++ b/cert.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "math/big" "net" + "path" "path/filepath" "strings" "time" @@ -70,6 +71,27 @@ func (c *CertificateStore) Load(path string) error { return nil } +type ClientCertificateStore struct { + CertificateStore +} + +func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) { + urlPath = path.Clean(urlPath) + for { + cert, err := c.CertificateStore.Lookup(hostname + urlPath) + switch err { + case ErrCertificateExpired, nil: + return cert, err + } + slash := urlPath == "/" + urlPath = path.Dir(urlPath) + if slash && urlPath == "/" { + break + } + } + return nil, ErrCertificateUnknown +} + // NewCertificate creates and returns a new parsed certificate. func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) { crt, priv, err := newX509KeyPair(host, duration) @@ -114,12 +136,14 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry BasicConstraintsValid: true, } - hosts := strings.Split(host, ",") - for _, h := range hosts { - if ip := net.ParseIP(h); ip != nil { - template.IPAddresses = append(template.IPAddresses, ip) - } else { - template.DNSNames = append(template.DNSNames, h) + if host != "" { + hosts := strings.Split(host, ",") + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } } } diff --git a/client.go b/client.go index d8d88ae..91402b1 100644 --- a/client.go +++ b/client.go @@ -16,7 +16,7 @@ type Client struct { // CertificateStore maps hostnames to certificates. // It is used to determine which certificate to use when the server requests // a certificate. - CertificateStore CertificateStore + CertificateStore ClientCertificateStore // CheckRedirect, if not nil, will be called to determine whether // to follow a redirect. @@ -28,7 +28,7 @@ type Client struct { // The returned certificate will be used when sending the request again. // If the certificate is nil, the request will not be sent again and // the response will be returned. - GetCertificate func(hostname string, store *CertificateStore) *tls.Certificate + GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate // TrustCertificate, if not nil, will be called to determine whether the // client should trust the given certificate. @@ -61,7 +61,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { return req.Certificate, nil } // If we have already stored the certificate, return it - if cert, err := c.CertificateStore.Lookup(hostname(req.Host)); err == nil { + if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil { return cert, nil } return &tls.Certificate{}, nil @@ -111,7 +111,7 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { return resp, nil } if c.GetCertificate != nil { - if cert := c.GetCertificate(hostname(req.Host), &c.CertificateStore); cert != nil { + if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil { req.Certificate = cert return c.Do(req) } diff --git a/examples/client.go b/examples/client.go index 5180f5c..e2e1669 100644 --- a/examples/client.go +++ b/examples/client.go @@ -8,7 +8,6 @@ import ( "crypto/x509" "fmt" "io/ioutil" - "net/url" "os" "time" @@ -47,22 +46,27 @@ func init() { } return err } - client.GetCertificate = func(hostname string, store *gemini.CertificateStore) *tls.Certificate { + client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate { // If the certificate is in the store, return it - if cert, err := store.Lookup(hostname); err == nil { + if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil { return cert } // Otherwise, generate a certificate - fmt.Println("Generating client certificate for", hostname) + fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path) duration := time.Hour - cert, err := gemini.NewCertificate(hostname, duration) + cert, err := gemini.NewCertificate("", duration) if err != nil { return nil } // Store and return the certificate - store.Add(hostname, cert) + store.Add(req.URL.Hostname()+req.URL.Path, cert) return &cert } + client.GetInput = func(prompt string, sensitive bool) (string, bool) { + fmt.Printf("%s: ", prompt) + scanner.Scan() + return scanner.Text(), true + } } // sendRequest sends a request to the given URL. diff --git a/gemini.go b/gemini.go index f3e0c81..635622c 100644 --- a/gemini.go +++ b/gemini.go @@ -49,16 +49,16 @@ func init() { defaultClientOnce.Do(func() { knownHosts.LoadDefault() }) return knownHosts.Lookup(hostname, cert) } - DefaultClient.GetCertificate = func(hostname string, store *CertificateStore) *tls.Certificate { - if cert, err := store.Lookup(hostname); err == nil { + DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate { + if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil { return cert } duration := time.Hour - cert, err := NewCertificate(hostname, duration) + cert, err := NewCertificate("", duration) if err != nil { return nil } - store.Add(hostname, cert) + store.Add(req.URL.Hostname()+req.URL.Path, cert) return &cert } }