client: Make Get and Do accept a Context
This removes the need for Request.Context.
This commit is contained in:
parent
a7f958b20d
commit
3f4fd10b6d
173
client.go
173
client.go
@ -5,8 +5,8 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,6 +28,10 @@ type Client struct {
|
|||||||
//
|
//
|
||||||
// A Timeout of zero means no timeout.
|
// A Timeout of zero means no timeout.
|
||||||
Timeout time.Duration
|
Timeout time.Duration
|
||||||
|
|
||||||
|
// DialContext specifies the dial function for creating TCP connections.
|
||||||
|
// If DialContext is nil, the client dials using package net.
|
||||||
|
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get sends a Gemini request for the given URL.
|
// Get sends a Gemini request for the given URL.
|
||||||
@ -39,12 +43,12 @@ type Client struct {
|
|||||||
// which the user is expected to close.
|
// which the user is expected to close.
|
||||||
//
|
//
|
||||||
// For more control over requests, use NewRequest and Client.Do.
|
// For more control over requests, use NewRequest and Client.Do.
|
||||||
func (c *Client) Get(url string) (*Response, error) {
|
func (c *Client) Get(ctx context.Context, url string) (*Response, error) {
|
||||||
req, err := NewRequest(url)
|
req, err := NewRequest(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return c.Do(req)
|
return c.Do(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do sends a Gemini request and returns a Gemini response, following
|
// Do sends a Gemini request and returns a Gemini response, following
|
||||||
@ -57,48 +61,56 @@ func (c *Client) Get(url string) (*Response, error) {
|
|||||||
// which the user is expected to close.
|
// which the user is expected to close.
|
||||||
//
|
//
|
||||||
// Generally Get will be used instead of Do.
|
// Generally Get will be used instead of Do.
|
||||||
func (c *Client) Do(req *Request) (*Response, error) {
|
func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
|
||||||
// Punycode request URL host
|
if ctx == nil {
|
||||||
hostname, port, err := net.SplitHostPort(req.URL.Host)
|
panic("nil context")
|
||||||
if err != nil {
|
|
||||||
// Likely no port
|
|
||||||
hostname = req.URL.Host
|
|
||||||
port = "1965"
|
|
||||||
}
|
}
|
||||||
punycode, err := punycodeHostname(hostname)
|
|
||||||
|
// Punycode request URL host
|
||||||
|
host, port := splitHostPort(req.URL.Host)
|
||||||
|
punycode, err := punycodeHostname(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if hostname != punycode {
|
if host != punycode {
|
||||||
hostname = punycode
|
host = punycode
|
||||||
|
|
||||||
// Make a copy of the request
|
// Make a copy of the request
|
||||||
_req := *req
|
r2 := new(Request)
|
||||||
req = &_req
|
*r2 = *req
|
||||||
_url := *req.URL
|
r2.URL = new(url.URL)
|
||||||
req.URL = &_url
|
*r2.URL = *req.URL
|
||||||
|
req = r2
|
||||||
|
|
||||||
// Set the host
|
// Set the host
|
||||||
req.URL.Host = net.JoinHostPort(hostname, port)
|
req.URL.Host = net.JoinHostPort(host, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use request host if provided
|
// Use request host if provided
|
||||||
if req.Host != "" {
|
if req.Host != "" {
|
||||||
hostname, port, err = net.SplitHostPort(req.Host)
|
host, port = splitHostPort(req.Host)
|
||||||
if err != nil {
|
host, err = punycodeHostname(host)
|
||||||
// Likely no port
|
|
||||||
hostname = req.Host
|
|
||||||
port = "1965"
|
|
||||||
}
|
|
||||||
// Punycode hostname
|
|
||||||
hostname, err = punycodeHostname(hostname)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addr := net.JoinHostPort(host, port)
|
||||||
|
|
||||||
// Connect to the host
|
// Connect to the host
|
||||||
config := &tls.Config{
|
start := time.Now()
|
||||||
|
conn, err := c.dialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the connection deadline
|
||||||
|
if c.Timeout != 0 {
|
||||||
|
conn.SetDeadline(start.Add(c.Timeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup TLS
|
||||||
|
conn = tls.Client(conn, &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
GetClientCertificate: func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
|
||||||
@ -108,83 +120,84 @@ func (c *Client) Do(req *Request) (*Response, error) {
|
|||||||
return &tls.Certificate{}, nil
|
return &tls.Certificate{}, nil
|
||||||
},
|
},
|
||||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||||
return c.verifyConnection(hostname, punycode, cs)
|
return c.verifyConnection(cs, host)
|
||||||
},
|
},
|
||||||
ServerName: hostname,
|
ServerName: host,
|
||||||
|
})
|
||||||
|
|
||||||
|
res := make(chan result, 1)
|
||||||
|
go func() {
|
||||||
|
res <- c.do(conn, req)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
conn.Close()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case r := <-res:
|
||||||
|
return r.resp, r.err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := req.Context
|
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
dialer := net.Dialer{
|
|
||||||
Timeout: c.Timeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
address := net.JoinHostPort(hostname, port)
|
|
||||||
netConn, err := dialer.DialContext(ctx, "tcp", address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn := tls.Client(netConn, config)
|
|
||||||
|
|
||||||
// Set connection deadline
|
|
||||||
if c.Timeout != 0 {
|
|
||||||
err := conn.SetDeadline(start.Add(c.Timeout))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set connection deadline: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := c.do(conn, req)
|
|
||||||
if err != nil {
|
|
||||||
// If we fail to perform the request/response we have
|
|
||||||
// to take responsibility for closing the connection.
|
|
||||||
_ = conn.Close()
|
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store connection state
|
|
||||||
state := conn.ConnectionState()
|
|
||||||
resp.TLS = &state
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) {
|
type result struct {
|
||||||
|
resp *Response
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) do(conn net.Conn, req *Request) result {
|
||||||
// Write the request
|
// Write the request
|
||||||
err := req.Write(conn)
|
if err := req.Write(conn); err != nil {
|
||||||
if err != nil {
|
return result{nil, err}
|
||||||
return nil, fmt.Errorf("failed to write request: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the response
|
// Read the response
|
||||||
resp, err := ReadResponse(conn)
|
resp, err := ReadResponse(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return result{nil, err}
|
||||||
}
|
}
|
||||||
|
|
||||||
return resp, nil
|
// Store TLS connection state
|
||||||
|
if tlsConn, ok := conn.(*tls.Conn); ok {
|
||||||
|
state := tlsConn.ConnectionState()
|
||||||
|
resp.TLS = &state
|
||||||
|
}
|
||||||
|
|
||||||
|
return result{resp, nil}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) verifyConnection(hostname, punycode string, cs tls.ConnectionState) error {
|
func (c *Client) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
if c.DialContext != nil {
|
||||||
|
return c.DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
return (&net.Dialer{
|
||||||
|
Timeout: c.Timeout,
|
||||||
|
}).DialContext(ctx, network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) verifyConnection(cs tls.ConnectionState, hostname string) error {
|
||||||
cert := cs.PeerCertificates[0]
|
cert := cs.PeerCertificates[0]
|
||||||
// Verify punycoded hostname
|
// Verify hostname
|
||||||
if err := verifyHostname(cert, punycode); err != nil {
|
if err := verifyHostname(cert, hostname); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Check expiration date
|
// Check expiration date
|
||||||
if !time.Now().Before(cert.NotAfter) {
|
if !time.Now().Before(cert.NotAfter) {
|
||||||
return errors.New("gemini: certificate expired")
|
return errors.New("gemini: certificate expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
// See if the client trusts the certificate
|
// See if the client trusts the certificate
|
||||||
if c.TrustCertificate != nil {
|
if c.TrustCertificate != nil {
|
||||||
return c.TrustCertificate(hostname, cert)
|
return c.TrustCertificate(hostname, cert)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func splitHostPort(hostport string) (host, port string) {
|
||||||
|
var err error
|
||||||
|
host, port, err = net.SplitHostPort(hostport)
|
||||||
|
if err != nil {
|
||||||
|
// Likely no port
|
||||||
|
host = hostport
|
||||||
|
port = "1965"
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -2,7 +2,6 @@ package gemini
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -48,14 +47,6 @@ type Request struct {
|
|||||||
// otherwise it leaves the field nil.
|
// otherwise it leaves the field nil.
|
||||||
// This field is ignored by the Gemini client.
|
// This field is ignored by the Gemini client.
|
||||||
TLS *tls.ConnectionState
|
TLS *tls.ConnectionState
|
||||||
|
|
||||||
// Context specifies the context to use for outgoing requests.
|
|
||||||
// The context controls the entire lifetime of a request and its
|
|
||||||
// response: obtaining a connection, sending the request, and
|
|
||||||
// reading the response header and body.
|
|
||||||
// If Context is nil, the background context will be used.
|
|
||||||
// This field is ignored by the Gemini server.
|
|
||||||
Context context.Context
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRequest returns a new request.
|
// NewRequest returns a new request.
|
||||||
|
@ -28,13 +28,8 @@ type timeoutHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *timeoutHandler) ServeGemini(w ResponseWriter, r *Request) {
|
func (t *timeoutHandler) ServeGemini(w ResponseWriter, r *Request) {
|
||||||
ctx := r.Context
|
ctx, cancel := context.WithTimeout(context.TODO(), t.dt)
|
||||||
if ctx == nil {
|
defer cancel()
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
var cancelCtx func()
|
|
||||||
ctx, cancelCtx = context.WithTimeout(ctx, t.dt)
|
|
||||||
defer cancelCtx()
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
tw := &timeoutWriter{}
|
tw := &timeoutWriter{}
|
||||||
|
Loading…
Reference in New Issue
Block a user