Add ReadRequest and ReadResponse functions

This commit is contained in:
Adnan Maolood 2020-12-18 01:41:14 -05:00
parent decd72cc23
commit d34d5df89e
5 changed files with 64 additions and 54 deletions

View File

@ -84,8 +84,8 @@ func (c *Client) Do(req *Request) (*Response, error) {
} }
// Read the response // Read the response
resp := &Response{} resp, err := ReadResponse(conn)
if err := resp.Read(conn); err != nil { if err != nil {
return nil, err return nil, err
} }
// Store connection state // Store connection state

View File

@ -58,6 +58,7 @@ var crlf = []byte("\r\n")
// Errors. // Errors.
var ( var (
ErrInvalidURL = errors.New("gemini: invalid URL") ErrInvalidURL = errors.New("gemini: invalid URL")
ErrInvalidRequest = errors.New("gemini: invalid request")
ErrInvalidResponse = errors.New("gemini: invalid response") ErrInvalidResponse = errors.New("gemini: invalid response")
ErrBodyNotAllowed = errors.New("gemini: response body not allowed") ErrBodyNotAllowed = errors.New("gemini: response body not allowed")
) )

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"context" "context"
"crypto/tls" "crypto/tls"
"io"
"net" "net"
"net/url" "net/url"
) )
@ -65,6 +66,37 @@ func NewRequestFromURL(url *url.URL) *Request {
} }
} }
// ReadRequest reads a Gemini request from the provided io.Reader
func ReadRequest(r io.Reader) (*Request, error) {
// Read URL
br := bufio.NewReader(r)
rawurl, err := br.ReadString('\r')
if err != nil {
return nil, err
}
// Read terminating line feed
if b, err := br.ReadByte(); err != nil {
return nil, err
} else if b != '\n' {
return nil, ErrInvalidRequest
}
// Trim carriage return
rawurl = rawurl[:len(rawurl)-1]
// Validate URL
if len(rawurl) > 1024 {
return nil, ErrInvalidRequest
}
u, err := url.Parse(rawurl)
if err != nil {
return nil, err
}
if u.User != nil {
// User is not allowed
return nil, ErrInvalidURL
}
return &Request{URL: u}, nil
}
// Write writes the Gemini request to the provided buffered writer. // Write writes the Gemini request to the provided buffered writer.
func (r *Request) Write(w *bufio.Writer) error { func (r *Request) Write(w *bufio.Writer) error {
url := r.URL.String() url := r.URL.String()

View File

@ -26,17 +26,19 @@ type Response struct {
TLS tls.ConnectionState TLS tls.ConnectionState
} }
// Read reads a Gemini response from the provided io.ReadCloser. // ReadResponse reads a Gemini response from the provided io.ReadCloser.
func (resp *Response) Read(rc io.ReadCloser) error { func ReadResponse(rc io.ReadCloser) (*Response, error) {
resp := &Response{}
br := bufio.NewReader(rc) br := bufio.NewReader(rc)
// Read the status // Read the status
statusB := make([]byte, 2) statusB := make([]byte, 2)
if _, err := br.Read(statusB); err != nil { if _, err := br.Read(statusB); err != nil {
return err return nil, err
} }
status, err := strconv.Atoi(string(statusB)) status, err := strconv.Atoi(string(statusB))
if err != nil { if err != nil {
return err return nil, err
} }
resp.Status = Status(status) resp.Status = Status(status)
@ -44,26 +46,26 @@ func (resp *Response) Read(rc io.ReadCloser) error {
const minStatus, maxStatus = 1, 6 const minStatus, maxStatus = 1, 6
statusClass := resp.Status.Class() statusClass := resp.Status.Class()
if statusClass < minStatus || statusClass > maxStatus { if statusClass < minStatus || statusClass > maxStatus {
return ErrInvalidResponse return nil, ErrInvalidResponse
} }
// Read one space // Read one space
if b, err := br.ReadByte(); err != nil { if b, err := br.ReadByte(); err != nil {
return err return nil, err
} else if b != ' ' { } else if b != ' ' {
return ErrInvalidResponse return nil, ErrInvalidResponse
} }
// Read the meta // Read the meta
meta, err := br.ReadString('\r') meta, err := br.ReadString('\r')
if err != nil { if err != nil {
return err return nil, err
} }
// Trim carriage return // Trim carriage return
meta = meta[:len(meta)-1] meta = meta[:len(meta)-1]
// Ensure meta is less than or equal to 1024 bytes // Ensure meta is less than or equal to 1024 bytes
if len(meta) > 1024 { if len(meta) > 1024 {
return ErrInvalidResponse return nil, ErrInvalidResponse
} }
// Default mime type of text/gemini; charset=utf-8 // Default mime type of text/gemini; charset=utf-8
if statusClass == StatusClassSuccess && meta == "" { if statusClass == StatusClassSuccess && meta == "" {
@ -73,15 +75,15 @@ func (resp *Response) Read(rc io.ReadCloser) error {
// Read terminating newline // Read terminating newline
if b, err := br.ReadByte(); err != nil { if b, err := br.ReadByte(); err != nil {
return err return nil, err
} else if b != '\n' { } else if b != '\n' {
return ErrInvalidResponse return nil, ErrInvalidResponse
} }
if resp.Status.Class() == StatusClassSuccess { if resp.Status.Class() == StatusClassSuccess {
resp.Body = newReadCloserBody(br, rc) resp.Body = newReadCloserBody(br, rc)
} }
return nil return resp, nil
} }
type readCloserBody struct { type readCloserBody struct {

View File

@ -6,7 +6,6 @@ import (
"errors" "errors"
"log" "log"
"net" "net"
"net/url"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -174,6 +173,7 @@ func (s *Server) getCertificateFor(hostname string) (*tls.Certificate, error) {
// respond responds to a connection. // respond responds to a connection.
func (s *Server) respond(conn net.Conn) { func (s *Server) respond(conn net.Conn) {
defer conn.Close()
if d := s.ReadTimeout; d != 0 { if d := s.ReadTimeout; d != 0 {
conn.SetReadDeadline(time.Now().Add(d)) conn.SetReadDeadline(time.Now().Add(d))
} }
@ -181,58 +181,33 @@ func (s *Server) respond(conn net.Conn) {
conn.SetWriteDeadline(time.Now().Add(d)) conn.SetWriteDeadline(time.Now().Add(d))
} }
r := bufio.NewReader(conn)
w := NewResponseWriter(conn) w := NewResponseWriter(conn)
// Read requested URL defer w.b.Flush()
rawurl, err := r.ReadString('\r')
req, err := ReadRequest(conn)
if err != nil { if err != nil {
return
}
// Read terminating line feed
if b, err := r.ReadByte(); err != nil {
return
} else if b != '\n' {
w.WriteStatus(StatusBadRequest)
}
// Trim carriage return
rawurl = rawurl[:len(rawurl)-1]
// Ensure URL is valid
if len(rawurl) > 1024 {
w.WriteStatus(StatusBadRequest)
} else if url, err := url.Parse(rawurl); err != nil || url.User != nil {
// Note that we return an error status if User is specified in the URL
w.WriteStatus(StatusBadRequest) w.WriteStatus(StatusBadRequest)
} else { } else {
// Store information about the TLS connection // Store information about the TLS connection
var connState tls.ConnectionState
var cert *tls.Certificate
if tlsConn, ok := conn.(*tls.Conn); ok { if tlsConn, ok := conn.(*tls.Conn); ok {
connState = tlsConn.ConnectionState() req.TLS = tlsConn.ConnectionState()
if len(connState.PeerCertificates) > 0 { if len(req.TLS.PeerCertificates) > 0 {
peerCert := connState.PeerCertificates[0] peerCert := req.TLS.PeerCertificates[0]
// Store the TLS certificate // Store the TLS certificate
cert = &tls.Certificate{ req.Certificate = &tls.Certificate{
Certificate: [][]byte{peerCert.Raw}, Certificate: [][]byte{peerCert.Raw},
Leaf: peerCert, Leaf: peerCert,
} }
} }
} }
req := &Request{
URL: url,
RemoteAddr: conn.RemoteAddr(),
TLS: connState,
Certificate: cert,
}
resp := s.responder(req)
if resp != nil {
resp.Respond(w, req)
} else {
w.WriteStatus(StatusNotFound)
}
} }
w.b.Flush()
conn.Close() resp := s.responder(req)
if resp != nil {
resp.Respond(w, req)
} else {
w.WriteStatus(StatusNotFound)
}
} }
func (s *Server) responder(r *Request) Responder { func (s *Server) responder(r *Request) Responder {