Refactor client.TrustCertificate workflow
This commit is contained in:
parent
b9cb7fe71d
commit
65a5065250
2
cert.go
2
cert.go
|
@ -16,7 +16,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CertificateStore maps hostnames to certificates.
|
// CertificateStore maps certificate scopes to certificates.
|
||||||
// The zero value of CertificateStore is an empty store ready to use.
|
// The zero value of CertificateStore is an empty store ready to use.
|
||||||
type CertificateStore struct {
|
type CertificateStore struct {
|
||||||
store map[string]tls.Certificate
|
store map[string]tls.Certificate
|
||||||
|
|
31
client.go
31
client.go
|
@ -43,12 +43,10 @@ type Client struct {
|
||||||
// the request will not be sent again and the response will be returned.
|
// the request will not be sent again and the response will be returned.
|
||||||
CreateCertificate func(hostname, path string) (tls.Certificate, error)
|
CreateCertificate func(hostname, path string) (tls.Certificate, error)
|
||||||
|
|
||||||
// TrustCertificate determines whether the client should trust
|
// TrustCertificate is called to determine whether the client
|
||||||
// the provided certificate.
|
// should trust a certificate it has not seen before.
|
||||||
// If the returned error is not nil, the connection will be aborted.
|
// If TrustCertificate is nil, the certificate will not be trusted.
|
||||||
// If TrustCertificate is nil, the client will check KnownHosts
|
TrustCertificate func(hostname string, cert *x509.Certificate) Trust
|
||||||
// for the certificate.
|
|
||||||
TrustCertificate func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get performs a Gemini request for the given url.
|
// Get performs a Gemini request for the given url.
|
||||||
|
@ -198,10 +196,25 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error {
|
||||||
if err := verifyHostname(cert, hostname); err != nil {
|
if err := verifyHostname(cert, hostname); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Check that the client trusts the certificate
|
// Check the known hosts
|
||||||
var err error
|
err := c.KnownHosts.Lookup(hostname, cert)
|
||||||
|
switch err {
|
||||||
|
case ErrCertificateExpired, ErrCertificateNotFound:
|
||||||
|
default:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// See if the client trusts the certificate
|
||||||
if c.TrustCertificate != nil {
|
if c.TrustCertificate != nil {
|
||||||
return c.TrustCertificate(hostname, cert, &c.KnownHosts)
|
switch c.TrustCertificate(hostname, cert) {
|
||||||
|
case TrustOnce:
|
||||||
|
c.KnownHosts.AddTemporary(hostname, cert)
|
||||||
|
return nil
|
||||||
|
case TrustAlways:
|
||||||
|
c.KnownHosts.Add(hostname, cert)
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return ErrCertificateNotTrusted
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
err = c.KnownHosts.Lookup(hostname, cert)
|
err = c.KnownHosts.Lookup(hostname, cert)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,12 +8,22 @@ import (
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.sr.ht/~adnano/go-gemini"
|
"git.sr.ht/~adnano/go-gemini"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const trustPrompt = `The certificate offered by %s is of unknown trust. Its fingerprint is:
|
||||||
|
%s
|
||||||
|
|
||||||
|
If you knew the fingerprint to expect in advance, verify that this matches.
|
||||||
|
Otherwise, this should be safe to trust.
|
||||||
|
|
||||||
|
[t]rust always; trust [o]nce; [a]bort
|
||||||
|
=> `
|
||||||
|
|
||||||
var (
|
var (
|
||||||
scanner = bufio.NewScanner(os.Stdin)
|
scanner = bufio.NewScanner(os.Stdin)
|
||||||
client = &gemini.Client{}
|
client = &gemini.Client{}
|
||||||
|
@ -22,30 +32,17 @@ var (
|
||||||
func init() {
|
func init() {
|
||||||
client.Timeout = 2 * time.Minute
|
client.Timeout = 2 * time.Minute
|
||||||
client.KnownHosts.LoadDefault()
|
client.KnownHosts.LoadDefault()
|
||||||
client.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *gemini.KnownHosts) error {
|
client.TrustCertificate = func(hostname string, cert *x509.Certificate) gemini.Trust {
|
||||||
err := knownHosts.Lookup(hostname, cert)
|
fmt.Printf(trustPrompt, hostname, gemini.Fingerprint(cert))
|
||||||
if err != nil {
|
scanner.Scan()
|
||||||
switch err {
|
switch scanner.Text() {
|
||||||
case gemini.ErrCertificateNotTrusted:
|
case "t":
|
||||||
// Alert the user that the certificate is not trusted
|
return gemini.TrustAlways
|
||||||
fmt.Printf("Warning: Certificate for %s is not trusted!\n", hostname)
|
case "o":
|
||||||
fmt.Println("This could indicate a Man-in-the-Middle attack.")
|
return gemini.TrustOnce
|
||||||
case gemini.ErrCertificateUnknown:
|
default:
|
||||||
// Prompt the user to trust the certificate
|
return gemini.TrustNone
|
||||||
trust := trustCertificate(cert)
|
|
||||||
switch trust {
|
|
||||||
case trustOnce:
|
|
||||||
// Temporarily trust the certificate
|
|
||||||
knownHosts.AddTemporary(hostname, cert)
|
|
||||||
return nil
|
|
||||||
case trustAlways:
|
|
||||||
// Add the certificate to the known hosts file
|
|
||||||
knownHosts.Add(hostname, cert)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
|
client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
|
||||||
fmt.Println("Generating client certificate for", hostname, path)
|
fmt.Println("Generating client certificate for", hostname, path)
|
||||||
|
@ -60,54 +57,6 @@ func init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func doRequest(req *gemini.Request) error {
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.Status.Class() == gemini.StatusClassSuccess {
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
fmt.Print(string(body))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta)
|
|
||||||
}
|
|
||||||
|
|
||||||
type trust int
|
|
||||||
|
|
||||||
const (
|
|
||||||
trustAbort trust = iota
|
|
||||||
trustOnce
|
|
||||||
trustAlways
|
|
||||||
)
|
|
||||||
|
|
||||||
const trustPrompt = `The certificate offered by this server is of unknown trust. Its fingerprint is:
|
|
||||||
%s
|
|
||||||
|
|
||||||
If you knew the fingerprint to expect in advance, verify that this matches.
|
|
||||||
Otherwise, this should be safe to trust.
|
|
||||||
|
|
||||||
[t]rust always; trust [o]nce; [a]bort
|
|
||||||
=> `
|
|
||||||
|
|
||||||
func trustCertificate(cert *x509.Certificate) trust {
|
|
||||||
fmt.Printf(trustPrompt, gemini.Fingerprint(cert))
|
|
||||||
scanner.Scan()
|
|
||||||
switch scanner.Text() {
|
|
||||||
case "t":
|
|
||||||
return trustAlways
|
|
||||||
case "o":
|
|
||||||
return trustOnce
|
|
||||||
default:
|
|
||||||
return trustAbort
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if len(os.Args) < 2 {
|
if len(os.Args) < 2 {
|
||||||
fmt.Printf("usage: %s gemini://... [host]", os.Args[0])
|
fmt.Printf("usage: %s gemini://... [host]", os.Args[0])
|
||||||
|
@ -125,8 +74,19 @@ func main() {
|
||||||
req.Host = os.Args[2]
|
req.Host = os.Args[2]
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := doRequest(req); err != nil {
|
resp, err := client.Do(req)
|
||||||
fmt.Println(err)
|
if err != nil {
|
||||||
os.Exit(1)
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Status.Class() == gemini.StatusClassSuccess {
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
fmt.Print(string(body))
|
||||||
|
} else {
|
||||||
|
log.Fatalf("request failed: %d %s: %s", resp.Status, resp.Status.Message(), resp.Meta)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
24
gemini.go
24
gemini.go
|
@ -1,11 +1,8 @@
|
||||||
package gemini
|
package gemini
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var crlf = []byte("\r\n")
|
var crlf = []byte("\r\n")
|
||||||
|
@ -14,16 +11,15 @@ var crlf = []byte("\r\n")
|
||||||
var (
|
var (
|
||||||
ErrInvalidURL = errors.New("gemini: invalid URL")
|
ErrInvalidURL = errors.New("gemini: invalid URL")
|
||||||
ErrInvalidResponse = errors.New("gemini: invalid response")
|
ErrInvalidResponse = errors.New("gemini: invalid response")
|
||||||
ErrCertificateUnknown = errors.New("gemini: unknown certificate")
|
|
||||||
ErrCertificateExpired = errors.New("gemini: certificate expired")
|
ErrCertificateExpired = errors.New("gemini: certificate expired")
|
||||||
|
ErrCertificateNotFound = errors.New("gemini: certificate not found")
|
||||||
ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted")
|
ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted")
|
||||||
|
ErrCertificateRequired = errors.New("gemini: certificate required")
|
||||||
ErrNotAFile = errors.New("gemini: not a file")
|
ErrNotAFile = errors.New("gemini: not a file")
|
||||||
ErrNotAGeminiURL = errors.New("gemini: not a Gemini URL")
|
ErrNotAGeminiURL = errors.New("gemini: not a Gemini URL")
|
||||||
ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body")
|
ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body")
|
||||||
ErrTooManyRedirects = errors.New("gemini: too many redirects")
|
ErrTooManyRedirects = errors.New("gemini: too many redirects")
|
||||||
ErrInputRequired = errors.New("gemini: input required")
|
ErrInputRequired = errors.New("gemini: input required")
|
||||||
ErrCertificateRequired = errors.New("gemini: certificate required")
|
|
||||||
ErrCertificateNotFound = errors.New("gemini: certificate not found")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultClient is the default client. It is used by Get and Do.
|
// DefaultClient is the default client. It is used by Get and Do.
|
||||||
|
@ -35,6 +31,7 @@ var DefaultClient Client
|
||||||
//
|
//
|
||||||
// Get is a wrapper around DefaultClient.Get.
|
// Get is a wrapper around DefaultClient.Get.
|
||||||
func Get(url string) (*Response, error) {
|
func Get(url string) (*Response, error) {
|
||||||
|
setupDefaultClientOnce()
|
||||||
return DefaultClient.Get(url)
|
return DefaultClient.Get(url)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,19 +39,14 @@ func Get(url string) (*Response, error) {
|
||||||
//
|
//
|
||||||
// Do is a wrapper around DefaultClient.Do.
|
// Do is a wrapper around DefaultClient.Do.
|
||||||
func Do(req *Request) (*Response, error) {
|
func Do(req *Request) (*Response, error) {
|
||||||
|
setupDefaultClientOnce()
|
||||||
return DefaultClient.Do(req)
|
return DefaultClient.Do(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultClientOnce sync.Once
|
var defaultClientOnce sync.Once
|
||||||
|
|
||||||
func init() {
|
func setupDefaultClientOnce() {
|
||||||
DefaultClient.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *KnownHosts) error {
|
defaultClientOnce.Do(func() {
|
||||||
defaultClientOnce.Do(func() { knownHosts.LoadDefault() })
|
DefaultClient.KnownHosts.LoadDefault()
|
||||||
return knownHosts.Lookup(hostname, cert)
|
})
|
||||||
}
|
|
||||||
DefaultClient.CreateCertificate = func(hostname, path string) (tls.Certificate, error) {
|
|
||||||
return CreateCertificate(CertificateOptions{
|
|
||||||
Duration: time.Hour,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -146,7 +146,7 @@ func (s *Server) getCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error
|
||||||
|
|
||||||
func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
|
func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
|
||||||
if _, ok := s.hosts[hostname]; !ok {
|
if _, ok := s.hosts[hostname]; !ok {
|
||||||
return nil, ErrCertificateUnknown
|
return nil, ErrCertificateNotFound
|
||||||
}
|
}
|
||||||
cert, err := s.Certificates.Lookup(hostname)
|
cert, err := s.Certificates.Lookup(hostname)
|
||||||
|
|
||||||
|
|
22
tofu.go
22
tofu.go
|
@ -13,6 +13,15 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Trust represents the trustworthiness of a certificate.
|
||||||
|
type Trust int
|
||||||
|
|
||||||
|
const (
|
||||||
|
TrustNone Trust = iota // The certificate is not trusted.
|
||||||
|
TrustOnce // The certificate is trusted once.
|
||||||
|
TrustAlways // The certificate is trusted always.
|
||||||
|
)
|
||||||
|
|
||||||
// KnownHosts represents a list of known hosts.
|
// KnownHosts represents a list of known hosts.
|
||||||
// The zero value for KnownHosts is an empty list ready to use.
|
// The zero value for KnownHosts is an empty list ready to use.
|
||||||
type KnownHosts struct {
|
type KnownHosts struct {
|
||||||
|
@ -86,26 +95,25 @@ func (k *KnownHosts) add(hostname string, cert *x509.Certificate, write bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lookup looks for the provided certificate in the list of known hosts.
|
// Lookup looks for the provided certificate in the list of known hosts.
|
||||||
// If the hostname is in the list, but the fingerprint differs,
|
// If the hostname is not in the list, Lookup returns ErrCertificateNotFound.
|
||||||
// Lookup returns ErrCertificateNotTrusted.
|
// If the fingerprint doesn't match, Lookup returns ErrCertificateNotTrusted.
|
||||||
// If the hostname is not in the list, Lookup returns ErrCertificateUnknown.
|
// Otherwise, Lookup returns nil.
|
||||||
// If the certificate is found and the fingerprint matches, error will be nil.
|
|
||||||
func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
|
func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
fingerprint := Fingerprint(cert)
|
fingerprint := Fingerprint(cert)
|
||||||
if c, ok := k.hosts[hostname]; ok {
|
if c, ok := k.hosts[hostname]; ok {
|
||||||
if c.Expires <= now {
|
if c.Expires <= now {
|
||||||
// Certificate is expired
|
// Certificate is expired
|
||||||
return ErrCertificateUnknown
|
return ErrCertificateExpired
|
||||||
}
|
}
|
||||||
if c.Fingerprint != fingerprint {
|
if c.Fingerprint != fingerprint {
|
||||||
// Fingerprint does not match
|
// Fingerprint does not match
|
||||||
return ErrCertificateNotTrusted
|
return ErrCertificateNotTrusted
|
||||||
}
|
}
|
||||||
// Certificate is trusted
|
// Certificate is found
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return ErrCertificateUnknown
|
return ErrCertificateNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse parses the provided reader and adds the parsed known hosts to the list.
|
// Parse parses the provided reader and adds the parsed known hosts to the list.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user