Remove (*Request).Hostname function

This commit is contained in:
adnano 2020-10-13 13:31:50 -04:00
parent 4770ad9112
commit a0aa135534
2 changed files with 12 additions and 14 deletions

View File

@ -9,7 +9,6 @@ import (
"net" "net"
"net/url" "net/url"
"strconv" "strconv"
"strings"
"time" "time"
) )
@ -48,14 +47,13 @@ type Request struct {
TLS tls.ConnectionState TLS tls.ConnectionState
} }
// Hostname returns the request host without the port. // hostname returns the host without the port.
// It assumes that r.Host contains a valid host:port. func hostname(host string) string {
func (r *Request) Hostname() string { hostname, _, err := net.SplitHostPort(host)
colon := strings.LastIndexByte(r.Host, ':') if err != nil {
if colon != -1 { return host
return r.Host[:colon]
} }
return r.Host return hostname
} }
// NewRequest returns a new request. The host is inferred from the provided URL. // NewRequest returns a new request. The host is inferred from the provided URL.
@ -218,7 +216,7 @@ func (c *Client) Send(req *Request) (*Response, error) {
return req.Certificate, nil return req.Certificate, nil
} }
// If we have already stored the certificate, return it // If we have already stored the certificate, return it
if cert, err := c.CertificateStore.Lookup(req.Hostname()); err == nil { if cert, err := c.CertificateStore.Lookup(hostname(req.Host)); err == nil {
return cert, nil return cert, nil
} }
return &tls.Certificate{}, nil return &tls.Certificate{}, nil
@ -235,15 +233,15 @@ func (c *Client) Send(req *Request) (*Response, error) {
} }
// Check that the certificate is valid for the hostname // Check that the certificate is valid for the hostname
// Use our own implementation of verifyHostname // Use our own implementation of verifyHostname
if err := verifyHostname(cert, req.Hostname()); err != nil { if err := verifyHostname(cert, hostname(req.Host)); err != nil {
return err return err
} }
// Check that the client trusts the certificate // Check that the client trusts the certificate
if c.TrustCertificate == nil { if c.TrustCertificate == nil {
if err := c.KnownHosts.Lookup(req.Hostname(), cert); err != nil { if err := c.KnownHosts.Lookup(hostname(req.Host), cert); err != nil {
return err return err
} }
} else if err := c.TrustCertificate(req.Hostname(), cert, &c.KnownHosts); err != nil { } else if err := c.TrustCertificate(hostname(req.Host), cert, &c.KnownHosts); err != nil {
return err return err
} }
return nil return nil
@ -279,7 +277,7 @@ func (c *Client) Send(req *Request) (*Response, error) {
return resp, nil return resp, nil
} }
if c.GetCertificate != nil { if c.GetCertificate != nil {
if cert := c.GetCertificate(req.Hostname(), &c.CertificateStore); cert != nil { if cert := c.GetCertificate(hostname(req.Host), &c.CertificateStore); cert != nil {
req.Certificate = cert req.Certificate = cert
return c.Send(req) return c.Send(req)
} }

View File

@ -240,7 +240,7 @@ func (s *Server) respond(conn net.Conn) {
} }
func (s *Server) handler(req *Request) Handler { func (s *Server) handler(req *Request) Handler {
if h, ok := s.handlers[handlerKey{req.URL.Scheme, req.Hostname()}]; ok { if h, ok := s.handlers[handlerKey{req.URL.Scheme, req.URL.Hostname()}]; ok {
return h return h
} }
return NotFoundHandler() return NotFoundHandler()