diff --git a/client.go b/client.go index 5afebdd..f64f371 100644 --- a/client.go +++ b/client.go @@ -68,24 +68,49 @@ func (c *Client) Do(req *Request) (*Response, error) { if ctx == nil { ctx = context.Background() } - netConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", req.Host) + + start := time.Now() + dialer := net.Dialer{ + Timeout: c.Timeout, + } + + netConn, err := dialer.DialContext(ctx, "tcp", req.Host) if err != nil { return nil, err } + conn := tls.Client(netConn, config) + // Set connection deadline if c.Timeout != 0 { - err := conn.SetDeadline(time.Now().Add(c.Timeout)) + err := conn.SetDeadline(start.Add(c.Timeout)) if err != nil { return nil, fmt.Errorf( "failed to set connection deadline: %w", err) } } + resp, err := c.do(conn, req) + if err != nil { + // If we fail to perform the request/response we have + // to take responsibility for closing the connection. + _ = conn.Close() + + return nil, err + } + + // Store connection state + resp.TLS = conn.ConnectionState() + + return resp, nil +} + +func (c *Client) do(conn *tls.Conn, req *Request) (*Response, error) { + // Write the request w := bufio.NewWriter(conn) - err = req.Write(w) + err := req.Write(w) if err != nil { return nil, fmt.Errorf( "failed to write request data: %w", err) @@ -100,8 +125,6 @@ func (c *Client) Do(req *Request) (*Response, error) { if err != nil { return nil, err } - // Store connection state - resp.TLS = conn.ConnectionState() return resp, nil }