diff --git a/client.go b/client.go index 4efd95f..57a7d67 100644 --- a/client.go +++ b/client.go @@ -27,6 +27,12 @@ type Client struct { // A Timeout of zero means no timeout. Timeout time.Duration + // InsecureTrustAlways specifies whether the client should trust + // any certificate it recieves without checking KnownHosts + // or calling TrustCertificate. + // Use with caution. + InsecureTrustAlways bool + // GetInput is called to retrieve input when the server requests it. // If GetInput is nil or returns false, no input will be sent and // the response will be returned. @@ -196,27 +202,25 @@ func (c *Client) verifyConnection(req *Request, cs tls.ConnectionState) error { if err := verifyHostname(cert, hostname); err != nil { return err } + if c.InsecureTrustAlways { + return nil + } // Check the known hosts err := c.KnownHosts.Lookup(hostname, cert) switch err { case ErrCertificateExpired, ErrCertificateNotFound: - default: - return err - } - // See if the client trusts the certificate - if c.TrustCertificate != nil { - switch c.TrustCertificate(hostname, cert) { - case TrustOnce: - c.KnownHosts.AddTemporary(hostname, cert) - return nil - case TrustAlways: - c.KnownHosts.Add(hostname, cert) - return nil - default: - return ErrCertificateNotTrusted + // See if the client trusts the certificate + if c.TrustCertificate != nil { + switch c.TrustCertificate(hostname, cert) { + case TrustOnce: + c.KnownHosts.AddTemporary(hostname, cert) + return nil + case TrustAlways: + c.KnownHosts.Add(hostname, cert) + return nil + } } - } else { - err = c.KnownHosts.Lookup(hostname, cert) + return ErrCertificateNotTrusted } return err } diff --git a/examples/client.go b/examples/client.go index 4c9b172..0a35cca 100644 --- a/examples/client.go +++ b/examples/client.go @@ -5,7 +5,6 @@ package main import ( "bufio" "crypto/tls" - "crypto/x509" "fmt" "io/ioutil" "log" @@ -31,19 +30,6 @@ var ( func init() { client.Timeout = 2 * time.Minute - client.KnownHosts.LoadDefault() - client.TrustCertificate = func(hostname string, cert *x509.Certificate) gemini.Trust { - fmt.Printf(trustPrompt, hostname, gemini.Fingerprint(cert)) - scanner.Scan() - switch scanner.Text() { - case "t": - return gemini.TrustAlways - case "o": - return gemini.TrustOnce - default: - return gemini.TrustNone - } - } client.CreateCertificate = func(hostname, path string) (tls.Certificate, error) { fmt.Println("Generating client certificate for", hostname, path) return gemini.CreateCertificate(gemini.CertificateOptions{ diff --git a/gemini.go b/gemini.go index b03400d..e5ac439 100644 --- a/gemini.go +++ b/gemini.go @@ -13,7 +13,7 @@ var ( ErrInvalidResponse = errors.New("gemini: invalid response") ErrCertificateExpired = errors.New("gemini: certificate expired") ErrCertificateNotFound = errors.New("gemini: certificate not found") - ErrCertificateNotTrusted = errors.New("gemini: certificate is not trusted") + ErrCertificateNotTrusted = errors.New("gemini: certificate not trusted") ErrCertificateRequired = errors.New("gemini: certificate required") ErrNotAFile = errors.New("gemini: not a file") ErrNotAGeminiURL = errors.New("gemini: not a Gemini URL")