Add ReadRequest and ReadResponse functions

This commit is contained in:
Adnan Maolood 2020-12-18 01:41:14 -05:00
parent decd72cc23
commit d34d5df89e
5 changed files with 64 additions and 54 deletions

View File

@ -84,8 +84,8 @@ func (c *Client) Do(req *Request) (*Response, error) {
}
// Read the response
resp := &Response{}
if err := resp.Read(conn); err != nil {
resp, err := ReadResponse(conn)
if err != nil {
return nil, err
}
// Store connection state

View File

@ -58,6 +58,7 @@ var crlf = []byte("\r\n")
// Errors.
var (
ErrInvalidURL = errors.New("gemini: invalid URL")
ErrInvalidRequest = errors.New("gemini: invalid request")
ErrInvalidResponse = errors.New("gemini: invalid response")
ErrBodyNotAllowed = errors.New("gemini: response body not allowed")
)

View File

@ -4,6 +4,7 @@ import (
"bufio"
"context"
"crypto/tls"
"io"
"net"
"net/url"
)
@ -65,6 +66,37 @@ func NewRequestFromURL(url *url.URL) *Request {
}
}
// ReadRequest reads a Gemini request from the provided io.Reader
func ReadRequest(r io.Reader) (*Request, error) {
// Read URL
br := bufio.NewReader(r)
rawurl, err := br.ReadString('\r')
if err != nil {
return nil, err
}
// Read terminating line feed
if b, err := br.ReadByte(); err != nil {
return nil, err
} else if b != '\n' {
return nil, ErrInvalidRequest
}
// Trim carriage return
rawurl = rawurl[:len(rawurl)-1]
// Validate URL
if len(rawurl) > 1024 {
return nil, ErrInvalidRequest
}
u, err := url.Parse(rawurl)
if err != nil {
return nil, err
}
if u.User != nil {
// User is not allowed
return nil, ErrInvalidURL
}
return &Request{URL: u}, nil
}
// Write writes the Gemini request to the provided buffered writer.
func (r *Request) Write(w *bufio.Writer) error {
url := r.URL.String()

View File

@ -26,17 +26,19 @@ type Response struct {
TLS tls.ConnectionState
}
// Read reads a Gemini response from the provided io.ReadCloser.
func (resp *Response) Read(rc io.ReadCloser) error {
// ReadResponse reads a Gemini response from the provided io.ReadCloser.
func ReadResponse(rc io.ReadCloser) (*Response, error) {
resp := &Response{}
br := bufio.NewReader(rc)
// Read the status
statusB := make([]byte, 2)
if _, err := br.Read(statusB); err != nil {
return err
return nil, err
}
status, err := strconv.Atoi(string(statusB))
if err != nil {
return err
return nil, err
}
resp.Status = Status(status)
@ -44,26 +46,26 @@ func (resp *Response) Read(rc io.ReadCloser) error {
const minStatus, maxStatus = 1, 6
statusClass := resp.Status.Class()
if statusClass < minStatus || statusClass > maxStatus {
return ErrInvalidResponse
return nil, ErrInvalidResponse
}
// Read one space
if b, err := br.ReadByte(); err != nil {
return err
return nil, err
} else if b != ' ' {
return ErrInvalidResponse
return nil, ErrInvalidResponse
}
// Read the meta
meta, err := br.ReadString('\r')
if err != nil {
return err
return nil, err
}
// Trim carriage return
meta = meta[:len(meta)-1]
// Ensure meta is less than or equal to 1024 bytes
if len(meta) > 1024 {
return ErrInvalidResponse
return nil, ErrInvalidResponse
}
// Default mime type of text/gemini; charset=utf-8
if statusClass == StatusClassSuccess && meta == "" {
@ -73,15 +75,15 @@ func (resp *Response) Read(rc io.ReadCloser) error {
// Read terminating newline
if b, err := br.ReadByte(); err != nil {
return err
return nil, err
} else if b != '\n' {
return ErrInvalidResponse
return nil, ErrInvalidResponse
}
if resp.Status.Class() == StatusClassSuccess {
resp.Body = newReadCloserBody(br, rc)
}
return nil
return resp, nil
}
type readCloserBody struct {

View File

@ -6,7 +6,6 @@ import (
"errors"
"log"
"net"
"net/url"
"strconv"
"strings"
"time"
@ -174,6 +173,7 @@ func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
// respond responds to a connection.
func (s *Server) respond(conn net.Conn) {
defer conn.Close()
if d := s.ReadTimeout; d != 0 {
conn.SetReadDeadline(time.Now().Add(d))
}
@ -181,58 +181,33 @@ func (s *Server) respond(conn net.Conn) {
conn.SetWriteDeadline(time.Now().Add(d))
}
r := bufio.NewReader(conn)
w := NewResponseWriter(conn)
// Read requested URL
rawurl, err := r.ReadString('\r')
defer w.b.Flush()
req, err := ReadRequest(conn)
if err != nil {
return
}
// Read terminating line feed
if b, err := r.ReadByte(); err != nil {
return
} else if b != '\n' {
w.WriteStatus(StatusBadRequest)
}
// Trim carriage return
rawurl = rawurl[:len(rawurl)-1]
// Ensure URL is valid
if len(rawurl) > 1024 {
w.WriteStatus(StatusBadRequest)
} else if url, err := url.Parse(rawurl); err != nil || url.User != nil {
// Note that we return an error status if User is specified in the URL
w.WriteStatus(StatusBadRequest)
} else {
// Store information about the TLS connection
var connState tls.ConnectionState
var cert *tls.Certificate
if tlsConn, ok := conn.(*tls.Conn); ok {
connState = tlsConn.ConnectionState()
if len(connState.PeerCertificates) > 0 {
peerCert := connState.PeerCertificates[0]
req.TLS = tlsConn.ConnectionState()
if len(req.TLS.PeerCertificates) > 0 {
peerCert := req.TLS.PeerCertificates[0]
// Store the TLS certificate
cert = &tls.Certificate{
req.Certificate = &tls.Certificate{
Certificate: [][]byte{peerCert.Raw},
Leaf: peerCert,
}
}
}
req := &Request{
URL: url,
RemoteAddr: conn.RemoteAddr(),
TLS: connState,
Certificate: cert,
}
resp := s.responder(req)
if resp != nil {
resp.Respond(w, req)
} else {
w.WriteStatus(StatusNotFound)
}
}
w.b.Flush()
conn.Close()
resp := s.responder(req)
if resp != nil {
resp.Respond(w, req)
} else {
w.WriteStatus(StatusNotFound)
}
}
func (s *Server) responder(r *Request) Responder {