diff --git a/examples/server/server.go b/examples/server/server.go index 75e401f..50c311d 100644 --- a/examples/server/server.go +++ b/examples/server/server.go @@ -26,13 +26,10 @@ func main() { config.ClientAuth = tls.RequestClientCert mux := &gemini.Mux{} - mux.HandleFunc("/", func(req *gemini.Request) *gemini.Response { + mux.HandleFunc("/", func(rw gemini.ResponseWriter, req *gemini.Request) { log.Printf("Request from %s for %s with certificates %v", req.RemoteAddr.String(), req.URL.String(), req.TLS.PeerCertificates) - return &gemini.Response{ - Status: gemini.StatusSuccess, - Meta: "text/gemini", - Body: []byte("You requested " + req.URL.String()), - } + rw.WriteHeader(gemini.StatusSuccess, "text/gemini") + rw.Write([]byte("You requested " + req.URL.String())) }) server := gemini.Server{ diff --git a/gemini.go b/gemini.go index b2bd14b..36e0ff1 100644 --- a/gemini.go +++ b/gemini.go @@ -3,7 +3,6 @@ package gemini import ( "bufio" - "bytes" "crypto/tls" "errors" "io" @@ -50,8 +49,13 @@ const ( // Errors. var ( - ErrProtocol = errors.New("Protocol error") - ErrInvalidURL = errors.New("Invalid URL") + ErrProtocol = errors.New("gemini: protocol error") + ErrInvalidURL = errors.New("gemini: requested URL is invalid") + ErrBodyNotAllowed = errors.New("gemini: response status code does not allow for body") +) + +var ( + crlf = []byte("\r\n") ) // Request is a Gemini request. @@ -70,7 +74,8 @@ type Request struct { URL *url.URL // For client requests, Host specifies the host on which the URL is sought. - // If this field is not set, the host will be inferred from the URL. + // If this field is empty, the host will be inferred from the URL. + // This field is ignored by the server. Host string @@ -126,16 +131,20 @@ func NewProxyRequest(host, rawurl string) (*Request, error) { }, nil } -// Write writes the Gemini request to the provided io.Writer. -func (r *Request) Write(w io.Writer) error { +// write writes the Gemini request to the provided buffered writer. +func (r *Request) write(w *bufio.Writer) error { url := r.URL.String() // UserInfo is invalid if r.URL.User != nil || len(url) > 1024 { return ErrInvalidURL } - request := url + "\r\n" - _, err := w.Write([]byte(request)) - return err + if _, err := w.WriteString(url); err != nil { + return err + } + if _, err := w.Write(crlf); err != nil { + return err + } + return nil } // Response is a Gemini response. @@ -159,21 +168,21 @@ type Response struct { TLS tls.ConnectionState } -// Write writes the Gemini response header and body to the provided io.Writer. -func (r *Response) Write(w io.Writer) error { - header := strconv.Itoa(r.Status) + " " + r.Meta + "\r\n" - if _, err := w.Write([]byte(header)); err != nil { - return err - } +// ResponseWriter is used by a Gemini handler to construct a Gemini response. +type ResponseWriter interface { + // WriteHeader writes the response header. + // + // Meta contains more information related to the response status. + // For successful responses, Meta should contain the mimetype of the response. + // For failure responses, Meta should contain a short description of the failure. + // Meta should not be longer than 1024 bytes. + WriteHeader(status int, meta string) - // Only write the response body on success - if r.Status/10 == StatusClassSuccess { - if _, err := w.Write(r.Body); err != nil { - return err - } - } - - return nil + // Write writes the response body. + // If the response status does not allow for a response body, Write returns + // ErrBodyNotAllowed. + // WriteHeader must be called before Write. + Write([]byte) (int, error) } // Get makes a request for the provided URL. The host is inferred from the URL. @@ -204,47 +213,61 @@ func Do(req *Request) (*Response, error) { defer conn.Close() // Write the request - if err := req.Write(conn); err != nil { + // TODO: Is buffered I/O necessary here? + w := bufio.NewWriter(conn) + req.write(w) + if err := w.Flush(); err != nil { return nil, err } - // Read the response - b, err := ioutil.ReadAll(conn) - if err != nil { + // Read the response status + r := bufio.NewReader(conn) + statusB := make([]byte, 2) + if _, err := r.Read(statusB); err != nil { return nil, err } - - // Ensure that the response is long enough - // The minimum response: (5 bytes) - if len(b) < 5 { - return nil, ErrProtocol - } - - // Parse the response header - status, err := strconv.Atoi(string(b[:2])) + status, err := strconv.Atoi(string(statusB)) if err != nil { return nil, err } // Read one space - if b[2] != ' ' { - return nil, ErrProtocol - } - - // Find the first - i := bytes.Index(b, []byte("\r\n")) - if i < 3 { + if b, err := r.ReadByte(); err != nil { + return nil, err + } else if b != ' ' { return nil, ErrProtocol } // Read the meta - meta := string(b[3:i]) + meta, err := r.ReadString('\r') + if err != nil { + return nil, err + } + + // Read terminating newline + if b, err := r.ReadByte(); err != nil { + return nil, err + } else if b != '\n' { + return nil, ErrProtocol + } + + // Trim carriage return + meta = meta[:len(meta)-1] + + // Ensure meta is less than 1024 bytes if len(meta) > 1024 { return nil, ErrProtocol } - // Read the response body - body := b[i+2:] + // Read response body + var body []byte + if status/10 == StatusClassSuccess { + var err error + body, err = ioutil.ReadAll(r) + if err != nil { + return nil, err + } + } return &Response{ Status: status, @@ -316,26 +339,47 @@ func (s *Server) Serve(l net.Listener) error { } } +// response represents the server side of a Gemini response. +type response struct { + w *bufio.Writer + bodyAllowed bool +} + +func newResponse(conn net.Conn) *response { + return &response{ + w: bufio.NewWriter(conn), + } +} + +func (r *response) WriteHeader(status int, meta string) { + r.w.WriteString(strconv.Itoa(status)) + r.w.WriteByte(' ') + r.w.WriteString(meta) + r.w.Write(crlf) + + // Only allow body to be written on successful status codes. + if status/10 == StatusClassSuccess { + r.bodyAllowed = true + } +} + +func (r *response) Write(b []byte) (int, error) { + if !r.bodyAllowed { + return 0, ErrBodyNotAllowed + } + return r.w.Write(b) +} + // respond responds to a connection. func (s *Server) respond(rw net.Conn) { - var resp *Response - + resp := newResponse(rw) if rawurl, err := readLine(rw); err != nil { - resp = &Response{ - Status: StatusBadRequest, - Meta: "Bad request", - } + resp.WriteHeader(StatusBadRequest, "Bad request") } else if len(rawurl) > 1024 { - resp = &Response{ - Status: StatusBadRequest, - Meta: "Requested URL exceeds 1024 bytes", - } + resp.WriteHeader(StatusBadRequest, "Requested URL exceeds 1024 bytes") } else if url, err := url.Parse(rawurl); err != nil || url.User != nil { // Note that we return an error if User is specified in the URL. - resp = &Response{ - Status: StatusBadRequest, - Meta: "Requested URL is invalid", - } + resp.WriteHeader(StatusBadRequest, "Requested URL is invalid") } else { // Gather information about the request req := &Request{ @@ -343,17 +387,16 @@ func (s *Server) respond(rw net.Conn) { RemoteAddr: rw.RemoteAddr(), TLS: rw.(*tls.Conn).ConnectionState(), } - resp = s.Handler.Serve(req) + s.Handler.Serve(resp, req) } - - resp.Write(rw) + resp.w.Flush() rw.Close() } // A Handler responds to a Gemini request. type Handler interface { - // Serve accepts a Request and returns a Response. - Serve(*Request) *Response + // Serve accepts a Request and constructs a Response. + Serve(ResponseWriter, *Request) } // Mux is a Gemini request multiplexer. @@ -397,28 +440,26 @@ func (m *Mux) Handle(pattern string, handler Handler) { } // HandleFunc registers a HandlerFunc for the given pattern. -func (m *Mux) HandleFunc(pattern string, handlerFunc func(req *Request) *Response) { +func (m *Mux) HandleFunc(pattern string, handlerFunc func(ResponseWriter, *Request)) { handler := HandlerFunc(handlerFunc) m.Handle(pattern, handler) } // Serve responds to the request with the appropriate handler. -func (m *Mux) Serve(req *Request) *Response { +func (m *Mux) Serve(rw ResponseWriter, req *Request) { h := m.match(req.URL) if h == nil { - return &Response{ - Status: StatusNotFound, - Meta: "Not found", - } + rw.WriteHeader(StatusNotFound, "Not found") + return } - return h.Serve(req) + h.Serve(rw, req) } // A wrapper around a bare function that implements Handler. -type HandlerFunc func(req *Request) *Response +type HandlerFunc func(ResponseWriter, *Request) -func (f HandlerFunc) Serve(req *Request) *Response { - return f(req) +func (f HandlerFunc) Serve(rw ResponseWriter, req *Request) { + f(rw, req) } // readLine reads a line.