Refactor client.TrustCertificate workflow
This commit is contained in:
parent
b9cb7fe71d
commit
65a5065250
2
cert.go
2
cert.go
@ -16,7 +16,7 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CertificateStore maps hostnames to certificates.
|
||||
// CertificateStore maps certificate scopes to certificates.
|
||||
// The zero value of CertificateStore is an empty store ready to use.
|
||||
type CertificateStore struct {
|
||||
store map[string]tls.Certificate
|
||||
|
31
client.go
31
client.go
@ -43,12 +43,10 @@ type Client struct {
|
||||
// the request will not be sent again and the response will be returned.
|
||||
CreateCertificate func(hostname, path string) (tls.Certificate, error)
|
||||
|
||||
// TrustCertificate determines whether the client should trust
|
||||
// the provided certificate.
|
||||
// If the returned error is not nil, the connection will be aborted.
|
||||
// If TrustCertificate is nil, the client will check KnownHosts
|
||||
// for the certificate.
|
||||
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
|
||||
// TrustCertificate is called to determine whether the client
|
||||
// should trust a certificate it has not seen before.
|
||||
// If TrustCertificate is nil, the certificate will not be trusted.
|
||||
TrustCertificate func(hostname string, cert *x509.Certificate) Trust
|
||||
}
|
||||
|
||||
// Get performs a Gemini request for the given url.
|
||||
@ -198,10 +196,25 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
|
||||
if err := verifyHostname(cert, hostname); err != nil {
|
||||
return err
|
||||
}
|
||||
// Check that the client trusts the certificate
|
||||
var err error
|
||||
// Check the known hosts
|
||||
err := c.KnownHosts.Lookup(hostname, cert)
|
||||
switch err {
|
||||
case ErrCertificateExpired, ErrCertificateNotFound:
|
||||
default:
|
||||
return err
|
||||
}
|
||||
// See if the client trusts the certificate
|
||||
if c.TrustCertificate != nil {
|
||||
return c.TrustCertificate(hostname, cert, &c.KnownHosts)
|
||||
switch c.TrustCertificate(hostname, cert) {
|
||||
case TrustOnce:
|
||||
c.KnownHosts.AddTemporary(hostname, cert)
|
||||
return nil
|
||||
case TrustAlways:
|
||||
c.KnownHosts.Add(hostname, cert)
|
||||
return nil
|
||||
default:
|
||||
return ErrCertificateNotTrusted
|
||||
}
|
||||
} else {
|
||||
err = c.KnownHosts.Lookup(hostname, cert)
|
||||
}
|
||||
|
@ -8,12 +8,22 @@ import (
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.sr.ht/~adnano/go-gemini"
|
||||
)
|
||||
|
||||
const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is:
|
||||
%s
|
||||
|
||||
If you knew the fingerprint to expect in advance, verify that this matches.
|
||||
Otherwise, this should be safe to trust.
|
||||
|
||||
[t]rust always; trust [o]nce; [a]bort
|
||||
=> `
|
||||
|
||||
var (
|
||||
scanner = bufio.NewScanner(os.Stdin)
|
||||
client = &gemini.Client{}
|
||||
@ -22,30 +32,17 @@ var (
|
||||
func init() {
|
||||
client.Timeout = 2 * time.Minute
|
||||
client.KnownHosts.LoadDefault()
|
||||
client.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *gemini.KnownHosts) error {
|
||||
err := knownHosts.Lookup(hostname, cert)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case gemini.ErrCertificateNotTrusted:
|
||||
// Alert the user that the certificate is not trusted
|
||||
fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname)
|
||||
fmt.Println("This could indicate a Man-in-the-Middle attack.")
|
||||
case gemini.ErrCertificateUnknown:
|
||||
// Prompt the user to trust the certificate
|
||||
trust := trustCertificate(cert)
|
||||
switch trust {
|
||||
case trustOnce:
|
||||
// Temporarily trust the certificate
|
||||
knownHosts.AddTemporary(hostname, cert)
|
||||
return nil
|
||||
case trustAlways:
|
||||
// Add the certificate to the known hosts file
|
||||
knownHosts.Add(hostname, cert)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
client.TrustCertificate = func(hostname string, cert *x509.Certificate) gemini.Trust {
|
||||
fmt.Printf(trustPrompt, hostname, gemini.Fingerprint(cert))
|
||||
scanner.Scan()
|
||||
switch scanner.Text() {
|
||||
case "t":
|
||||
return gemini.TrustAlways
|
||||
case "o":
|
||||
return gemini.TrustOnce
|
||||
default:
|
||||
return gemini.TrustNone
|
||||
}
|
||||
return err
|
||||
}
|
||||
client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
|
||||
fmt.Println("Generating client certificate for", hostname, path)
|
||||
@ -60,54 +57,6 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
func doRequest(req *gemini.Request) error {
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Status.Class() == gemini.StatusClassSuccess {
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Print(string(body))
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta)
|
||||
}
|
||||
|
||||
type trust int
|
||||
|
||||
const (
|
||||
trustAbort trust = iota
|
||||
trustOnce
|
||||
trustAlways
|
||||
)
|
||||
|
||||
const trustPrompt = `The certificate offered by this server is of unknown trust. Its fingerprint is:
|
||||
%s
|
||||
|
||||
If you knew the fingerprint to expect in advance, verify that this matches.
|
||||
Otherwise, this should be safe to trust.
|
||||
|
||||
[t]rust always; trust [o]nce; [a]bort
|
||||
=> `
|
||||
|
||||
func trustCertificate(cert *x509.Certificate) trust {
|
||||
fmt.Printf(trustPrompt, gemini.Fingerprint(cert))
|
||||
scanner.Scan()
|
||||
switch scanner.Text() {
|
||||
case "t":
|
||||
return trustAlways
|
||||
case "o":
|
||||
return trustOnce
|
||||
default:
|
||||
return trustAbort
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 2 {
|
||||
fmt.Printf("usage: %s gemini://... [host]", os.Args[0])
|
||||
@ -125,8 +74,19 @@ func main() {
|
||||
req.Host = os.Args[2]
|
||||
}
|
||||
|
||||
if err := doRequest(req); err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Status.Class() == gemini.StatusClassSuccess {
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Print(string(body))
|
||||
} else {
|
||||
log.Fatalf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta)
|
||||
}
|
||||
}
|
||||
|
24
gemini.go
24
gemini.go
@ -1,11 +1,8 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var crlf = []byte("\r\n")
|
||||
@ -14,16 +11,15 @@ var crlf = []byte("\r\n")
|
||||
var (
|
||||
ErrInvalidURL = errors.New("gemini: invalid URL")
|
||||
ErrInvalidResponse = errors.New("gemini: invalid response")
|
||||
ErrCertificateUnknown = errors.New("gemini: unknown certificate")
|
||||
ErrCertificateExpired = errors.New("gemini: certificate expired")
|
||||
ErrCertificateNotFound = errors.New("gemini: certificate not found")
|
||||
ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted")
|
||||
ErrCertificateRequired = errors.New("gemini: certificate required")
|
||||
ErrNotAFile = errors.New("gemini: not a file")
|
||||
ErrNotAGeminiURL = errors.New("gemini: not a Gemini URL")
|
||||
ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body")
|
||||
ErrTooManyRedirects = errors.New("gemini: too many redirects")
|
||||
ErrInputRequired = errors.New("gemini: input required")
|
||||
ErrCertificateRequired = errors.New("gemini: certificate required")
|
||||
ErrCertificateNotFound = errors.New("gemini: certificate not found")
|
||||
)
|
||||
|
||||
// DefaultClient is the default client. It is used by Get and Do.
|
||||
@ -35,6 +31,7 @@ var DefaultClient Client
|
||||
//
|
||||
// Get is a wrapper around DefaultClient.Get.
|
||||
func Get(url string) (*Response, error) {
|
||||
setupDefaultClientOnce()
|
||||
return DefaultClient.Get(url)
|
||||
}
|
||||
|
||||
@ -42,19 +39,14 @@ func Get(url string) (*Response, error) {
|
||||
//
|
||||
// Do is a wrapper around DefaultClient.Do.
|
||||
func Do(req *Request) (*Response, error) {
|
||||
setupDefaultClientOnce()
|
||||
return DefaultClient.Do(req)
|
||||
}
|
||||
|
||||
var defaultClientOnce sync.Once
|
||||
|
||||
func init() {
|
||||
DefaultClient.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error {
|
||||
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
|
||||
return knownHosts.Lookup(hostname, cert)
|
||||
}
|
||||
DefaultClient.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
|
||||
return CreateCertificate(CertificateOptions{
|
||||
Duration: time.Hour,
|
||||
})
|
||||
}
|
||||
func setupDefaultClientOnce() {
|
||||
defaultClientOnce.Do(func() {
|
||||
DefaultClient.KnownHosts.LoadDefault()
|
||||
})
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error
|
||||
|
||||
func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
|
||||
if _, ok := s.hosts[hostname]; !ok {
|
||||
return nil, ErrCertificateUnknown
|
||||
return nil, ErrCertificateNotFound
|
||||
}
|
||||
cert, err := s.Certificates.Lookup(hostname)
|
||||
|
||||
|
22
tofu.go
22
tofu.go
@ -13,6 +13,15 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Trust represents the trustworthiness of a certificate.
|
||||
type Trust int
|
||||
|
||||
const (
|
||||
TrustNone Trust = iota // The certificate is not trusted.
|
||||
TrustOnce // The certificate is trusted once.
|
||||
TrustAlways // The certificate is trusted always.
|
||||
)
|
||||
|
||||
// KnownHosts represents a list of known hosts.
|
||||
// The zero value for KnownHosts is an empty list ready to use.
|
||||
type KnownHosts struct {
|
||||
@ -86,26 +95,25 @@ func (k *KnownHosts) add(hostname string, cert *x509.Certificate, write bool) {
|
||||
}
|
||||
|
||||
// Lookup looks for the provided certificate in the list of known hosts.
|
||||
// If the hostname is in the list, but the fingerprint differs,
|
||||
// Lookup returns ErrCertificateNotTrusted.
|
||||
// If the hostname is not in the list, Lookup returns ErrCertificateUnknown.
|
||||
// If the certificate is found and the fingerprint matches, error will be nil.
|
||||
// If the hostname is not in the list, Lookup returns ErrCertificateNotFound.
|
||||
// If the fingerprint doesn't match, Lookup returns ErrCertificateNotTrusted.
|
||||
// Otherwise, Lookup returns nil.
|
||||
func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
|
||||
now := time.Now().Unix()
|
||||
fingerprint := Fingerprint(cert)
|
||||
if c, ok := k.hosts[hostname]; ok {
|
||||
if c.Expires <= now {
|
||||
// Certificate is expired
|
||||
return ErrCertificateUnknown
|
||||
return ErrCertificateExpired
|
||||
}
|
||||
if c.Fingerprint != fingerprint {
|
||||
// Fingerprint does not match
|
||||
return ErrCertificateNotTrusted
|
||||
}
|
||||
// Certificate is trusted
|
||||
// Certificate is found
|
||||
return nil
|
||||
}
|
||||
return ErrCertificateUnknown
|
||||
return ErrCertificateNotFound
|
||||
}
|
||||
|
||||
// Parse parses the provided reader and adds the parsed known hosts to the list.
|
||||
|
Loading…
Reference in New Issue
Block a user