client: Add GetCertificate callback
This commit is contained in:
		
							parent
							
								
									611a7d54c0
								
							
						
					
					
						commit
						846fa2ac41
					
				
							
								
								
									
										55
									
								
								client.go
									
									
									
									
									
								
							
							
						
						
									
										55
									
								
								client.go
									
									
									
									
									
								
							@ -7,9 +7,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -20,9 +18,6 @@ type Client struct {
 | 
			
		||||
	// KnownHosts is a list of known hosts.
 | 
			
		||||
	KnownHosts KnownHostsFile
 | 
			
		||||
 | 
			
		||||
	// Certificates stores client-side certificates.
 | 
			
		||||
	Certificates CertificateDir
 | 
			
		||||
 | 
			
		||||
	// Timeout specifies a time limit for requests made by this
 | 
			
		||||
	// Client. The timeout includes connection time and reading
 | 
			
		||||
	// the response body. The timer remains running after
 | 
			
		||||
@ -46,11 +41,11 @@ type Client struct {
 | 
			
		||||
	// If CheckRedirect is nil, redirects will not be followed.
 | 
			
		||||
	CheckRedirect func(req *Request, via []*Request) error
 | 
			
		||||
 | 
			
		||||
	// CreateCertificate is called to generate a certificate upon
 | 
			
		||||
	// GetCertificate is called to retrieve a certificate upon
 | 
			
		||||
	// the request of a server.
 | 
			
		||||
	// If CreateCertificate is nil or the returned error is not nil,
 | 
			
		||||
	// If GetCertificate is nil or the returned error is not nil,
 | 
			
		||||
	// the request will not be sent again and the response will be returned.
 | 
			
		||||
	CreateCertificate func(scope, path string) (tls.Certificate, error)
 | 
			
		||||
	GetCertificate func(scope, path string) (tls.Certificate, error)
 | 
			
		||||
 | 
			
		||||
	// TrustCertificate is called to determine whether the client
 | 
			
		||||
	// should trust a certificate it has not seen before.
 | 
			
		||||
@ -62,8 +57,6 @@ type Client struct {
 | 
			
		||||
	// If TrustCertificate returns TrustAlways, the certificate will also be
 | 
			
		||||
	// written to the known hosts file.
 | 
			
		||||
	TrustCertificate func(hostname string, cert *x509.Certificate) Trust
 | 
			
		||||
 | 
			
		||||
	mu sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get performs a Gemini request for the given url.
 | 
			
		||||
@ -77,9 +70,6 @@ func (c *Client) Get(url string) (*Response, error) {
 | 
			
		||||
 | 
			
		||||
// Do performs a Gemini request and returns a Gemini response.
 | 
			
		||||
func (c *Client) Do(req *Request) (*Response, error) {
 | 
			
		||||
	c.mu.Lock()
 | 
			
		||||
	defer c.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	return c.do(req, nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -96,7 +86,10 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
 | 
			
		||||
		InsecureSkipVerify: true,
 | 
			
		||||
		MinVersion:         tls.VersionTLS12,
 | 
			
		||||
		GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
 | 
			
		||||
			return c.getClientCertificate(req)
 | 
			
		||||
			if req.Certificate != nil {
 | 
			
		||||
				return req.Certificate, nil
 | 
			
		||||
			}
 | 
			
		||||
			return &tls.Certificate{}, nil
 | 
			
		||||
		},
 | 
			
		||||
		VerifyConnection: func(cs tls.ConnectionState) error {
 | 
			
		||||
			return c.verifyConnection(req, cs)
 | 
			
		||||
@ -137,13 +130,11 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
 | 
			
		||||
		if c.CreateCertificate != nil {
 | 
			
		||||
			cert, err := c.CreateCertificate(hostname, path)
 | 
			
		||||
		if c.GetCertificate != nil {
 | 
			
		||||
			cert, err := c.GetCertificate(hostname, path)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return resp, err
 | 
			
		||||
			}
 | 
			
		||||
			c.Certificates.Add(hostname+path, cert)
 | 
			
		||||
			c.Certificates.Write(hostname+path, cert)
 | 
			
		||||
			req.Certificate = &cert
 | 
			
		||||
			return c.do(req, via)
 | 
			
		||||
		}
 | 
			
		||||
@ -185,34 +176,6 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
 | 
			
		||||
	return resp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
 | 
			
		||||
	// Request certificates have the highest precedence
 | 
			
		||||
	if req.Certificate != nil {
 | 
			
		||||
		return req.Certificate, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Search recursively for the certificate
 | 
			
		||||
	scope := req.URL.Hostname() + strings.TrimSuffix(req.URL.Path, "/")
 | 
			
		||||
	for {
 | 
			
		||||
		cert, ok := c.Certificates.Lookup(scope)
 | 
			
		||||
		if ok {
 | 
			
		||||
			// Ensure that the certificate is not expired
 | 
			
		||||
			if cert.Leaf != nil && !time.Now().After(cert.Leaf.NotAfter) {
 | 
			
		||||
				// Store the certificate
 | 
			
		||||
				req.Certificate = &cert
 | 
			
		||||
				return &cert, nil
 | 
			
		||||
			}
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		scope = path.Dir(scope)
 | 
			
		||||
		if scope == "." {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &tls.Certificate{}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
 | 
			
		||||
	// Verify the hostname
 | 
			
		||||
	var hostname string
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,6 @@ package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"crypto/tls"
 | 
			
		||||
	"crypto/x509"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
@ -47,12 +46,6 @@ func init() {
 | 
			
		||||
			return gemini.TrustNone
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
 | 
			
		||||
		fmt.Println("Generating client certificate for", hostname, path)
 | 
			
		||||
		return gemini.CreateCertificate(gemini.CertificateOptions{
 | 
			
		||||
			Duration: time.Hour,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	client.GetInput = func(prompt string, sensitive bool) (string, bool) {
 | 
			
		||||
		fmt.Printf("%s: ", prompt)
 | 
			
		||||
		scanner.Scan()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user