From d34d5df89e9a0d1cf8b3f8e95ce10d08ff91600b Mon Sep 17 00:00:00 2001 From: Adnan Maolood Date: Fri, 18 Dec 2020 01:41:14 -0500 Subject: [PATCH] Add ReadRequest and ReadResponse functions --- client.go | 4 ++-- gemini.go | 1 + request.go | 32 +++++++++++++++++++++++++++++++ response.go | 26 +++++++++++++------------ server.go | 55 +++++++++++++++-------------------------------------- 5 files changed, 64 insertions(+), 54 deletions(-) diff --git a/client.go b/client.go index 6464255..a0f571a 100644 --- a/client.go +++ b/client.go @@ -84,8 +84,8 @@ func (c *Client) Do(req *Request) (*Response, error) { } // Read the response - resp := &Response{} - if err := resp.Read(conn); err != nil { + resp, err := ReadResponse(conn) + if err != nil { return nil, err } // Store connection state diff --git a/gemini.go b/gemini.go index 3623fd0..c756845 100644 --- a/gemini.go +++ b/gemini.go @@ -58,6 +58,7 @@ var crlf = []byte("\r\n") // Errors. var ( ErrInvalidURL = errors.New("gemini: invalid URL") + ErrInvalidRequest = errors.New("gemini: invalid request") ErrInvalidResponse = errors.New("gemini: invalid response") ErrBodyNotAllowed = errors.New("gemini: response body not allowed") ) diff --git a/request.go b/request.go index a18dcab..8313e81 100644 --- a/request.go +++ b/request.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "crypto/tls" + "io" "net" "net/url" ) @@ -65,6 +66,37 @@ func NewRequestFromURL(url *url.URL) *Request { } } +// ReadRequest reads a Gemini request from the provided io.Reader +func ReadRequest(r io.Reader) (*Request, error) { + // Read URL + br := bufio.NewReader(r) + rawurl, err := br.ReadString('\r') + if err != nil { + return nil, err + } + // Read terminating line feed + if b, err := br.ReadByte(); err != nil { + return nil, err + } else if b != '\n' { + return nil, ErrInvalidRequest + } + // Trim carriage return + rawurl = rawurl[:len(rawurl)-1] + // Validate URL + if len(rawurl) > 1024 { + return nil, ErrInvalidRequest + } + u, err := url.Parse(rawurl) + if err != nil { + return nil, err + } + if u.User != nil { + // User is not allowed + return nil, ErrInvalidURL + } + return &Request{URL: u}, nil +} + // Write writes the Gemini request to the provided buffered writer. func (r *Request) Write(w *bufio.Writer) error { url := r.URL.String() diff --git a/response.go b/response.go index 88baeeb..568507b 100644 --- a/response.go +++ b/response.go @@ -26,17 +26,19 @@ type Response struct { TLS tls.ConnectionState } -// Read reads a Gemini response from the provided io.ReadCloser. -func (resp *Response) Read(rc io.ReadCloser) error { +// ReadResponse reads a Gemini response from the provided io.ReadCloser. +func ReadResponse(rc io.ReadCloser) (*Response, error) { + resp := &Response{} br := bufio.NewReader(rc) + // Read the status statusB := make([]byte, 2) if _, err := br.Read(statusB); err != nil { - return err + return nil, err } status, err := strconv.Atoi(string(statusB)) if err != nil { - return err + return nil, err } resp.Status = Status(status) @@ -44,26 +46,26 @@ func (resp *Response) Read(rc io.ReadCloser) error { const minStatus, maxStatus = 1, 6 statusClass := resp.Status.Class() if statusClass < minStatus || statusClass > maxStatus { - return ErrInvalidResponse + return nil, ErrInvalidResponse } // Read one space if b, err := br.ReadByte(); err != nil { - return err + return nil, err } else if b != ' ' { - return ErrInvalidResponse + return nil, ErrInvalidResponse } // Read the meta meta, err := br.ReadString('\r') if err != nil { - return err + return nil, err } // Trim carriage return meta = meta[:len(meta)-1] // Ensure meta is less than or equal to 1024 bytes if len(meta) > 1024 { - return ErrInvalidResponse + return nil, ErrInvalidResponse } // Default mime type of text/gemini; charset=utf-8 if statusClass == StatusClassSuccess && meta == "" { @@ -73,15 +75,15 @@ func (resp *Response) Read(rc io.ReadCloser) error { // Read terminating newline if b, err := br.ReadByte(); err != nil { - return err + return nil, err } else if b != '\n' { - return ErrInvalidResponse + return nil, ErrInvalidResponse } if resp.Status.Class() == StatusClassSuccess { resp.Body = newReadCloserBody(br, rc) } - return nil + return resp, nil } type readCloserBody struct { diff --git a/server.go b/server.go index f589c19..3e069ae 100644 --- a/server.go +++ b/server.go @@ -6,7 +6,6 @@ import ( "errors" "log" "net" - "net/url" "strconv" "strings" "time" @@ -174,6 +173,7 @@ func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) { // respond responds to a connection. func (s *Server) respond(conn net.Conn) { + defer conn.Close() if d := s.ReadTimeout; d != 0 { conn.SetReadDeadline(time.Now().Add(d)) } @@ -181,58 +181,33 @@ func (s *Server) respond(conn net.Conn) { conn.SetWriteDeadline(time.Now().Add(d)) } - r := bufio.NewReader(conn) w := NewResponseWriter(conn) - // Read requested URL - rawurl, err := r.ReadString('\r') + defer w.b.Flush() + + req, err := ReadRequest(conn) if err != nil { - return - } - // Read terminating line feed - if b, err := r.ReadByte(); err != nil { - return - } else if b != '\n' { - w.WriteStatus(StatusBadRequest) - } - // Trim carriage return - rawurl = rawurl[:len(rawurl)-1] - // Ensure URL is valid - if len(rawurl) > 1024 { - w.WriteStatus(StatusBadRequest) - } else if url, err := url.Parse(rawurl); err != nil || url.User != nil { - // Note that we return an error status if User is specified in the URL w.WriteStatus(StatusBadRequest) } else { // Store information about the TLS connection - var connState tls.ConnectionState - var cert *tls.Certificate if tlsConn, ok := conn.(*tls.Conn); ok { - connState = tlsConn.ConnectionState() - if len(connState.PeerCertificates) > 0 { - peerCert := connState.PeerCertificates[0] + req.TLS = tlsConn.ConnectionState() + if len(req.TLS.PeerCertificates) > 0 { + peerCert := req.TLS.PeerCertificates[0] // Store the TLS certificate - cert = &tls.Certificate{ + req.Certificate = &tls.Certificate{ Certificate: [][]byte{peerCert.Raw}, Leaf: peerCert, } } } - - req := &Request{ - URL: url, - RemoteAddr: conn.RemoteAddr(), - TLS: connState, - Certificate: cert, - } - resp := s.responder(req) - if resp != nil { - resp.Respond(w, req) - } else { - w.WriteStatus(StatusNotFound) - } } - w.b.Flush() - conn.Close() + + resp := s.responder(req) + if resp != nil { + resp.Respond(w, req) + } else { + w.WriteStatus(StatusNotFound) + } } func (s *Server) responder(r *Request) Responder {