client: Add GetCertificate callback

This commit is contained in:
Adnan Maolood 2020-12-17 16:46:16 -05:00
parent 611a7d54c0
commit 846fa2ac41
2 changed files with 9 additions and 53 deletions

View File

@ -7,9 +7,7 @@ import (
"errors" "errors"
"net" "net"
"net/url" "net/url"
"path"
"strings" "strings"
"sync"
"time" "time"
) )
@ -20,9 +18,6 @@ type Client struct {
// KnownHosts is a list of known hosts. // KnownHosts is a list of known hosts.
KnownHosts KnownHostsFile KnownHosts KnownHostsFile
// Certificates stores client-side certificates.
Certificates CertificateDir
// Timeout specifies a time limit for requests made by this // Timeout specifies a time limit for requests made by this
// Client. The timeout includes connection time and reading // Client. The timeout includes connection time and reading
// the response body. The timer remains running after // the response body. The timer remains running after
@ -46,11 +41,11 @@ type Client struct {
// If CheckRedirect is nil, redirects will not be followed. // If CheckRedirect is nil, redirects will not be followed.
CheckRedirect func(req *Request, via []*Request) error CheckRedirect func(req *Request, via []*Request) error
// CreateCertificate is called to generate a certificate upon // GetCertificate is called to retrieve a certificate upon
// the request of a server. // the request of a server.
// If CreateCertificate is nil or the returned error is not nil, // If GetCertificate is nil or the returned error is not nil,
// the request will not be sent again and the response will be returned. // the request will not be sent again and the response will be returned.
CreateCertificate func(scope, path string) (tls.Certificate, error) GetCertificate func(scope, path string) (tls.Certificate, error)
// TrustCertificate is called to determine whether the client // TrustCertificate is called to determine whether the client
// should trust a certificate it has not seen before. // should trust a certificate it has not seen before.
@ -62,8 +57,6 @@ type Client struct {
// If TrustCertificate returns TrustAlways, the certificate will also be // If TrustCertificate returns TrustAlways, the certificate will also be
// written to the known hosts file. // written to the known hosts file.
TrustCertificate func(hostname string, cert *x509.Certificate) Trust TrustCertificate func(hostname string, cert *x509.Certificate) Trust
mu sync.Mutex
} }
// Get performs a Gemini request for the given url. // Get performs a Gemini request for the given url.
@ -77,9 +70,6 @@ func (c *Client) Get(url string) (*Response, error) {
// Do performs a Gemini request and returns a Gemini response. // Do performs a Gemini request and returns a Gemini response.
func (c *Client) Do(req *Request) (*Response, error) { func (c *Client) Do(req *Request) (*Response, error) {
c.mu.Lock()
defer c.mu.Unlock()
return c.do(req, nil) return c.do(req, nil)
} }
@ -96,7 +86,10 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
InsecureSkipVerify: true, InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return c.getClientCertificate(req) if req.Certificate != nil {
return req.Certificate, nil
}
return &tls.Certificate{}, nil
}, },
VerifyConnection: func(cs tls.ConnectionState) error { VerifyConnection: func(cs tls.ConnectionState) error {
return c.verifyConnection(req, cs) return c.verifyConnection(req, cs)
@ -137,13 +130,11 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
} }
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/") hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
if c.CreateCertificate != nil { if c.GetCertificate != nil {
cert, err := c.CreateCertificate(hostname, path) cert, err := c.GetCertificate(hostname, path)
if err != nil { if err != nil {
return resp, err return resp, err
} }
c.Certificates.Add(hostname+path, cert)
c.Certificates.Write(hostname+path, cert)
req.Certificate = &cert req.Certificate = &cert
return c.do(req, via) return c.do(req, via)
} }
@ -185,34 +176,6 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
return resp, nil return resp, nil
} }
func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
// Request certificates have the highest precedence
if req.Certificate != nil {
return req.Certificate, nil
}
// Search recursively for the certificate
scope := req.URL.Hostname() + strings.TrimSuffix(req.URL.Path, "/")
for {
cert, ok := c.Certificates.Lookup(scope)
if ok {
// Ensure that the certificate is not expired
if cert.Leaf != nil && !time.Now().After(cert.Leaf.NotAfter) {
// Store the certificate
req.Certificate = &cert
return &cert, nil
}
break
}
scope = path.Dir(scope)
if scope == "." {
break
}
}
return &tls.Certificate{}, nil
}
func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error { func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
// Verify the hostname // Verify the hostname
var hostname string var hostname string

View File

@ -4,7 +4,6 @@ package main
import ( import (
"bufio" "bufio"
"crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -47,12 +46,6 @@ func init() {
return gemini.TrustNone return gemini.TrustNone
} }
} }
client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
fmt.Println("Generating client certificate for", hostname, path)
return gemini.CreateCertificate(gemini.CertificateOptions{
Duration: time.Hour,
})
}
client.GetInput = func(prompt string, sensitive bool) (string, bool) { client.GetInput = func(prompt string, sensitive bool) (string, bool) {
fmt.Printf("%s: ", prompt) fmt.Printf("%s: ", prompt)
scanner.Scan() scanner.Scan()