diff --git a/client.go b/client.go index 147d920..d507960 100644 --- a/client.go +++ b/client.go @@ -45,6 +45,11 @@ type Request struct { TLS tls.ConnectionState } +// Hostname returns the request host without the port. +func (r *Request) Hostname() string { + return hostname(r.Host) +} + // NewRequest returns a new request. The host is inferred from the provided url. func NewRequest(rawurl string) (*Request, error) { u, err := url.Parse(rawurl) @@ -180,7 +185,7 @@ type Client struct { // TrustCertificate, if not nil, will be called to determine whether the // client should trust the given certificate. // If error is not nil, the connection will be aborted. - TrustCertificate func(cert *x509.Certificate, knownHosts *KnownHosts) error + TrustCertificate func(req *Request, cert *x509.Certificate, knownHosts *KnownHosts) error } // Send sends a Gemini request and returns a Gemini response. @@ -205,15 +210,15 @@ func (c *Client) Send(req *Request) (*Response, error) { return err } // Check that the certificate is valid for the hostname - if err := cert.VerifyHostname(hostname(req.Host)); err != nil { + if err := cert.VerifyHostname(req.Hostname()); err != nil { return err } // Check that the client trusts the certificate if c.TrustCertificate == nil { - if err := c.KnownHosts.Lookup(cert); err != nil { + if err := c.KnownHosts.Lookup(req.Hostname(), cert); err != nil { return err } - } else if err := c.TrustCertificate(cert, &c.KnownHosts); err != nil { + } else if err := c.TrustCertificate(req, cert, &c.KnownHosts); err != nil { return err } return nil diff --git a/examples/client/client.go b/examples/client/client.go index d3e2156..e6be51d 100644 --- a/examples/client/client.go +++ b/examples/client/client.go @@ -22,8 +22,8 @@ func init() { client = &gemini.Client{} client.KnownHosts.Load() - client.TrustCertificate = func(cert *x509.Certificate, knownHosts *gemini.KnownHosts) error { - err := knownHosts.Lookup(cert) + client.TrustCertificate = func(req *gemini.Request, cert *x509.Certificate, knownHosts *gemini.KnownHosts) error { + err := knownHosts.Lookup(req.Hostname(), cert) if err != nil { switch err { case gemini.ErrCertificateNotTrusted: diff --git a/gemini.go b/gemini.go index 99ae901..f2185e3 100644 --- a/gemini.go +++ b/gemini.go @@ -48,11 +48,11 @@ var DefaultClient *Client func init() { DefaultClient = &Client{ - TrustCertificate: func(cert *x509.Certificate, knownHosts *KnownHosts) error { + TrustCertificate: func(req *Request, cert *x509.Certificate, knownHosts *KnownHosts) error { // Load the hosts only once. This is so that the hosts don't have to be loaded // for those using their own clients. setupDefaultClientOnce.Do(setupDefaultClient) - return knownHosts.Lookup(cert) + return knownHosts.Lookup(req.Hostname(), cert) }, } } diff --git a/tofu.go b/tofu.go index bdda919..8143099 100644 --- a/tofu.go +++ b/tofu.go @@ -77,9 +77,8 @@ func (k *KnownHosts) Add(cert *x509.Certificate) { // Lookup returns ErrCertificateNotTrusted. // If the hostname is not in the list, Lookup returns ErrCertificateUnknown. // If the certificate is found and the fingerprint matches, error will be nil. -func (k *KnownHosts) Lookup(cert *x509.Certificate) error { +func (k *KnownHosts) Lookup(hostname string, cert *x509.Certificate) error { now := time.Now().Unix() - hostname := cert.Subject.CommonName fingerprint := Fingerprint(cert) for i := range k.hosts { if k.hosts[i].Hostname != hostname {