client: Make Get and Do accept a Context

This removes the need for Request.Context.
This commit is contained in:
Adnan Maolood 2021-02-20 13:37:08 -05:00
parent a7f958b20d
commit 3f4fd10b6d
3 changed files with 95 additions and 96 deletions

173
client.go
View File

@ -5,8 +5,8 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt"
"net" "net"
"net/url"
"time" "time"
) )
@ -28,6 +28,10 @@ type Client struct {
// //
// A Timeout of zero means no timeout. // A Timeout of zero means no timeout.
Timeout time.Duration Timeout time.Duration
// DialContext specifies the dial function for creating TCP connections.
// If DialContext is nil, the client dials using package net.
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
} }
// Get sends a Gemini request for the given URL. // Get sends a Gemini request for the given URL.
@ -39,12 +43,12 @@ type Client struct {
// which the user is expected to close. // which the user is expected to close.
// //
// For more control over requests, use NewRequest and Client.Do. // For more control over requests, use NewRequest and Client.Do.
func (c *Client) Get(url string) (*Response, error) { func (c *Client) Get(ctx context.Context, url string) (*Response, error) {
req, err := NewRequest(url) req, err := NewRequest(url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.Do(req) return c.Do(ctx, req)
} }
// Do sends a Gemini request and returns a Gemini response, following // Do sends a Gemini request and returns a Gemini response, following
@ -57,48 +61,56 @@ func (c *Client) Get(url string) (*Response, error) {
// which the user is expected to close. // which the user is expected to close.
// //
// Generally Get will be used instead of Do. // Generally Get will be used instead of Do.
func (c *Client) Do(req *Request) (*Response, error) { func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
// Punycode request URL host if ctx == nil {
hostname, port, err := net.SplitHostPort(req.URL.Host) panic("nil context")
if err != nil {
// Likely no port
hostname = req.URL.Host
port = "1965"
} }
punycode, err := punycodeHostname(hostname)
// Punycode request URL host
host, port := splitHostPort(req.URL.Host)
punycode, err := punycodeHostname(host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if hostname != punycode { if host != punycode {
hostname = punycode host = punycode
// Make a copy of the request // Make a copy of the request
_req := *req r2 := new(Request)
req = &_req *r2 = *req
_url := *req.URL r2.URL = new(url.URL)
req.URL = &_url *r2.URL = *req.URL
req = r2
// Set the host // Set the host
req.URL.Host = net.JoinHostPort(hostname, port) req.URL.Host = net.JoinHostPort(host, port)
} }
// Use request host if provided // Use request host if provided
if req.Host != "" { if req.Host != "" {
hostname, port, err = net.SplitHostPort(req.Host) host, port = splitHostPort(req.Host)
if err != nil { host, err = punycodeHostname(host)
// Likely no port
hostname = req.Host
port = "1965"
}
// Punycode hostname
hostname, err = punycodeHostname(hostname)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
addr := net.JoinHostPort(host, port)
// Connect to the host // Connect to the host
config := &tls.Config{ start := time.Now()
conn, err := c.dialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
// Set the connection deadline
if c.Timeout != 0 {
conn.SetDeadline(start.Add(c.Timeout))
}
// Setup TLS
conn = tls.Client(conn, &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
@ -108,83 +120,84 @@ func (c *Client) Do(req *Request) (*Response, error) {
return &tls.Certificate{}, nil return &tls.Certificate{}, nil
}, },
VerifyConnection: func(cs tls.ConnectionState) error { VerifyConnection: func(cs tls.ConnectionState) error {
return c.verifyConnection(hostname, punycode, cs) return c.verifyConnection(cs, host)
}, },
ServerName: hostname, ServerName: host,
})
res := make(chan result, 1)
go func() {
res <- c.do(conn, req)
}()
select {
case <-ctx.Done():
conn.Close()
return nil, ctx.Err()
case r := <-res:
return r.resp, r.err
} }
ctx := req.Context
if ctx == nil {
ctx = context.Background()
}
start := time.Now()
dialer := net.Dialer{
Timeout: c.Timeout,
}
address := net.JoinHostPort(hostname, port)
netConn, err := dialer.DialContext(ctx, "tcp", address)
if err != nil {
return nil, err
}
conn := tls.Client(netConn, config)
// Set connection deadline
if c.Timeout != 0 {
err := conn.SetDeadline(start.Add(c.Timeout))
if err != nil {
return nil, fmt.Errorf("failed to set connection deadline: %w", err)
}
}
resp, err := c.do(conn, req)
if err != nil {
// If we fail to perform the request/response we have
// to take responsibility for closing the connection.
_ = conn.Close()
return nil, err
}
// Store connection state
state := conn.ConnectionState()
resp.TLS = &state
return resp, nil
} }
func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) { type result struct {
resp *Response
err error
}
func (c *Client) do(conn net.Conn, req *Request) result {
// Write the request // Write the request
err := req.Write(conn) if err := req.Write(conn); err != nil {
if err != nil { return result{nil, err}
return nil, fmt.Errorf("failed to write request: %w", err)
} }
// Read the response // Read the response
resp, err := ReadResponse(conn) resp, err := ReadResponse(conn)
if err != nil { if err != nil {
return nil, err return result{nil, err}
} }
return resp, nil // Store TLS connection state
if tlsConn, ok := conn.(*tls.Conn); ok {
state := tlsConn.ConnectionState()
resp.TLS = &state
}
return result{resp, nil}
} }
func (c *Client) verifyConnection(hostname, punycode string, cs tls.ConnectionState) error { func (c *Client) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if c.DialContext != nil {
return c.DialContext(ctx, network, addr)
}
return (&net.Dialer{
Timeout: c.Timeout,
}).DialContext(ctx, network, addr)
}
func (c *Client) verifyConnection(cs tls.ConnectionState, hostname string) error {
cert := cs.PeerCertificates[0] cert := cs.PeerCertificates[0]
// Verify punycoded hostname // Verify hostname
if err := verifyHostname(cert, punycode); err != nil { if err := verifyHostname(cert, hostname); err != nil {
return err return err
} }
// Check expiration date // Check expiration date
if !time.Now().Before(cert.NotAfter) { if !time.Now().Before(cert.NotAfter) {
return errors.New("gemini: certificate expired") return errors.New("gemini: certificate expired")
} }
// See if the client trusts the certificate // See if the client trusts the certificate
if c.TrustCertificate != nil { if c.TrustCertificate != nil {
return c.TrustCertificate(hostname, cert) return c.TrustCertificate(hostname, cert)
} }
return nil return nil
} }
func splitHostPort(hostport string) (host, port string) {
var err error
host, port, err = net.SplitHostPort(hostport)
if err != nil {
// Likely no port
host = hostport
port = "1965"
}
return
}

