diff --git a/client.go b/client.go index 913a655..7585dfc 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,7 @@ package gemini import ( - "bufio" + "bytes" "crypto/tls" "errors" "io/ioutil" @@ -11,7 +11,8 @@ import ( ) var ( - ErrProtocol = errors.New("Protocol error") + ErrProtocol = errors.New("Protocol error") + ErrInvalidURL = errors.New("Invalid URL") ) // Client is a Gemini client. @@ -19,6 +20,10 @@ type Client struct{} // Request makes a request for the provided URL. The host is inferred from the URL. func (c *Client) Request(url string) (*Response, error) { + if len(url) > 1024 { + return nil, ErrInvalidURL + } + req, err := NewRequest(url) if err != nil { return nil, err @@ -28,6 +33,10 @@ func (c *Client) Request(url string) (*Response, error) { // ProxyRequest requests the provided URL from the provided host. func (c *Client) ProxyRequest(host, url string) (*Response, error) { + if len(url) > 1024 { + return nil, ErrInvalidURL + } + req, err := NewProxyRequest(host, url) if err != nil { return nil, err @@ -92,8 +101,8 @@ func (c *Client) Do(req *Request) (*Response, error) { host += ":1965" } - config := tls.Config{} // Allow self signed certificates + config := tls.Config{} config.InsecureSkipVerify = true config.Certificates = req.Certificates @@ -109,36 +118,43 @@ func (c *Client) Do(req *Request) (*Response, error) { return nil, err } - buf := bufio.NewReader(conn) - - // Read the response header - code := make([]byte, 2) - if _, err := buf.Read(code); err != nil { + // Read the response + b, err := ioutil.ReadAll(conn) + if err != nil { return nil, err } - status, err := strconv.Atoi(string(code)) + + // Ensure that the response is long enough + // The minimum response: (5 bytes) + if len(b) < 5 { + return nil, ErrProtocol + } + + // Parse the response header + status, err := strconv.Atoi(string(b[:2])) if err != nil { return nil, err } // Read one space - if space, err := buf.ReadByte(); err != nil { - return nil, err - } else if space != ' ' { + if b[2] != ' ' { + return nil, ErrProtocol + } + + // Find the first + i := bytes.Index(b, []byte("\r\n")) + if i < 3 { return nil, ErrProtocol } // Read the meta - meta, err := readLine(buf) - if err != nil { - return nil, err + meta := string(b[3:i]) + if len(meta) > 1024 { + return nil, ErrProtocol } // Read the response body - body, err := ioutil.ReadAll(buf) - if err != nil { - return nil, err - } + body := b[i+2:] return &Response{ Status: status, diff --git a/server.go b/server.go index 4cd7f1e..ad93738 100644 --- a/server.go +++ b/server.go @@ -92,22 +92,33 @@ func (s *Server) Serve(ln net.Listener) error { return err } - req, err := readLine(rw) - if err != nil { - continue - } - url, err := url.Parse(req) - if err != nil { - continue + var resp *Response + + if rawurl, err := readLine(rw); err != nil { + resp = &Response{ + Status: StatusBadRequest, + Meta: "Bad request", + } + } else if len(rawurl) > 1024 { + resp = &Response{ + Status: StatusBadRequest, + Meta: "URL exceeds 1024 bytes", + } + } else if url, err := url.Parse(rawurl); err != nil { + resp = &Response{ + Status: StatusBadRequest, + Meta: "Invalid URL", + } + } else { + // Gather information about the request + reqInfo := &RequestInfo{ + URL: url, + Certificates: rw.(*tls.Conn).ConnectionState().PeerCertificates, + RemoteAddr: rw.RemoteAddr(), + } + resp = s.Handler.Serve(reqInfo) } - // Gather information about the request - reqInfo := &RequestInfo{ - URL: url, - Certificates: rw.(*tls.Conn).ConnectionState().PeerCertificates, - RemoteAddr: rw.RemoteAddr(), - } - resp := s.Handler.Serve(reqInfo) resp.Write(rw) rw.Close() }