diff --git a/gemini.go b/gemini.go index 25160f9..f4b211e 100644 --- a/gemini.go +++ b/gemini.go @@ -11,8 +11,6 @@ func init() { mime.AddExtensionType(".gemini", "text/gemini") } -var crlf = []byte("\r\n") - // Errors. var ( ErrInvalidRequest = errors.New("gemini: invalid request") @@ -22,3 +20,15 @@ var ( // when the response status code does not permit a body. ErrBodyNotAllowed = errors.New("gemini: response status code does not allow body") ) + +var crlf = []byte("\r\n") + +func trimCRLF(b []byte) ([]byte, bool) { + // Check for CR + if len(b) < 2 || b[len(b)-2] != '\r' { + return nil, false + } + // Trim CRLF + b = b[:len(b)-2] + return b, true +} diff --git a/request.go b/request.go index e87d2ca..0613716 100644 --- a/request.go +++ b/request.go @@ -51,26 +51,25 @@ func NewRequest(rawurl string) (*Request, error) { // for specialized applications; most code should use the Server // to read requests and handle them via the Handler interface. func ReadRequest(r io.Reader) (*Request, error) { - // Read URL + // Limit request size r = io.LimitReader(r, 1026) br := bufio.NewReaderSize(r, 1026) - rawurl, err := br.ReadString('\r') + b, err := br.ReadBytes('\n') if err != nil { + if err == io.EOF { + return nil, ErrInvalidRequest + } return nil, err } - // Read terminating line feed - if b, err := br.ReadByte(); err != nil { - return nil, err - } else if b != '\n' { + // Read URL + rawurl, ok := trimCRLF(b) + if !ok { return nil, ErrInvalidRequest } - // Trim carriage return - rawurl = rawurl[:len(rawurl)-1] - // Validate URL - if len(rawurl) > 1024 { + if len(rawurl) == 0 { return nil, ErrInvalidRequest } - u, err := url.Parse(rawurl) + u, err := url.Parse(string(rawurl)) if err != nil { return nil, err } diff --git a/request_test.go b/request_test.go index 90ca57a..3471930 100644 --- a/request_test.go +++ b/request_test.go @@ -2,7 +2,6 @@ package gemini import ( "bufio" - "io" "net/url" "strings" "testing" @@ -36,25 +35,25 @@ func TestReadRequest(t *testing.T) { }, { Raw: "\r\n", - URL: &url.URL{}, + Err: ErrInvalidRequest, }, { Raw: "gemini://example.com\n", - Err: io.EOF, + Err: ErrInvalidRequest, }, { Raw: "gemini://example.com", - Err: io.EOF, + Err: ErrInvalidRequest, }, { // 1030 bytes Raw: maxURL + "xxxxxx", - Err: io.EOF, + Err: ErrInvalidRequest, }, { // 1027 bytes Raw: maxURL + "x" + "\r\n", - Err: io.EOF, + Err: ErrInvalidRequest, }, { // 1024 bytes diff --git a/response.go b/response.go index 03d65ef..48d2e53 100644 --- a/response.go +++ b/response.go @@ -46,43 +46,39 @@ func ReadResponse(r io.ReadCloser) (*Response, error) { resp := &Response{} br := bufio.NewReader(r) - // Read the status - statusB := make([]byte, 2) - if _, err := br.Read(statusB); err != nil { + // Read response header + b, err := br.ReadBytes('\n') + if err != nil { + if err == io.EOF { + return nil, ErrInvalidResponse + } return nil, err } - status, err := strconv.Atoi(string(statusB)) + if len(b) < 3 { + return nil, ErrInvalidResponse + } + + // Read the status + status, err := strconv.Atoi(string(b[:2])) if err != nil { return nil, ErrInvalidResponse } resp.Status = Status(status) // Read one space - if b, err := br.ReadByte(); err != nil { - return nil, err - } else if b != ' ' { + if b[2] != ' ' { return nil, ErrInvalidResponse } // Read the meta - meta, err := br.ReadString('\r') - if err != nil { - return nil, err + meta, ok := trimCRLF(b[3:]) + if !ok { + return nil, ErrInvalidResponse } - // Trim carriage return - meta = meta[:len(meta)-1] - // Ensure meta is less than or equal to 1024 bytes if len(meta) == 0 || len(meta) > 1024 { return nil, ErrInvalidResponse } - resp.Meta = meta - - // Read terminating newline - if b, err := br.ReadByte(); err != nil { - return nil, err - } else if b != '\n' { - return nil, ErrInvalidResponse - } + resp.Meta = string(meta) if resp.Status.Class() == StatusSuccess { resp.Body = newBufReadCloser(br, r) diff --git a/response_test.go b/response_test.go index ccd279b..d5e6a6b 100644 --- a/response_test.go +++ b/response_test.go @@ -65,15 +65,15 @@ func TestReadWriteResponse(t *testing.T) { }, { Raw: "", - Err: io.EOF, + Err: ErrInvalidResponse, }, { Raw: "10 Search query", - Err: io.EOF, + Err: ErrInvalidResponse, }, { Raw: "20 text/gemini\nHello, world!", - Err: io.EOF, + Err: ErrInvalidResponse, }, { Raw: "20 text/gemini\rHello, world!", @@ -81,7 +81,7 @@ func TestReadWriteResponse(t *testing.T) { }, { Raw: "20 text/gemini\r", - Err: io.EOF, + Err: ErrInvalidResponse, }, { Raw: "abcdefghijklmnopqrstuvwxyz",