View File

@ -2,7 +2,6 @@ package gemini
import ( import (
"bufio" "bufio"
"context"
"crypto/tls" "crypto/tls"
"io" "io"
"net/url" "net/url"
@ -48,14 +47,6 @@ type Request struct {
// otherwise it leaves the field nil. // otherwise it leaves the field nil.
// This field is ignored by the Gemini client. // This field is ignored by the Gemini client.
TLS *tls.ConnectionState TLS *tls.ConnectionState
// Context specifies the context to use for outgoing requests.
// The context controls the entire lifetime of a request and its
// response: obtaining a connection, sending the request, and
// reading the response header and body.
// If Context is nil, the background context will be used.
// This field is ignored by the Gemini server.
Context context.Context
} }
// NewRequest returns a new request. // NewRequest returns a new request.

View File

@ -28,13 +28,8 @@ type timeoutHandler struct {
} }
func (t *timeoutHandler) ServeGemini(w ResponseWriter, r *Request) { func (t *timeoutHandler) ServeGemini(w ResponseWriter, r *Request) {
ctx := r.Context ctx, cancel := context.WithTimeout(context.TODO(), t.dt)
if ctx == nil { defer cancel()
ctx = context.Background()
}
var cancelCtx func()
ctx, cancelCtx = context.WithTimeout(ctx, t.dt)
defer cancelCtx()
done := make(chan struct{}) done := make(chan struct{})
tw := &timeoutWriter{} tw := &timeoutWriter{}