From 7d470c5fb1c8b1bfc2c734d35acdda755e30965d Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Sat, 31 Oct 2020 21:07:02 -0400 Subject: [PATCH] Implement Server read and write timeouts --- client.go | 13 ++----------- examples/server.go | 2 ++ server.go | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index 8af997d..4e6b926 100644 --- a/client.go +++ b/client.go @@ -65,13 +65,6 @@ func (c *Client) Do(req *Request) (*Response, error) { return c.do(req, nil) } -func (c *Client) deadline() time.Time { - if c.Timeout > 0 { - return time.Now().Add(c.Timeout) - } - return time.Time{} -} - func (c *Client) do(req *Request, via []*Request) (*Response, error) { // Connect to the host config := &tls.Config{ @@ -89,10 +82,8 @@ func (c *Client) do(req *Request, via []*Request) (*Response, error) { return nil, err } // Set connection deadline - if deadline := c.deadline(); !deadline.IsZero() { - if err := conn.SetDeadline(deadline); err != nil { - return nil, err - } + if d := c.Timeout; d != 0 { + conn.SetDeadline(time.Now().Add(d)) } // Write the request diff --git a/examples/server.go b/examples/server.go index c26e66b..75aac41 100644 --- a/examples/server.go +++ b/examples/server.go @@ -12,6 +12,8 @@ import ( func main() { var server gemini.Server + server.ReadTimeout = 1 * time.Minute + server.WriteTimeout = 2 * time.Minute if err := server.Certificates.Load("/var/lib/gemini/certs"); err != nil { log.Fatal(err) } diff --git a/server.go b/server.go index 57bfaa0..4fbd73a 100644 --- a/server.go +++ b/server.go @@ -21,6 +21,13 @@ type Server struct { // Certificates contains the certificates used by the server. Certificates CertificateStore + // ReadTimeout is the maximum duration for reading a request. + ReadTimeout time.Duration + + // WriteTimeout is the maximum duration before timing out + // writes of the response. + WriteTimeout time.Duration + // CreateCertificate, if not nil, will be called to create a new certificate // if the current one is expired or missing. CreateCertificate func(hostname string) (tls.Certificate, error) @@ -159,6 +166,13 @@ func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) { // respond responds to a connection. func (s *Server) respond(conn net.Conn) { + if d := s.ReadTimeout; d != 0 { + conn.SetReadDeadline(time.Now().Add(d)) + } + if d := s.WriteTimeout; d != 0 { + conn.SetWriteDeadline(time.Now().Add(d)) + } + r := bufio.NewReader(conn) w := newResponseWriter(conn) // Read requested URL