Decouple Client from KnownHostsFile

This commit is contained in:
Adnan Maolood 2020-12-17 19:50:26 -05:00
parent 176b260468
commit 28c5c857dc
3 changed files with 81 additions and 115 deletions

104
client.go
View File

@ -5,7 +5,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"net" "net"
"net/url" "net/url"
"strings" "strings"
@ -13,25 +12,19 @@ import (
) )
// Client is a Gemini client. // Client is a Gemini client.
//
// Clients are safe for concurrent use by multiple goroutines.
type Client struct { type Client struct {
// KnownHosts is a list of known hosts. // TrustCertificate is called to determine whether the client
KnownHosts KnownHostsFile // should trust the certificate provided by the server.
// If TrustCertificate is nil, the client will accept any certificate.
// If the returned error is not nil, the certificate will not be trusted
// and the request will be aborted.
TrustCertificate func(hostname string, cert *x509.Certificate) error
// Timeout specifies a time limit for requests made by this // GetCertificate is called to retrieve a certificate upon
// Client. The timeout includes connection time and reading // the request of a server.
// the response body. The timer remains running after // If GetCertificate is nil or the returned error is not nil,
// Get and Do return and will interrupt reading of the Response.Body. // the request will not be sent again and the response will be returned.
// GetCertificate func(scope, path string) (tls.Certificate, error)
// A Timeout of zero means no timeout.
Timeout time.Duration
// InsecureSkipTrust specifies whether the client should trust
// any certificate it receives without checking KnownHosts
// or calling TrustCertificate.
// Use with caution.
InsecureSkipTrust bool
// GetInput is called to retrieve input when the server requests it. // GetInput is called to retrieve input when the server requests it.
// If GetInput is nil or returns false, no input will be sent and // If GetInput is nil or returns false, no input will be sent and
@ -42,25 +35,16 @@ type Client struct {
// If CheckRedirect is nil, redirects will not be followed. // If CheckRedirect is nil, redirects will not be followed.
CheckRedirect func(req *Request, via []*Request) error CheckRedirect func(req *Request, via []*Request) error
// GetCertificate is called to retrieve a certificate upon // Timeout specifies a time limit for requests made by this
// the request of a server. // Client. The timeout includes connection time and reading
// If GetCertificate is nil or the returned error is not nil, // the response body. The timer remains running after
// the request will not be sent again and the response will be returned. // Get and Do return and will interrupt reading of the Response.Body.
GetCertificate func(scope, path string) (tls.Certificate, error)
// TrustCertificate is called to determine whether the client
// should trust a certificate it has not seen before.
// If TrustCertificate is nil, the certificate will not be trusted
// and the connection will be aborted.
// //
// If TrustCertificate returns TrustOnce, the certificate will be added // A Timeout of zero means no timeout.
// to the client's list of known hosts. Timeout time.Duration
// If TrustCertificate returns TrustAlways, the certificate will also be
// written to the known hosts file.
TrustCertificate func(hostname string, cert *x509.Certificate) Trust
} }
// Get performs a Gemini request for the given url. // Get performs a Gemini request for the given URL.
func (c *Client) Get(url string) (*Response, error) { func (c *Client) Get(url string) (*Response, error) {
req, err := NewRequest(url) req, err := NewRequest(url)
if err != nil { if err != nil {
@ -130,34 +114,39 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
switch { switch {
case resp.Status == StatusCertificateRequired: case resp.Status == StatusCertificateRequired:
// Check to see if a certificate was already provided to prevent an infinite loop if c.GetCertificate == nil {
if req.Certificate != nil { break
return resp, nil
} }
// Check to see if a certificate was already provided to prevent an infinite loop
if req.Certificate != nil {
break
}
hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/") hostname, path := req.URL.Hostname(), strings.TrimSuffix(req.URL.Path, "/")
if c.GetCertificate != nil {
cert, err := c.GetCertificate(hostname, path) cert, err := c.GetCertificate(hostname, path)
if err != nil { if err != nil {
return resp, err return resp, err
} }
req.Certificate = &cert req.Certificate = &cert
return c.do(req, via) return c.do(req, via)
}
return resp, nil
case resp.Status.Class() == StatusClassInput: case resp.Status.Class() == StatusClassInput:
if c.GetInput != nil { if c.GetInput == nil {
break
}
input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput) input, ok := c.GetInput(resp.Meta, resp.Status == StatusSensitiveInput)
if ok { if ok {
req.URL.ForceQuery = true req.URL.ForceQuery = true
req.URL.RawQuery = QueryEscape(input) req.URL.RawQuery = QueryEscape(input)
return c.do(req, via) return c.do(req, via)
} }
}
return resp, nil
case resp.Status.Class() == StatusClassRedirect: case resp.Status.Class() == StatusClassRedirect:
if c.CheckRedirect == nil {
break
}
if via == nil { if via == nil {
via = []*Request{} via = []*Request{}
} }
@ -171,13 +160,11 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) {
redirect := NewRequestFromURL(target) redirect := NewRequestFromURL(target)
redirect.Context = req.Context redirect.Context = req.Context
if c.CheckRedirect != nil {
if err := c.CheckRedirect(redirect, via); err != nil { if err := c.CheckRedirect(redirect, via); err != nil {
return resp, err return resp, err
} }
return c.do(redirect, via) return c.do(redirect, via)
} }
}
return resp, nil return resp, nil
} }
@ -194,33 +181,10 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
if err := verifyHostname(cert, hostname); err != nil { if err := verifyHostname(cert, hostname); err != nil {
return err return err
} }
if c.InsecureSkipTrust {
return nil
}
// Check the known hosts
knownHost, ok := c.KnownHosts.Lookup(hostname)
if !ok || !time.Now().Before(knownHost.Expires) {
// See if the client trusts the certificate // See if the client trusts the certificate
if c.TrustCertificate != nil { if c.TrustCertificate != nil {
switch c.TrustCertificate(hostname, cert) { return c.TrustCertificate(hostname, cert)
case TrustOnce: }
fingerprint := NewFingerprint(cert.Raw, cert.NotAfter)
c.KnownHosts.Add(hostname, fingerprint)
return nil return nil
case TrustAlways:
fingerprint := NewFingerprint(cert.Raw, cert.NotAfter)
c.KnownHosts.Add(hostname, fingerprint)
c.KnownHosts.Write(hostname, fingerprint)
return nil
}
}
return errors.New("gemini: certificate not trusted")
}
fingerprint := NewFingerprint(cert.Raw, cert.NotAfter)
if knownHost.Hex == fingerprint.Hex {
return nil
}
return errors.New("gemini: fingerprint does not match")
} }

