diff --git a/cert.go b/cert.go index da77478..6ecf119 100644 --- a/cert.go +++ b/cert.go @@ -20,9 +20,9 @@ type CertificateStore struct { store map[string]tls.Certificate } -// Add adds a certificate for the given hostname to the store. +// Add adds a certificate for the given scope to the store. // It tries to parse the certificate if it is not already parsed. -func (c *CertificateStore) Add(hostname string, cert tls.Certificate) { +func (c *CertificateStore) Add(scope string, cert tls.Certificate) { if c.store == nil { c.store = map[string]tls.Certificate{} } @@ -33,7 +33,7 @@ func (c *CertificateStore) Add(hostname string, cert tls.Certificate) { cert.Leaf = parsed } } - c.store[hostname] = cert + c.store[scope] = cert } // Lookup returns the certificate for the given hostname. @@ -49,6 +49,22 @@ func (c *CertificateStore) Lookup(hostname string) (*tls.Certificate, error) { return &cert, nil } +// lookup returns the certificate for the given hostname + path. +func (c *CertificateStore) lookup(scope string) (*tls.Certificate, error) { + for { + cert, err := c.Lookup(scope) + switch err { + case ErrCertificateExpired, nil: + return cert, err + } + scope = path.Dir(scope) + if scope == "." { + break + } + } + return nil, ErrCertificateUnknown +} + // Load loads certificates from the given path. // The path should lead to a directory containing certificates and private keys // in the form hostname.crt and hostname.key. @@ -71,36 +87,16 @@ func (c *CertificateStore) Load(path string) error { return nil } -type ClientCertificateStore struct { - CertificateStore +// CertificateOptions configures how a certificate is created. +type CertificateOptions struct { + IPAddresses []net.IP + DNSNames []string + Duration time.Duration } -func (c *ClientCertificateStore) Lookup(hostname, urlPath string) (*tls.Certificate, error) { - urlPath = path.Clean(urlPath) - if urlPath == "." { - urlPath = "/" - } - if urlPath[0] != '/' { - urlPath = "/" + urlPath - } - for { - cert, err := c.CertificateStore.Lookup(hostname + urlPath) - switch err { - case ErrCertificateExpired, nil: - return cert, err - } - slash := urlPath == "/" - urlPath = path.Dir(urlPath) - if slash && urlPath == "/" { - break - } - } - return nil, ErrCertificateUnknown -} - -// NewCertificate creates and returns a new parsed certificate. -func NewCertificate(host string, duration time.Duration) (tls.Certificate, error) { - crt, priv, err := newX509KeyPair(host, duration) +// CreateCertificate creates a new TLS certificate. +func CreateCertificate(options CertificateOptions) (tls.Certificate, error) { + crt, priv, err := newX509KeyPair(options) if err != nil { return tls.Certificate{}, err } @@ -112,7 +108,7 @@ func NewCertificate(host string, duration time.Duration) (tls.Certificate, error } // newX509KeyPair creates and returns a new certificate and private key. -func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, crypto.PrivateKey, error) { +func newX509KeyPair(options CertificateOptions) (*x509.Certificate, crypto.PrivateKey, error) { // Generate an ED25519 private key _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { @@ -131,7 +127,7 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry } notBefore := time.Now() - notAfter := notBefore.Add(duration) + notAfter := notBefore.Add(options.Duration) template := x509.Certificate{ SerialNumber: serialNumber, @@ -140,17 +136,8 @@ func newX509KeyPair(host string, duration time.Duration) (*x509.Certificate, cry KeyUsage: keyUsage, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, - } - - if host != "" { - hosts := strings.Split(host, ",") - for _, h := range hosts { - if ip := net.ParseIP(h); ip != nil { - template.IPAddresses = append(template.IPAddresses, ip) - } else { - template.DNSNames = append(template.DNSNames, h) - } - } + IPAddresses: options.IPAddresses, + DNSNames: options.DNSNames, } crt, err := x509.CreateCertificate(rand.Reader, &template, &template, public, priv) diff --git a/client.go b/client.go index c4ffb59..5ab3ded 100644 --- a/client.go +++ b/client.go @@ -6,37 +6,38 @@ import ( "crypto/x509" "net" "net/url" + "strings" ) -// Client represents a Gemini client. +// Client is a Gemini client. type Client struct { - // KnownHosts is a list of known hosts that the client trusts. + // KnownHosts is a list of known hosts. KnownHosts KnownHosts - // CertificateStore maps hostnames to certificates. - // It is used to determine which certificate to use when the server requests - // a certificate. - CertificateStore ClientCertificateStore + // Certificates stores client-side certificates. + Certificates CertificateStore - // CheckRedirect, if not nil, will be called to determine whether - // to follow a redirect. + // GetInput is called to retrieve input when the server requests it. + // If GetInput is nil or returns false, no input will be sent and + // the response will be returned. + GetInput func(prompt string, sensitive bool) (input string, ok bool) + + // CheckRedirect determines whether to follow a redirect. // If CheckRedirect is nil, a default policy of no more than 5 consecutive // redirects will be enforced. CheckRedirect func(req *Request, via []*Request) error - // GetInput, if not nil, will be called to retrieve input when the server - // requests it. - GetInput func(prompt string, sensitive bool) (string, bool) + // CreateCertificate is called to generate a certificate upon + // the request of a server. + // If CreateCertificate is nil or the returned error is not nil, + // the request will not be sent again and the response will be returned. + CreateCertificate func(hostname, path string) (tls.Certificate, error) - // GetCertificate, if not nil, will be called when a server requests a certificate. - // The returned certificate will be used when sending the request again. - // If the certificate is nil, the request will not be sent again and - // the response will be returned. - GetCertificate func(req *Request, store *ClientCertificateStore) *tls.Certificate - - // TrustCertificate, if not nil, will be called to determine whether the - // client should trust the given certificate. - // If error is not nil, the connection will be aborted. + // TrustCertificate determines whether the client should trust + // the provided certificate. + // If the returned error is not nil, the connection will be aborted. + // If TrustCertificate is nil, the client will check KnownHosts + // for the certificate. TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error } @@ -59,38 +60,18 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { config := &tls.Config{ InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, - GetClientCertificate: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { - // Request certificates take precedence over client certificates - if req.Certificate != nil { - return req.Certificate, nil - } - // If we have already stored the certificate, return it - if cert, err := c.CertificateStore.Lookup(hostname(req.Host), req.URL.Path); err == nil { - return cert, nil - } - return &tls.Certificate{}, nil + GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { + return c.getClientCertificate(req) }, VerifyConnection: func(cs tls.ConnectionState) error { - cert := cs.PeerCertificates[0] - // Verify the hostname - if err := verifyHostname(cert, hostname(req.Host)); err != nil { - return err - } - // Check that the client trusts the certificate - if c.TrustCertificate == nil { - if err := c.KnownHosts.Lookup(hostname(req.Host), cert); err != nil { - return err - } - } else if err := c.TrustCertificate(hostname(req.Host), cert, &c.KnownHosts); err != nil { - return err - } - return nil + return c.verifyConnection(req, cs) }, } conn, err := tls.Dial("tcp", req.Host, config) if err != nil { return nil, err } + // TODO: Set connection deadline // Write the request w := bufio.NewWriter(conn) @@ -104,24 +85,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { if err := resp.read(conn); err != nil { return nil, err } - // Store connection information + // Store connection state resp.TLS = conn.ConnectionState() - // Resend the request with a certificate if the server responded - // with CertificateRequired - if resp.Status == StatusCertificateRequired { + switch { + case resp.Status == StatusCertificateRequired: // Check to see if a certificate was already provided to prevent an infinite loop if req.Certificate != nil { return resp, nil } - if c.GetCertificate != nil { - if cert := c.GetCertificate(req, &c.CertificateStore); cert != nil { - req.Certificate = cert - return c.Do(req) + + hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/") + if c.CreateCertificate != nil { + cert, err := c.CreateCertificate(hostname, path) + if err != nil { + return resp, err } + c.Certificates.Add(hostname+path, cert) + return c.do(req, via) } return resp, ErrCertificateRequired - } else if resp.Status.Class() == StatusClassRedirect { + + case resp.Status.Class() == StatusClassInput: + if c.GetInput != nil { + input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput) + if ok { + req.URL.ForceQuery = true + req.URL.RawQuery = url.QueryEscape(input) + return c.do(req, via) + } + } + return resp, ErrInputRequired + + case resp.Status.Class() == StatusClassRedirect: if via == nil { via = []*Request{} } @@ -146,27 +142,44 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { return resp, ErrTooManyRedirects } return c.do(redirect, via) - } else if resp.Status.Class() == StatusClassInput { - if c.GetInput != nil { - input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput) - if ok { - req.URL.ForceQuery = true - req.URL.RawQuery = url.QueryEscape(input) - return c.do(req, via) - } - } - return resp, ErrInputRequired } resp.Request = req return resp, nil } -// hostname returns the host without the port. -func hostname(host string) string { - hostname, _, err := net.SplitHostPort(host) - if err != nil { - return host +func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) { + // Request certificates have the highest precedence + if req.Certificate != nil { + return req.Certificate, nil } - return hostname + hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/") + if cert, err := c.Certificates.lookup(hostname + path); err == nil { + // Remember the certificate used + req.Certificate = cert + return cert, nil + } + return &tls.Certificate{}, nil +} + +func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error { + // Verify the hostname + var hostname string + if host, _, err := net.SplitHostPort(req.Host); err == nil { + hostname = host + } else { + hostname = req.Host + } + cert := cs.PeerCertificates[0] + if err := verifyHostname(cert, hostname); err != nil { + return err + } + // Check that the client trusts the certificate + var err error + if c.TrustCertificate != nil { + return c.TrustCertificate(hostname, cert, &c.KnownHosts) + } else { + err = c.KnownHosts.Lookup(hostname, cert) + } + return err } diff --git a/examples/client.go b/examples/client.go index 169f726..71e7915 100644 --- a/examples/client.go +++ b/examples/client.go @@ -46,21 +46,11 @@ func init() { } return err } - client.GetCertificate = func(req *gemini.Request, store *gemini.ClientCertificateStore) *tls.Certificate { - // If the certificate is in the store, return it - if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil { - return cert - } - // Otherwise, generate a certificate - fmt.Println("Generating client certificate for", req.URL.Hostname()+req.URL.Path) - duration := time.Hour - cert, err := gemini.NewCertificate("", duration) - if err != nil { - return nil - } - // Store and return the certificate - store.Add(req.URL.Hostname()+req.URL.Path, cert) - return &cert + client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) { + fmt.Println("Generating client certificate for", hostname, path) + return gemini.CreateCertificate(gemini.CertificateOptions{ + Duration: time.Hour, + }) } client.GetInput = func(prompt string, sensitive bool) (string, bool) { fmt.Printf("%s: ", prompt) @@ -69,8 +59,7 @@ func init() { } } -// sendRequest sends a request to the given URL. -func sendRequest(req *gemini.Request) error { +func doRequest(req *gemini.Request) error { resp, err := client.Do(req) if err != nil { return err @@ -131,7 +120,7 @@ func main() { os.Exit(1) } - if err := sendRequest(req); err != nil { + if err := doRequest(req); err != nil { fmt.Println(err) os.Exit(1) } diff --git a/examples/server.go b/examples/server.go index 9ef4b8d..11ea9ed 100644 --- a/examples/server.go +++ b/examples/server.go @@ -29,7 +29,10 @@ func main() { fallthrough case gmi.ErrCertificateUnknown: // Generate a certificate if one does not exist. - cert, err := gmi.NewCertificate(hostname, time.Minute) + cert, err := gmi.CreateCertificate(gmi.CertificateOptions{ + DNSNames: []string{hostname}, + Duration: time.Hour, + }) if err != nil { // Failed to generate new certificate, abort return nil diff --git a/gemini.go b/gemini.go index a381712..8a74422 100644 --- a/gemini.go +++ b/gemini.go @@ -51,16 +51,9 @@ func init() { defaultClientOnce.Do(func() { knownHosts.LoadDefault() }) return knownHosts.Lookup(hostname, cert) } - DefaultClient.GetCertificate = func(req *Request, store *ClientCertificateStore) *tls.Certificate { - if cert, err := store.Lookup(req.URL.Hostname(), req.URL.Path); err == nil { - return cert - } - duration := time.Hour - cert, err := NewCertificate("", duration) - if err != nil { - return nil - } - store.Add(req.URL.Hostname()+req.URL.Path, cert) - return &cert + DefaultClient.CreateCertificate = func(hostname, path string) (tls.Certificate, error) { + return CreateCertificate(CertificateOptions{ + Duration: time.Hour, + }) } } diff --git a/status.go b/status.go index 2a4d0d3..4c50f89 100644 --- a/status.go +++ b/status.go @@ -24,6 +24,18 @@ const ( StatusCertificateNotValid Status = 62 ) +// Status code categories. +type StatusClass int + +const ( + StatusClassInput StatusClass = 1 + StatusClassSuccess StatusClass = 2 + StatusClassRedirect StatusClass = 3 + StatusClassTemporaryFailure StatusClass = 4 + StatusClassPermanentFailure StatusClass = 5 + StatusClassCertificateRequired StatusClass = 6 +) + // Class returns the status class for this status code. func (s Status) Class() StatusClass { return StatusClass(s / 10) @@ -71,15 +83,3 @@ func (s Status) Message() string { } return "" } - -// Status code categories. -type StatusClass int - -const ( - StatusClassInput StatusClass = 1 - StatusClassSuccess StatusClass = 2 - StatusClassRedirect StatusClass = 3 - StatusClassTemporaryFailure StatusClass = 4 - StatusClassPermanentFailure StatusClass = 5 - StatusClassCertificateRequired StatusClass = 6 -)