hnakra/router/service.go

393 lines
9.9 KiB
Go
Raw Normal View History

2023-05-25 16:08:56 -06:00
package router
import "io"
import "log"
import "net"
import "sync"
import "bufio"
import "net/url"
import "net/http"
import "hnakra/protocol"
type activeRequest struct {
service *Service
body chan []byte
requestBody io.ReadCloser
header http.Header
status int
url *url.URL
res http.ResponseWriter
wroteHeader bool
headerMutex sync.Mutex
}
func (request *activeRequest) ensureHeader () {
request.headerMutex.Lock()
defer request.headerMutex.Unlock()
if request.wroteHeader { return }
for key, values := range request.header {
for _, value := range values {
request.res.Header().Add(key, value)
}
}
request.res.WriteHeader(request.status)
request.wroteHeader = true
}
type Service struct {
router *Router
validate func (user string, key []byte, pattern Pattern) protocol.Status
pattern Pattern
user, name, description string
idFactory *protocol.IDFactory
requests map[protocol.ID] *activeRequest
requestsLock sync.RWMutex
conn net.Conn
connLock sync.Mutex
connReadWriter *bufio.ReadWriter
connErr error
}
func (router *Router) newService (conn net.Conn) (service *Service) {
service = &Service {
idFactory: protocol.NewRouterIDFactory(),
router: router,
validate: router.Validate,
conn: conn,
requests: make(map[protocol.ID] *activeRequest),
connReadWriter: bufio.NewReadWriter (
bufio.NewReader(conn),
bufio.NewWriter(conn)),
}
go service.work()
return
}
func (service *Service) OverRequests (callback func (protocol.ID, *url.URL) bool) {
service.requestsLock.RLock()
defer service.requestsLock.RUnlock()
for id, request := range service.requests {
if !callback(id, request.url) { break }
}
}
func (service *Service) User () string {
return service.user
}
func (service *Service) Name () string {
return service.name
}
func (service *Service) Description () string {
return service.description
}
func (service *Service) Pattern () Pattern {
return service.pattern
}
func (service *Service) Close () error {
return service.conn.Close()
}
func (service *Service) Shutdown () error {
// TODO: wait for all requests to finish
return service.conn.Close()
}
func (service *Service) ServeHTTP (res http.ResponseWriter, req *http.Request) {
// if we are only accepting https requests and we recieve an http one,
// redirect to the https version
if req.TLS == nil && service.pattern.Scheme == "https" {
newURL := req.URL
newURL.Scheme = service.pattern.Scheme
http.Redirect(res, req, newURL.String(), http.StatusPermanentRedirect)
return
}
// create a new request
request, id := service.generateActiveRequest(res, req)
defer service.freeRequest(id)
// ask the service to fulfill it
err := service.send(protocol.MessageHTTPRequest {
ID: id,
Address: req.RemoteAddr,
Method: req.Method,
URL: req.URL.String(),
Header: protocol.Map(req.Header),
})
if err != nil {
http.Error (
res, "connection to service closed abruptly: " +
err.Error(), http.StatusBadGateway)
return
}
// read in the body chunk bu chunk from the channel, and finalize the
// request when the channel is closed
for chunk := range request.body {
request.ensureHeader()
res.Write(chunk)
}
// if we have stopped serving the page prematurely due to an error,
// inform the user (if possible)
if service.connErr != nil && !request.wroteHeader {
http.Error (
res, "connection to service closed abruptly: " +
service.connErr.Error(), http.StatusBadGateway)
return
}
}
func (service *Service) generateActiveRequest (
res http.ResponseWriter,
req *http.Request,
) (
*activeRequest,
protocol.ID,
) {
service.requestsLock.Lock()
defer service.requestsLock.Unlock()
id := service.idFactory.Next()
request := &activeRequest {
service: service,
body: make(chan []byte),
header: make(http.Header),
status: http.StatusOK,
url: req.URL,
requestBody: req.Body,
res: res,
}
service.requests[id] = request
return request, id
}
func (service *Service) freeRequest (id protocol.ID) {
service.requestsLock.Lock()
defer service.requestsLock.Unlock()
delete(service.requests, id)
service.idFactory.Free(id)
}
func (service *Service) work () {
defer log.Println("X=- connection to", service.conn.RemoteAddr(), "closed")
unhandle, ok := service.authenticate()
if !ok { return }
// dismantle service when we exit
defer func () {
unhandle()
service.conn.Close()
closedRequest := false
service.requestsLock.Lock()
for _, request := range service.requests {
closedRequest = true
close(request.body)
}
service.requestsLock.Unlock()
if service.connErr != nil && closedRequest {
log.Println("!!! service connection error:", service.connErr)
}
} ()
// respond to requests
for {
message, err := protocol.ReadMessage(service.connReadWriter)
if err != nil {
service.connErr = err
return
}
switch message.(type) {
// set the header and status code of the request. when the first
// body chunk is sent or the body is ended without any chunks
// being sent, the serving goroutine will send the header data.
case protocol.MessageHTTPResponse:
response := message.(protocol.MessageHTTPResponse)
service.requestsLock.RLock()
if _, ok := service.requests[response.ID]; ok {
service.requests[response.ID].header = http.Header(response.Header)
service.requests[response.ID].status = int(response.Status)
service.requests[response.ID].ensureHeader()
}
service.requestsLock.RUnlock()
// send the body chunk along the body channel to the goroutine
// serving the request.
case protocol.MessageHTTPBodySegment:
segment := message.(protocol.MessageHTTPBodySegment)
service.requestsLock.RLock()
if _, ok := service.requests[segment.ID]; ok {
service.requests[segment.ID].body <- segment.Data
} else {
log.Println("!!! service sent segment for nonexistent request", segment.ID)
}
service.requestsLock.RUnlock()
// close the body channel and remove the request from the list,
// signaling to the goroutine serving the request that the body
// has been read entirely.
case protocol.MessageHTTPBodyEnd:
end := message.(protocol.MessageHTTPBodyEnd)
service.requestsLock.Lock()
if _, ok := service.requests[end.ID]; ok {
close(service.requests[end.ID].body)
service.requests[end.ID].body = nil
} else {
log.Println("!!! service ended nonexistent request", end.ID)
}
service.requestsLock.Unlock()
// the client wants the reqeust body
case protocol.MessageHTTPBodyRequest:
id := message.(protocol.MessageHTTPBodyRequest).ID
service.requestsLock.RLock()
if request, ok := service.requests[id]; ok {
body := request.requestBody
request.requestBody = nil
if body == nil {
// TODO: send an error to the service bc this
// means either the implementation of
// http.Request is broken or the service has
// requested the body twice.
} else {
go service.sendRequestBody(body, id)
}
} else {
log.Println("!!! service requested body for nonexistent request", id)
}
service.requestsLock.RUnlock()
// a message from the client.
case protocol.MessageStatus:
status := message.(protocol.MessageStatus)
log.Println("(i)", service.name, "says:", status.Status)
default:
log.Println("ERR service sent unknown type code")
service.Close()
}
}
}
func (service *Service) sendRequestBody (body io.ReadCloser, id protocol.ID) {
// FIXME: is this a good max buffer size? maybe make it configurable
buffer := [2048]byte { }
for {
n, err := body.Read(buffer[:])
if n > 0 {
if service.send(protocol.MessageHTTPBodySegment {
ID: id,
Data: buffer[:n],
}) != nil { return }
}
if err != nil {
if err != io.EOF {
// TODO: send the service an error
}
break
}
}
service.send(protocol.MessageHTTPBodyEnd { ID: id })
}
func (service *Service) authenticate () (func(), bool) {
// get login details
message, err := protocol.ReadMessage(service.connReadWriter)
if err != nil {
service.Close()
return nil, false
}
login, ok := message.(protocol.MessageLogin)
if !ok {
log.Println (
"ERR", service.conn.RemoteAddr(),
"sent wrong type code, expected",
protocol.TypeLogin)
service.send(protocol.MessageStatus {
Status: protocol.StatusBadMessageType,
})
service.Close()
return nil, false
}
// create pattern
service.pattern = Pattern {
Scheme: login.Scheme,
Host: login.Host,
Path: login.Path,
}
service.pattern.FillDefaults()
// validate credentials
status := service.validate(login.User, login.Key, service.pattern)
service.send(protocol.MessageStatus { Status: status })
if status == protocol.StatusOk {
service.user = login.User
service.name = login.Name
service.description = login.Description
log.Println (
"-->", service.conn.RemoteAddr(),
"logged in as", login.User, "on", service.pattern)
} else {
log.Println (
"ERR", service.conn.RemoteAddr(),
"failed login as", login.User, "on", service.pattern)
service.conn.Close()
return nil, false
}
// mount service on the mux.
var unhandle func ()
muxPattern := service.pattern.MuxPattern()
switch service.pattern.Scheme {
case "http", "https":
err = service.router.HTTPMux().Handle(muxPattern, service)
unhandle = func () {
service.router.HTTPMux().Unhandle(muxPattern)
}
default:
err = protocol.StatusBadMount
}
if err != nil {
log.Println (
"ERR", service.name,
"sent bad mount pattern:", service.pattern)
service.send(protocol.MessageStatus {
Status: protocol.StatusBadMount,
})
service.Close()
return nil, false
}
// list the service in the router
service.router.list(service)
return func () {
unhandle()
service.router.unlist(service)
}, true
}
func (service *Service) send (message protocol.Message) (err error) {
service.connLock.Lock()
defer service.connLock.Unlock()
err = message.Send(service.connReadWriter)
if err != nil { return }
return service.connReadWriter.Flush()
}