View File

@ -5,6 +5,7 @@ package main
import ( import (
"bufio" "bufio"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -25,43 +26,52 @@ Otherwise, this should be safe to trust.
[t]rust always; trust [o]nce; [a]bort [t]rust always; trust [o]nce; [a]bort
=> ` => `
var ( func main() {
scanner = bufio.NewScanner(os.Stdin) if len(os.Args) < 2 {
client = &gemini.Client{} fmt.Printf("usage: %s <url> [host]", os.Args[0])
) os.Exit(1)
}
// Load known hosts file
var knownHosts gemini.KnownHostsFile
if err := knownHosts.Load(filepath.Join(xdg.DataHome(), "gemini", "known_hosts")); err != nil {
log.Println(err)
}
scanner := bufio.NewScanner(os.Stdin)
var client gemini.Client
client.TrustCertificate = func(hostname string, cert *x509.Certificate) error {
knownHost, ok := knownHosts.Lookup(hostname)
if ok && time.Now().Before(knownHost.Expires) {
// Certificate is in known hosts file and is not expired
return nil
}
func init() {
client.Timeout = 30 * time.Second
client.KnownHosts.Load(filepath.Join(xdg.DataHome(), "gemini", "known_hosts"))
client.TrustCertificate = func(hostname string, cert *x509.Certificate) gemini.Trust {
fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter) fingerprint := gemini.NewFingerprint(cert.Raw, cert.NotAfter)
fmt.Printf(trustPrompt, hostname, fingerprint.Hex) fmt.Printf(trustPrompt, hostname, fingerprint.Hex)
scanner.Scan() scanner.Scan()
switch scanner.Text() { switch scanner.Text() {
case "t": case "t":
return gemini.TrustAlways knownHosts.Add(hostname, fingerprint)
knownHosts.Write(hostname, fingerprint)
return nil
case "o": case "o":
return gemini.TrustOnce knownHosts.Add(hostname, fingerprint)
return nil
default: default:
return gemini.TrustNone return errors.New("certificate not trusted")
} }
} }
client.GetInput = func(prompt string, sensitive bool) (string, bool) { client.GetInput = func(prompt string, sensitive bool) (string, bool) {
fmt.Printf("%s: ", prompt) fmt.Printf("%s ", prompt)
scanner.Scan() scanner.Scan()
return scanner.Text(), true return scanner.Text(), true
} }
}
func main() {
if len(os.Args) < 2 {
fmt.Printf("usage: %s gemini://... [host]", os.Args[0])
os.Exit(1)
}
// Do the request
url := os.Args[1] url := os.Args[1]
req, err := gemini.NewRequest(url) req, err := gemini.NewRequest(url)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@ -69,13 +79,13 @@ func main() {
if len(os.Args) == 3 { if len(os.Args) == 3 {
req.Host = os.Args[2] req.Host = os.Args[2]
} }
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
} }
// Handle response
if resp.Status.Class() == gemini.StatusClassSuccess { if resp.Status.Class() == gemini.StatusClassSuccess {
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
@ -84,6 +94,7 @@ func main() {
} }
fmt.Print(string(body)) fmt.Print(string(body))
} else { } else {
fmt.Printf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta) fmt.Printf("%d %s: %s\n", resp.Status, resp.Status.Message(), resp.Meta)
os.Exit(1)
} }
} }

View File

@ -12,15 +12,6 @@ import (
"time" "time"
) )
// Trust represents the trustworthiness of a certificate.
type Trust int
const (
TrustNone Trust = iota // The certificate is not trusted.
TrustOnce // The certificate is trusted once.
TrustAlways // The certificate is trusted always.
)
// KnownHosts maps hosts to fingerprints. // KnownHosts maps hosts to fingerprints.
type KnownHosts map[string]Fingerprint type KnownHosts map[string]Fingerprint