diff --git a/client.go b/client.go index 1969f1e..5a94031 100644 --- a/client.go +++ b/client.go @@ -5,8 +5,8 @@ import ( "crypto/tls" "crypto/x509" "errors" - "fmt" "net" + "net/url" "time" ) @@ -28,6 +28,10 @@ type Client struct { // // A Timeout of zero means no timeout. Timeout time.Duration + + // DialContext specifies the dial function for creating TCP connections. + // If DialContext is nil, the client dials using package net. + DialContext func(ctx context.Context, network, addr string) (net.Conn, error) } // Get sends a Gemini request for the given URL. @@ -39,12 +43,12 @@ type Client struct { // which the user is expected to close. // // For more control over requests, use NewRequest and Client.Do. -func (c *Client) Get(url string) (*Response, error) { +func (c *Client) Get(ctx context.Context, url string) (*Response, error) { req, err := NewRequest(url) if err != nil { return nil, err } - return c.Do(req) + return c.Do(ctx, req) } // Do sends a Gemini request and returns a Gemini response, following @@ -57,48 +61,56 @@ func (c *Client) Get(url string) (*Response, error) { // which the user is expected to close. // // Generally Get will be used instead of Do. -func (c *Client) Do(req *Request) (*Response, error) { - // Punycode request URL host - hostname, port, err := net.SplitHostPort(req.URL.Host) - if err != nil { - // Likely no port - hostname = req.URL.Host - port = "1965" +func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) { + if ctx == nil { + panic("nil context") } - punycode, err := punycodeHostname(hostname) + + // Punycode request URL host + host, port := splitHostPort(req.URL.Host) + punycode, err := punycodeHostname(host) if err != nil { return nil, err } - if hostname != punycode { - hostname = punycode + if host != punycode { + host = punycode // Make a copy of the request - _req := *req - req = &_req - _url := *req.URL - req.URL = &_url + r2 := new(Request) + *r2 = *req + r2.URL = new(url.URL) + *r2.URL = *req.URL + req = r2 // Set the host - req.URL.Host = net.JoinHostPort(hostname, port) + req.URL.Host = net.JoinHostPort(host, port) } // Use request host if provided if req.Host != "" { - hostname, port, err = net.SplitHostPort(req.Host) - if err != nil { - // Likely no port - hostname = req.Host - port = "1965" - } - // Punycode hostname - hostname, err = punycodeHostname(hostname) + host, port = splitHostPort(req.Host) + host, err = punycodeHostname(host) if err != nil { return nil, err } } + addr := net.JoinHostPort(host, port) + // Connect to the host - config := &tls.Config{ + start := time.Now() + conn, err := c.dialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + + // Set the connection deadline + if c.Timeout != 0 { + conn.SetDeadline(start.Add(c.Timeout)) + } + + // Setup TLS + conn = tls.Client(conn, &tls.Config{ InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { @@ -108,83 +120,84 @@ func (c *Client) Do(req *Request) (*Response, error) { return &tls.Certificate{}, nil }, VerifyConnection: func(cs tls.ConnectionState) error { - return c.verifyConnection(hostname, punycode, cs) + return c.verifyConnection(cs, host) }, - ServerName: hostname, + ServerName: host, + }) + + res := make(chan result, 1) + go func() { + res <- c.do(conn, req) + }() + + select { + case <-ctx.Done(): + conn.Close() + return nil, ctx.Err() + case r := <-res: + return r.resp, r.err } - - ctx := req.Context - if ctx == nil { - ctx = context.Background() - } - - start := time.Now() - dialer := net.Dialer{ - Timeout: c.Timeout, - } - - address := net.JoinHostPort(hostname, port) - netConn, err := dialer.DialContext(ctx, "tcp", address) - if err != nil { - return nil, err - } - - conn := tls.Client(netConn, config) - - // Set connection deadline - if c.Timeout != 0 { - err := conn.SetDeadline(start.Add(c.Timeout)) - if err != nil { - return nil, fmt.Errorf("failed to set connection deadline: %w", err) - } - } - - resp, err := c.do(conn, req) - if err != nil { - // If we fail to perform the request/response we have - // to take responsibility for closing the connection. - _ = conn.Close() - - return nil, err - } - - // Store connection state - state := conn.ConnectionState() - resp.TLS = &state - - return resp, nil } -func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) { +type result struct { + resp *Response + err error +} + +func (c *Client) do(conn net.Conn, req *Request) result { // Write the request - err := req.Write(conn) - if err != nil { - return nil, fmt.Errorf("failed to write request: %w", err) + if err := req.Write(conn); err != nil { + return result{nil, err} } // Read the response resp, err := ReadResponse(conn) if err != nil { - return nil, err + return result{nil, err} } - return resp, nil + // Store TLS connection state + if tlsConn, ok := conn.(*tls.Conn); ok { + state := tlsConn.ConnectionState() + resp.TLS = &state + } + + return result{resp, nil} } -func (c *Client) verifyConnection(hostname, punycode string, cs tls.ConnectionState) error { +func (c *Client) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { + if c.DialContext != nil { + return c.DialContext(ctx, network, addr) + } + return (&net.Dialer{ + Timeout: c.Timeout, + }).DialContext(ctx, network, addr) +} + +func (c *Client) verifyConnection(cs tls.ConnectionState, hostname string) error { cert := cs.PeerCertificates[0] - // Verify punycoded hostname - if err := verifyHostname(cert, punycode); err != nil { + // Verify hostname + if err := verifyHostname(cert, hostname); err != nil { return err } // Check expiration date if !time.Now().Before(cert.NotAfter) { return errors.New("gemini: certificate expired") } - // See if the client trusts the certificate if c.TrustCertificate != nil { return c.TrustCertificate(hostname, cert) } return nil } + +func splitHostPort(hostport string) (host, port string) { + var err error + host, port, err = net.SplitHostPort(hostport) + if err != nil { + // Likely no port + host = hostport + port = "1965" + } + return +} diff --git a/request.go b/request.go index 2d10f50..b50176f 100644 --- a/request.go +++ b/request.go @@ -2,7 +2,6 @@ package gemini import ( "bufio" - "context" "crypto/tls" "io" "net/url" @@ -48,14 +47,6 @@ type Request struct { // otherwise it leaves the field nil. // This field is ignored by the Gemini client. TLS *tls.ConnectionState - - // Context specifies the context to use for outgoing requests. - // The context controls the entire lifetime of a request and its - // response: obtaining a connection, sending the request, and - // reading the response header and body. - // If Context is nil, the background context will be used. - // This field is ignored by the Gemini server. - Context context.Context } // NewRequest returns a new request. diff --git a/timeout.go b/timeout.go index cff5307..14b7860 100644 --- a/timeout.go +++ b/timeout.go @@ -28,13 +28,8 @@ type timeoutHandler struct { } func (t *timeoutHandler) ServeGemini(w ResponseWriter, r *Request) { - ctx := r.Context - if ctx == nil { - ctx = context.Background() - } - var cancelCtx func() - ctx, cancelCtx = context.WithTimeout(ctx, t.dt) - defer cancelCtx() + ctx, cancel := context.WithTimeout(context.TODO(), t.dt) + defer cancel() done := make(chan struct{}) tw := &timeoutWriter{}