Refactor client requests and server responses

This commit is contained in:
adnano 2020-09-21 21:31:09 -04:00
parent 5a0a515874
commit 01294f6661
2 changed files with 60 additions and 33 deletions

View File

@ -1,7 +1,7 @@
package gemini package gemini
import ( import (
"bufio" "bytes"
"crypto/tls" "crypto/tls"
"errors" "errors"
"io/ioutil" "io/ioutil"
@ -11,7 +11,8 @@ import (
) )
var ( var (
ErrProtocol = errors.New("Protocol error") ErrProtocol = errors.New("Protocol error")
ErrInvalidURL = errors.New("Invalid URL")
) )
// Client is a Gemini client. // Client is a Gemini client.
@ -19,6 +20,10 @@ type Client struct{}
// Request makes a request for the provided URL. The host is inferred from the URL. // Request makes a request for the provided URL. The host is inferred from the URL.
func (c *Client) Request(url string) (*Response, error) { func (c *Client) Request(url string) (*Response, error) {
if len(url) > 1024 {
return nil, ErrInvalidURL
}
req, err := NewRequest(url) req, err := NewRequest(url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -28,6 +33,10 @@ func (c *Client) Request(url string) (*Response, error) {
// ProxyRequest requests the provided URL from the provided host. // ProxyRequest requests the provided URL from the provided host.
func (c *Client) ProxyRequest(host, url string) (*Response, error) { func (c *Client) ProxyRequest(host, url string) (*Response, error) {
if len(url) > 1024 {
return nil, ErrInvalidURL
}
req, err := NewProxyRequest(host, url) req, err := NewProxyRequest(host, url)
if err != nil { if err != nil {
return nil, err return nil, err
@ -92,8 +101,8 @@ func (c *Client) Do(req *Request) (*Response, error) {
host += ":1965" host += ":1965"
} }
config := tls.Config{}
// Allow self signed certificates // Allow self signed certificates
config := tls.Config{}
config.InsecureSkipVerify = true config.InsecureSkipVerify = true
config.Certificates = req.Certificates config.Certificates = req.Certificates
@ -109,36 +118,43 @@ func (c *Client) Do(req *Request) (*Response, error) {
return nil, err return nil, err
} }
buf := bufio.NewReader(conn) // Read the response
b, err := ioutil.ReadAll(conn)
// Read the response header if err != nil {
code := make([]byte, 2)
if _, err := buf.Read(code); err != nil {
return nil, err return nil, err
} }
status, err := strconv.Atoi(string(code))
// Ensure that the response is long enough
// The minimum response: <STATUS><SPACE><CR><LF> (5 bytes)
if len(b) < 5 {
return nil, ErrProtocol
}
// Parse the response header
status, err := strconv.Atoi(string(b[:2]))
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Read one space // Read one space
if space, err := buf.ReadByte(); err != nil { if b[2] != ' ' {
return nil, err return nil, ErrProtocol
} else if space != ' ' { }
// Find the first <CR><LF>
i := bytes.Index(b, []byte("\r\n"))
if i < 3 {
return nil, ErrProtocol return nil, ErrProtocol
} }
// Read the meta // Read the meta
meta, err := readLine(buf) meta := string(b[3:i])
if err != nil { if len(meta) > 1024 {
return nil, err return nil, ErrProtocol
} }
// Read the response body // Read the response body
body, err := ioutil.ReadAll(buf) body := b[i+2:]
if err != nil {
return nil, err
}
return &Response{ return &Response{
Status: status, Status: status,

View File

@ -92,22 +92,33 @@ func (s *Server) Serve(ln net.Listener) error {
return err return err
} }
req, err := readLine(rw) var resp *Response
if err != nil {
continue if rawurl, err := readLine(rw); err != nil {
} resp = &Response{
url, err := url.Parse(req) Status: StatusBadRequest,
if err != nil { Meta: "Bad request",
continue }
} else if len(rawurl) > 1024 {
resp = &Response{
Status: StatusBadRequest,
Meta: "URL exceeds 1024 bytes",
}
} else if url, err := url.Parse(rawurl); err != nil {
resp = &Response{
Status: StatusBadRequest,
Meta: "Invalid URL",
}
} else {
// Gather information about the request
reqInfo := &RequestInfo{
URL: url,
Certificates: rw.(*tls.Conn).ConnectionState().PeerCertificates,
RemoteAddr: rw.RemoteAddr(),
}
resp = s.Handler.Serve(reqInfo)
} }
// Gather information about the request
reqInfo := &RequestInfo{
URL: url,
Certificates: rw.(*tls.Conn).ConnectionState().PeerCertificates,
RemoteAddr: rw.RemoteAddr(),
}
resp := s.Handler.Serve(reqInfo)
resp.Write(rw) resp.Write(rw)
rw.Close() rw.Close()
} }