client: Add GetCertificate callback
This commit is contained in:
parent
611a7d54c0
commit
846fa2ac41
55
client.go
55
client.go
@ -7,9 +7,7 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -20,9 +18,6 @@ type Client struct {
|
||||
// KnownHosts is a list of known hosts.
|
||||
KnownHosts KnownHostsFile
|
||||
|
||||
// Certificates stores client-side certificates.
|
||||
Certificates CertificateDir
|
||||
|
||||
// Timeout specifies a time limit for requests made by this
|
||||
// Client. The timeout includes connection time and reading
|
||||
// the response body. The timer remains running after
|
||||
@ -46,11 +41,11 @@ type Client struct {
|
||||
// If CheckRedirect is nil, redirects will not be followed.
|
||||
CheckRedirect func(req *Request, via []*Request) error
|
||||
|
||||
// CreateCertificate is called to generate a certificate upon
|
||||
// GetCertificate is called to retrieve a certificate upon
|
||||
// the request of a server.
|
||||
// If CreateCertificate is nil or the returned error is not nil,
|
||||
// If GetCertificate is nil or the returned error is not nil,
|
||||
// the request will not be sent again and the response will be returned.
|
||||
CreateCertificate func(scope, path string) (tls.Certificate, error)
|
||||
GetCertificate func(scope, path string) (tls.Certificate, error)
|
||||
|
||||
// TrustCertificate is called to determine whether the client
|
||||
// should trust a certificate it has not seen before.
|
||||
@ -62,8 +57,6 @@ type Client struct {
|
||||
// If TrustCertificate returns TrustAlways, the certificate will also be
|
||||
// written to the known hosts file.
|
||||
TrustCertificate func(hostname string, cert *x509.Certificate) Trust
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Get performs a Gemini request for the given url.
|
||||
@ -77,9 +70,6 @@ func (c *Client) Get(url string) (*Response, error) {
|
||||
|
||||
// Do performs a Gemini request and returns a Gemini response.
|
||||
func (c *Client) Do(req *Request) (*Response, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return c.do(req, nil)
|
||||
}
|
||||
|
||||
@ -96,7 +86,10 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
InsecureSkipVerify: true,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||
return c.getClientCertificate(req)
|
||||
if req.Certificate != nil {
|
||||
return req.Certificate, nil
|
||||
}
|
||||
return &tls.Certificate{}, nil
|
||||
},
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
return c.verifyConnection(req, cs)
|
||||
@ -137,13 +130,11 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
}
|
||||
|
||||
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
|
||||
if c.CreateCertificate != nil {
|
||||
cert, err := c.CreateCertificate(hostname, path)
|
||||
if c.GetCertificate != nil {
|
||||
cert, err := c.GetCertificate(hostname, path)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
c.Certificates.Add(hostname+path, cert)
|
||||
c.Certificates.Write(hostname+path, cert)
|
||||
req.Certificate = &cert
|
||||
return c.do(req, via)
|
||||
}
|
||||
@ -185,34 +176,6 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
|
||||
// Request certificates have the highest precedence
|
||||
if req.Certificate != nil {
|
||||
return req.Certificate, nil
|
||||
}
|
||||
|
||||
// Search recursively for the certificate
|
||||
scope := req.URL.Hostname() + strings.TrimSuffix(req.URL.Path, "/")
|
||||
for {
|
||||
cert, ok := c.Certificates.Lookup(scope)
|
||||
if ok {
|
||||
// Ensure that the certificate is not expired
|
||||
if cert.Leaf != nil && !time.Now().After(cert.Leaf.NotAfter) {
|
||||
// Store the certificate
|
||||
req.Certificate = &cert
|
||||
return &cert, nil
|
||||
}
|
||||
break
|
||||
}
|
||||
scope = path.Dir(scope)
|
||||
if scope == "." {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &tls.Certificate{}, nil
|
||||
}
|
||||
|
||||
func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
|
||||
// Verify the hostname
|
||||
var hostname string
|
||||
|
@ -4,7 +4,6 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
@ -47,12 +46,6 @@ func init() {
|
||||
return gemini.TrustNone
|
||||
}
|
||||
}
|
||||
client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
|
||||
fmt.Println("Generating client certificate for", hostname, path)
|
||||
return gemini.CreateCertificate(gemini.CertificateOptions{
|
||||
Duration: time.Hour,
|
||||
})
|
||||
}
|
||||
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
|
||||
fmt.Printf("%s: ", prompt)
|
||||
scanner.Scan()
|
||||
|
Loading…
Reference in New Issue
Block a user