client: Add GetCertificate callback
This commit is contained in:
parent
611a7d54c0
commit
846fa2ac41
55
client.go
55
client.go
@ -7,9 +7,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,9 +18,6 @@ type Client struct {
|
|||||||
// KnownHosts is a list of known hosts.
|
// KnownHosts is a list of known hosts.
|
||||||
KnownHosts KnownHostsFile
|
KnownHosts KnownHostsFile
|
||||||
|
|
||||||
// Certificates stores client-side certificates.
|
|
||||||
Certificates CertificateDir
|
|
||||||
|
|
||||||
// Timeout specifies a time limit for requests made by this
|
// Timeout specifies a time limit for requests made by this
|
||||||
// Client. The timeout includes connection time and reading
|
// Client. The timeout includes connection time and reading
|
||||||
// the response body. The timer remains running after
|
// the response body. The timer remains running after
|
||||||
@ -46,11 +41,11 @@ type Client struct {
|
|||||||
// If CheckRedirect is nil, redirects will not be followed.
|
// If CheckRedirect is nil, redirects will not be followed.
|
||||||
CheckRedirect func(req *Request, via []*Request) error
|
CheckRedirect func(req *Request, via []*Request) error
|
||||||
|
|
||||||
// CreateCertificate is called to generate a certificate upon
|
// GetCertificate is called to retrieve a certificate upon
|
||||||
// the request of a server.
|
// the request of a server.
|
||||||
// If CreateCertificate is nil or the returned error is not nil,
|
// If GetCertificate is nil or the returned error is not nil,
|
||||||
// the request will not be sent again and the response will be returned.
|
// the request will not be sent again and the response will be returned.
|
||||||
CreateCertificate func(scope, path string) (tls.Certificate, error)
|
GetCertificate func(scope, path string) (tls.Certificate, error)
|
||||||
|
|
||||||
// TrustCertificate is called to determine whether the client
|
// TrustCertificate is called to determine whether the client
|
||||||
// should trust a certificate it has not seen before.
|
// should trust a certificate it has not seen before.
|
||||||
@ -62,8 +57,6 @@ type Client struct {
|
|||||||
// If TrustCertificate returns TrustAlways, the certificate will also be
|
// If TrustCertificate returns TrustAlways, the certificate will also be
|
||||||
// written to the known hosts file.
|
// written to the known hosts file.
|
||||||
TrustCertificate func(hostname string, cert *x509.Certificate) Trust
|
TrustCertificate func(hostname string, cert *x509.Certificate) Trust
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get performs a Gemini request for the given url.
|
// Get performs a Gemini request for the given url.
|
||||||
@ -77,9 +70,6 @@ func (c *Client) Get(url string) (*Response, error) {
|
|||||||
|
|
||||||
// Do performs a Gemini request and returns a Gemini response.
|
// Do performs a Gemini request and returns a Gemini response.
|
||||||
func (c *Client) Do(req *Request) (*Response, error) {
|
func (c *Client) Do(req *Request) (*Response, error) {
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
return c.do(req, nil)
|
return c.do(req, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,7 +86,10 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
|||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
return c.getClientCertificate(req)
|
if req.Certificate != nil {
|
||||||
|
return req.Certificate, nil
|
||||||
|
}
|
||||||
|
return &tls.Certificate{}, nil
|
||||||
},
|
},
|
||||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||||
return c.verifyConnection(req, cs)
|
return c.verifyConnection(req, cs)
|
||||||
@ -137,13 +130,11 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
|
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
|
||||||
if c.CreateCertificate != nil {
|
if c.GetCertificate != nil {
|
||||||
cert, err := c.CreateCertificate(hostname, path)
|
cert, err := c.GetCertificate(hostname, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
c.Certificates.Add(hostname+path, cert)
|
|
||||||
c.Certificates.Write(hostname+path, cert)
|
|
||||||
req.Certificate = &cert
|
req.Certificate = &cert
|
||||||
return c.do(req, via)
|
return c.do(req, via)
|
||||||
}
|
}
|
||||||
@ -185,34 +176,6 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) getClientCertificate(req *Request) (*tls.Certificate, error) {
|
|
||||||
// Request certificates have the highest precedence
|
|
||||||
if req.Certificate != nil {
|
|
||||||
return req.Certificate, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Search recursively for the certificate
|
|
||||||
scope := req.URL.Hostname() + strings.TrimSuffix(req.URL.Path, "/")
|
|
||||||
for {
|
|
||||||
cert, ok := c.Certificates.Lookup(scope)
|
|
||||||
if ok {
|
|
||||||
// Ensure that the certificate is not expired
|
|
||||||
if cert.Leaf != nil && !time.Now().After(cert.Leaf.NotAfter) {
|
|
||||||
// Store the certificate
|
|
||||||
req.Certificate = &cert
|
|
||||||
return &cert, nil
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
scope = path.Dir(scope)
|
|
||||||
if scope == "." {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &tls.Certificate{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
|
func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
|
||||||
// Verify the hostname
|
// Verify the hostname
|
||||||
var hostname string
|
var hostname string
|
||||||
|
@ -4,7 +4,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
@ -47,12 +46,6 @@ func init() {
|
|||||||
return gemini.TrustNone
|
return gemini.TrustNone
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
|
|
||||||
fmt.Println("Generating client certificate for", hostname, path)
|
|
||||||
return gemini.CreateCertificate(gemini.CertificateOptions{
|
|
||||||
Duration: time.Hour,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
|
client.GetInput = func(prompt string, sensitive bool) (string, bool) {
|
||||||
fmt.Printf("%s: ", prompt)
|
fmt.Printf("%s: ", prompt)
|
||||||
scanner.Scan()
|
scanner.Scan()
|
||||||
|
Loading…
Reference in New Issue
Block a user