From 42c95f8c8d9cb69e5783a440d9b156ca1337a832 Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Sat, 31 Oct 2020 20:55:56 -0400 Subject: [PATCH] Implement Client connection timeout --- client.go | 23 ++++++++++++++++++++++- examples/client.go | 1 + 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 5d58895..8af997d 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "net/url" "path" "strings" + "time" ) // Client is a Gemini client. @@ -18,6 +19,14 @@ type Client struct { // Certificates stores client-side certificates. Certificates CertificateStore + // Timeout specifies a time limit for requests made by this + // Client. The timeout includes connection time and reading + // the response body. The timer remains running after + // Get and Do return and will interrupt reading of the Response.Body. + // + // A Timeout of zero means no timeout. + Timeout time.Duration + // GetInput is called to retrieve input when the server requests it. // If GetInput is nil or returns false, no input will be sent and // the response will be returned. @@ -56,6 +65,13 @@ func (c *Client) Do(req *Request) (*Response, error) { return c.do(req, nil) } +func (c *Client) deadline() time.Time { + if c.Timeout > 0 { + return time.Now().Add(c.Timeout) + } + return time.Time{} +} + func (c *Client) do(req *Request, via []*Request) (*Response, error) { // Connect to the host config := &tls.Config{ @@ -72,7 +88,12 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { if err != nil { return nil, err } - // TODO: Set connection deadline + // Set connection deadline + if deadline := c.deadline(); !deadline.IsZero() { + if err := conn.SetDeadline(deadline); err != nil { + return nil, err + } + } // Write the request w := bufio.NewWriter(conn) diff --git a/examples/client.go b/examples/client.go index 5ef9db6..1fea082 100644 --- a/examples/client.go +++ b/examples/client.go @@ -20,6 +20,7 @@ var ( ) func init() { + client.Timeout = 2 * time.Minute client.KnownHosts.LoadDefault() client.TrustCertificate = func(hostname string, cert *x509.Certificate, knownHosts *gemini.KnownHosts) error { err := knownHosts.Lookup(hostname, cert)