diff --git a/client.go b/client.go index 8c94f23..4372464 100644 --- a/client.go +++ b/client.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "io" "net" "net/url" "time" @@ -124,7 +125,22 @@ func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) { res := make(chan result, 1) go func() { - resp, err := c.do(conn, req) + ctx, cancel := context.WithCancel(ctx) + done := ctx.Done() + cw := &contextWriter{ + ctx: ctx, + done: done, + cancel: cancel, + wc: conn, + } + cr := &contextReader{ + ctx: ctx, + done: done, + cancel: cancel, + rc: conn, + } + + resp, err := c.do(cw, cr, req) res <- result{resp, err} }() @@ -137,14 +153,14 @@ func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) { } } -func (c *Client) do(conn net.Conn, req *Request) (*Response, error) { +func (c *Client) do(w io.Writer, rc io.ReadCloser, req *Request) (*Response, error) { // Write the request - if err := req.Write(conn); err != nil { + if err := req.Write(w); err != nil { return nil, err } // Read the response - resp, err := ReadResponse(conn) + resp, err := ReadResponse(rc) if err != nil { return nil, err } @@ -206,3 +222,55 @@ func punycodeHostname(hostname string) (string, error) { } return idna.Lookup.ToASCII(hostname) } + +type contextReader struct { + ctx context.Context + done <-chan struct{} + cancel func() + rc io.ReadCloser +} + +func (r *contextReader) Read(p []byte) (int, error) { + select { + case <-r.done: + r.rc.Close() + return 0, r.ctx.Err() + default: + } + n, err := r.rc.Read(p) + if err != nil { + r.cancel() + } + return n, err +} + +func (r *contextReader) Close() error { + r.cancel() + return r.rc.Close() +} + +type contextWriter struct { + ctx context.Context + done <-chan struct{} + cancel func() + wc io.WriteCloser +} + +func (w *contextWriter) Write(b []byte) (int, error) { + select { + case <-w.done: + w.wc.Close() + return 0, w.ctx.Err() + default: + } + n, err := w.wc.Write(b) + if err != nil { + w.cancel() + } + return n, err +} + +func (w *contextWriter) Close() error { + w.cancel() + return w.wc.Close() +}