Decouple Client from KnownHostsFile

This commit is contained in:
Adnan Maolood
2020-12-17 19:50:26 -05:00
parent 176b260468
commit 28c5c857dc
3 changed files with 81 additions and 115 deletions

132
client.go
View File

@@ -5,7 +5,6 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"net"
"net/url"
"strings"
@@ -13,25 +12,19 @@ import (
)
// Client is a Gemini client.
//
// Clients are safe for concurrent use by multiple goroutines.
type Client struct {
// KnownHosts is a list of known hosts.
KnownHosts KnownHostsFile
// TrustCertificate is called to determine whether the client
// should trust the certificate provided by the server.
// If TrustCertificate is nil, the client will accept any certificate.
// If the returned error is not nil, the certificate will not be trusted
// and the request will be aborted.
TrustCertificate func(hostname string, cert *x509.Certificate) error
// Timeout specifies a time limit for requests made by this
// Client. The timeout includes connection time and reading
// the response body. The timer remains running after
// Get and Do return and will interrupt reading of the Response.Body.
//
// A Timeout of zero means no timeout.
Timeout time.Duration
// InsecureSkipTrust specifies whether the client should trust
// any certificate it receives without checking KnownHosts
// or calling TrustCertificate.
// Use with caution.
InsecureSkipTrust bool
// GetCertificate is called to retrieve a certificate upon
// the request of a server.
// If GetCertificate is nil or the returned error is not nil,
// the request will not be sent again and the response will be returned.
GetCertificate func(scope, path string) (tls.Certificate, error)
// GetInput is called to retrieve input when the server requests it.
// If GetInput is nil or returns false, no input will be sent and
@@ -42,25 +35,16 @@ type Client struct {
// If CheckRedirect is nil, redirects will not be followed.
CheckRedirect func(req *Request, via []*Request) error
// GetCertificate is called to retrieve a certificate upon
// the request of a server.
// If GetCertificate is nil or the returned error is not nil,
// the request will not be sent again and the response will be returned.
GetCertificate func(scope, path string) (tls.Certificate, error)
// TrustCertificate is called to determine whether the client
// should trust a certificate it has not seen before.
// If TrustCertificate is nil, the certificate will not be trusted
// and the connection will be aborted.
// Timeout specifies a time limit for requests made by this
// Client. The timeout includes connection time and reading
// the response body. The timer remains running after
// Get and Do return and will interrupt reading of the Response.Body.
//
// If TrustCertificate returns TrustOnce, the certificate will be added
// to the client's list of known hosts.
// If TrustCertificate returns TrustAlways, the certificate will also be
// written to the known hosts file.
TrustCertificate func(hostname string, cert *x509.Certificate) Trust
// A Timeout of zero means no timeout.
Timeout time.Duration
}
// Get performs a Gemini request for the given url.
// Get performs a Gemini request for the given URL.
func (c *Client) Get(url string) (*Response, error) {
req, err := NewRequest(url)
if err != nil {
@@ -130,34 +114,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
switch {
case resp.Status == StatusCertificateRequired:
if c.GetCertificate == nil {
break
}
// Check to see if a certificate was already provided to prevent an infinite loop
if req.Certificate != nil {
return resp, nil
break
}
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
if c.GetCertificate != nil {
cert, err := c.GetCertificate(hostname, path)
if err != nil {
return resp, err
}
req.Certificate = &cert
return c.do(req, via)
cert, err := c.GetCertificate(hostname, path)
if err != nil {
return resp, err
}
return resp, nil
req.Certificate = &cert
return c.do(req, via)
case resp.Status.Class() == StatusClassInput:
if c.GetInput != nil {
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok {
req.URL.ForceQuery = true
req.URL.RawQuery = QueryEscape(input)
return c.do(req, via)
}
if c.GetInput == nil {
break
}
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok {
req.URL.ForceQuery = true
req.URL.RawQuery = QueryEscape(input)
return c.do(req, via)
}
return resp, nil
case resp.Status.Class() == StatusClassRedirect:
if c.CheckRedirect == nil {
break
}
if via == nil {
via = []*Request{}
}
@@ -171,12 +160,10 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
redirect := NewRequestFromURL(target)
redirect.Context = req.Context
if c.CheckRedirect != nil {
if err := c.CheckRedirect(redirect, via); err != nil {
return resp, err
}
return c.do(redirect, via)
if err := c.CheckRedirect(redirect, via); err != nil {
return resp, err
}
return c.do(redirect, via)
}
return resp, nil
@@ -194,33 +181,10 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
if err := verifyHostname(cert, hostname); err != nil {
return err
}
if c.InsecureSkipTrust {
return nil
}
// Check the known hosts
knownHost, ok := c.KnownHosts.Lookup(hostname)
if !ok || !time.Now().Before(knownHost.Expires) {
// See if the client trusts the certificate
if c.TrustCertificate != nil {
switch c.TrustCertificate(hostname, cert) {
case TrustOnce:
fingerprint := NewFingerprint(cert.Raw, cert.NotAfter)
c.KnownHosts.Add(hostname, fingerprint)
return nil
case TrustAlways:
fingerprint := NewFingerprint(cert.Raw, cert.NotAfter)
c.KnownHosts.Add(hostname, fingerprint)
c.KnownHosts.Write(hostname, fingerprint)
return nil
}
}
return errors.New("gemini: certificate not trusted")
// See if the client trusts the certificate
if c.TrustCertificate != nil {
return c.TrustCertificate(hostname, cert)
}
fingerprint := NewFingerprint(cert.Raw, cert.NotAfter)
if knownHost.Hex == fingerprint.Hex {
return nil
}
return errors.New("gemini: fingerprint does not match")
return nil
}