Refactor client requests and server responses

This commit is contained in:
adnano 2020-09-21 21:31:09 -04:00
parent 5a0a515874
commit 01294f6661
2 changed files with 60 additions and 33 deletions

View File

@ -1,7 +1,7 @@
package gemini
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"io/ioutil"
@ -11,7 +11,8 @@ import (
)
var (
ErrProtocol = errors.New("Protocol error")
ErrProtocol = errors.New("Protocol error")
ErrInvalidURL = errors.New("Invalid URL")
)
// Client is a Gemini client.
@ -19,6 +20,10 @@ type Client struct{}
// Request makes a request for the provided URL. The host is inferred from the URL.
func (c *Client) Request(url string) (*Response, error) {
if len(url) > 1024 {
return nil, ErrInvalidURL
}
req, err := NewRequest(url)
if err != nil {
return nil, err
@ -28,6 +33,10 @@ func (c *Client) Request(url string) (*Response, error) {
// ProxyRequest requests the provided URL from the provided host.
func (c *Client) ProxyRequest(host, url string) (*Response, error) {
if len(url) > 1024 {
return nil, ErrInvalidURL
}
req, err := NewProxyRequest(host, url)
if err != nil {
return nil, err
@ -92,8 +101,8 @@ func (c *Client) Do(req *Request) (*Response, error) {
host += ":1965"
}
config := tls.Config{}
// Allow self signed certificates
config := tls.Config{}
config.InsecureSkipVerify = true
config.Certificates = req.Certificates
@ -109,36 +118,43 @@ func (c *Client) Do(req *Request) (*Response, error) {
return nil, err
}
buf := bufio.NewReader(conn)
// Read the response header
code := make([]byte, 2)
if _, err := buf.Read(code); err != nil {
// Read the response
b, err := ioutil.ReadAll(conn)
if err != nil {
return nil, err
}
status, err := strconv.Atoi(string(code))
// Ensure that the response is long enough
// The minimum response: <STATUS><SPACE><CR><LF> (5 bytes)
if len(b) < 5 {
return nil, ErrProtocol
}
// Parse the response header
status, err := strconv.Atoi(string(b[:2]))
if err != nil {
return nil, err
}
// Read one space
if space, err := buf.ReadByte(); err != nil {
return nil, err
} else if space != ' ' {
if b[2] != ' ' {
return nil, ErrProtocol
}
// Find the first <CR><LF>
i := bytes.Index(b, []byte("\r\n"))
if i < 3 {
return nil, ErrProtocol
}
// Read the meta
meta, err := readLine(buf)
if err != nil {
return nil, err
meta := string(b[3:i])
if len(meta) > 1024 {
return nil, ErrProtocol
}
// Read the response body
body, err := ioutil.ReadAll(buf)
if err != nil {
return nil, err
}
body := b[i+2:]
return &Response{
Status: status,

View File

@ -92,22 +92,33 @@ func (s *Server) Serve(ln net.Listener) error {
return err
}
req, err := readLine(rw)
if err != nil {
continue
}
url, err := url.Parse(req)
if err != nil {
continue
var resp *Response
if rawurl, err := readLine(rw); err != nil {
resp = &Response{
Status: StatusBadRequest,
Meta: "Bad request",
}
} else if len(rawurl) > 1024 {
resp = &Response{
Status: StatusBadRequest,
Meta: "URL exceeds 1024 bytes",
}
} else if url, err := url.Parse(rawurl); err != nil {
resp = &Response{
Status: StatusBadRequest,
Meta: "Invalid URL",
}
} else {
// Gather information about the request
reqInfo := &RequestInfo{
URL: url,
Certificates: rw.(*tls.Conn).ConnectionState().PeerCertificates,
RemoteAddr: rw.RemoteAddr(),
}
resp = s.Handler.Serve(reqInfo)
}
// Gather information about the request
reqInfo := &RequestInfo{
URL: url,
Certificates: rw.(*tls.Conn).ConnectionState().PeerCertificates,
RemoteAddr: rw.RemoteAddr(),
}
resp := s.Handler.Serve(reqInfo)
resp.Write(rw)
rw.Close()
}