393 lines
9.9 KiB
Go
393 lines
9.9 KiB
Go
package router
|
|
|
|
import "io"
|
|
import "log"
|
|
import "net"
|
|
import "sync"
|
|
import "bufio"
|
|
import "net/url"
|
|
import "net/http"
|
|
import "hnakra/protocol"
|
|
|
|
type activeRequest struct {
|
|
service *Service
|
|
body chan []byte
|
|
requestBody io.ReadCloser
|
|
header http.Header
|
|
status int
|
|
url *url.URL
|
|
|
|
res http.ResponseWriter
|
|
wroteHeader bool
|
|
headerMutex sync.Mutex
|
|
}
|
|
|
|
func (request *activeRequest) ensureHeader () {
|
|
request.headerMutex.Lock()
|
|
defer request.headerMutex.Unlock()
|
|
if request.wroteHeader { return }
|
|
|
|
for key, values := range request.header {
|
|
for _, value := range values {
|
|
request.res.Header().Add(key, value)
|
|
}
|
|
}
|
|
request.res.WriteHeader(request.status)
|
|
request.wroteHeader = true
|
|
}
|
|
|
|
type Service struct {
|
|
router *Router
|
|
validate func (user string, key []byte, pattern Pattern) protocol.Status
|
|
pattern Pattern
|
|
|
|
user, name, description string
|
|
|
|
idFactory *protocol.IDFactory
|
|
requests map[protocol.ID] *activeRequest
|
|
requestsLock sync.RWMutex
|
|
|
|
conn net.Conn
|
|
connLock sync.Mutex
|
|
connReadWriter *bufio.ReadWriter
|
|
connErr error
|
|
}
|
|
|
|
func (router *Router) newService (conn net.Conn) (service *Service) {
|
|
service = &Service {
|
|
idFactory: protocol.NewRouterIDFactory(),
|
|
router: router,
|
|
validate: router.Validate,
|
|
conn: conn,
|
|
requests: make(map[protocol.ID] *activeRequest),
|
|
connReadWriter: bufio.NewReadWriter (
|
|
bufio.NewReader(conn),
|
|
bufio.NewWriter(conn)),
|
|
}
|
|
go service.work()
|
|
return
|
|
}
|
|
|
|
func (service *Service) OverRequests (callback func (protocol.ID, *url.URL) bool) {
|
|
service.requestsLock.RLock()
|
|
defer service.requestsLock.RUnlock()
|
|
for id, request := range service.requests {
|
|
if !callback(id, request.url) { break }
|
|
}
|
|
}
|
|
|
|
func (service *Service) User () string {
|
|
return service.user
|
|
}
|
|
|
|
func (service *Service) Name () string {
|
|
return service.name
|
|
}
|
|
|
|
func (service *Service) Description () string {
|
|
return service.description
|
|
}
|
|
|
|
func (service *Service) Pattern () Pattern {
|
|
return service.pattern
|
|
}
|
|
|
|
func (service *Service) Close () error {
|
|
return service.conn.Close()
|
|
}
|
|
|
|
func (service *Service) Shutdown () error {
|
|
// TODO: wait for all requests to finish
|
|
return service.conn.Close()
|
|
}
|
|
|
|
func (service *Service) ServeHTTP (res http.ResponseWriter, req *http.Request) {
|
|
// if we are only accepting https requests and we recieve an http one,
|
|
// redirect to the https version
|
|
if req.TLS == nil && service.pattern.Scheme == "https" {
|
|
newURL := req.URL
|
|
newURL.Scheme = service.pattern.Scheme
|
|
http.Redirect(res, req, newURL.String(), http.StatusPermanentRedirect)
|
|
return
|
|
}
|
|
|
|
// create a new request
|
|
request, id := service.generateActiveRequest(res, req)
|
|
defer service.freeRequest(id)
|
|
|
|
// ask the service to fulfill it
|
|
err := service.send(protocol.MessageHTTPRequest {
|
|
ID: id,
|
|
Address: req.RemoteAddr,
|
|
Method: req.Method,
|
|
URL: req.URL.String(),
|
|
Header: protocol.Map(req.Header),
|
|
})
|
|
if err != nil {
|
|
http.Error (
|
|
res, "connection to service closed abruptly: " +
|
|
err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
|
|
// read in the body chunk bu chunk from the channel, and finalize the
|
|
// request when the channel is closed
|
|
for chunk := range request.body {
|
|
request.ensureHeader()
|
|
res.Write(chunk)
|
|
}
|
|
|
|
// if we have stopped serving the page prematurely due to an error,
|
|
// inform the user (if possible)
|
|
if service.connErr != nil && !request.wroteHeader {
|
|
http.Error (
|
|
res, "connection to service closed abruptly: " +
|
|
service.connErr.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (service *Service) generateActiveRequest (
|
|
res http.ResponseWriter,
|
|
req *http.Request,
|
|
) (
|
|
*activeRequest,
|
|
protocol.ID,
|
|
) {
|
|
service.requestsLock.Lock()
|
|
defer service.requestsLock.Unlock()
|
|
|
|
id := service.idFactory.Next()
|
|
request := &activeRequest {
|
|
service: service,
|
|
body: make(chan []byte),
|
|
header: make(http.Header),
|
|
status: http.StatusOK,
|
|
url: req.URL,
|
|
requestBody: req.Body,
|
|
res: res,
|
|
}
|
|
service.requests[id] = request
|
|
return request, id
|
|
}
|
|
|
|
func (service *Service) freeRequest (id protocol.ID) {
|
|
service.requestsLock.Lock()
|
|
defer service.requestsLock.Unlock()
|
|
delete(service.requests, id)
|
|
service.idFactory.Free(id)
|
|
}
|
|
|
|
func (service *Service) work () {
|
|
defer log.Println("X=- connection to", service.conn.RemoteAddr(), "closed")
|
|
unhandle, ok := service.authenticate()
|
|
if !ok { return }
|
|
|
|
// dismantle service when we exit
|
|
defer func () {
|
|
unhandle()
|
|
service.conn.Close()
|
|
|
|
closedRequest := false
|
|
service.requestsLock.Lock()
|
|
for _, request := range service.requests {
|
|
closedRequest = true
|
|
close(request.body)
|
|
}
|
|
service.requestsLock.Unlock()
|
|
|
|
if service.connErr != nil && closedRequest {
|
|
log.Println("!!! service connection error:", service.connErr)
|
|
}
|
|
} ()
|
|
|
|
// respond to requests
|
|
for {
|
|
message, err := protocol.ReadMessage(service.connReadWriter)
|
|
if err != nil {
|
|
service.connErr = err
|
|
return
|
|
}
|
|
|
|
switch message.(type) {
|
|
// set the header and status code of the request. when the first
|
|
// body chunk is sent or the body is ended without any chunks
|
|
// being sent, the serving goroutine will send the header data.
|
|
case protocol.MessageHTTPResponse:
|
|
response := message.(protocol.MessageHTTPResponse)
|
|
service.requestsLock.RLock()
|
|
if _, ok := service.requests[response.ID]; ok {
|
|
service.requests[response.ID].header = http.Header(response.Header)
|
|
service.requests[response.ID].status = int(response.Status)
|
|
service.requests[response.ID].ensureHeader()
|
|
}
|
|
service.requestsLock.RUnlock()
|
|
|
|
// send the body chunk along the body channel to the goroutine
|
|
// serving the request.
|
|
case protocol.MessageHTTPBodySegment:
|
|
segment := message.(protocol.MessageHTTPBodySegment)
|
|
service.requestsLock.RLock()
|
|
if _, ok := service.requests[segment.ID]; ok {
|
|
service.requests[segment.ID].body <- segment.Data
|
|
} else {
|
|
log.Println("!!! service sent segment for nonexistent request", segment.ID)
|
|
}
|
|
service.requestsLock.RUnlock()
|
|
|
|
// close the body channel and remove the request from the list,
|
|
// signaling to the goroutine serving the request that the body
|
|
// has been read entirely.
|
|
case protocol.MessageHTTPBodyEnd:
|
|
end := message.(protocol.MessageHTTPBodyEnd)
|
|
service.requestsLock.Lock()
|
|
if _, ok := service.requests[end.ID]; ok {
|
|
close(service.requests[end.ID].body)
|
|
service.requests[end.ID].body = nil
|
|
} else {
|
|
log.Println("!!! service ended nonexistent request", end.ID)
|
|
}
|
|
service.requestsLock.Unlock()
|
|
|
|
// the client wants the reqeust body
|
|
case protocol.MessageHTTPBodyRequest:
|
|
id := message.(protocol.MessageHTTPBodyRequest).ID
|
|
service.requestsLock.RLock()
|
|
if request, ok := service.requests[id]; ok {
|
|
body := request.requestBody
|
|
request.requestBody = nil
|
|
|
|
if body == nil {
|
|
// TODO: send an error to the service bc this
|
|
// means either the implementation of
|
|
// http.Request is broken or the service has
|
|
// requested the body twice.
|
|
} else {
|
|
go service.sendRequestBody(body, id)
|
|
}
|
|
} else {
|
|
log.Println("!!! service requested body for nonexistent request", id)
|
|
}
|
|
service.requestsLock.RUnlock()
|
|
|
|
// a message from the client.
|
|
case protocol.MessageStatus:
|
|
status := message.(protocol.MessageStatus)
|
|
log.Println("(i)", service.name, "says:", status.Status)
|
|
|
|
default:
|
|
log.Println("ERR service sent unknown type code")
|
|
service.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (service *Service) sendRequestBody (body io.ReadCloser, id protocol.ID) {
|
|
// FIXME: is this a good max buffer size? maybe make it configurable
|
|
buffer := [2048]byte { }
|
|
for {
|
|
n, err := body.Read(buffer[:])
|
|
|
|
if n > 0 {
|
|
if service.send(protocol.MessageHTTPBodySegment {
|
|
ID: id,
|
|
Data: buffer[:n],
|
|
}) != nil { return }
|
|
}
|
|
|
|
if err != nil {
|
|
if err != io.EOF {
|
|
// TODO: send the service an error
|
|
}
|
|
break
|
|
}
|
|
}
|
|
service.send(protocol.MessageHTTPBodyEnd { ID: id })
|
|
}
|
|
|
|
func (service *Service) authenticate () (func(), bool) {
|
|
// get login details
|
|
message, err := protocol.ReadMessage(service.connReadWriter)
|
|
if err != nil {
|
|
service.Close()
|
|
return nil, false
|
|
}
|
|
login, ok := message.(protocol.MessageLogin)
|
|
if !ok {
|
|
log.Println (
|
|
"ERR", service.conn.RemoteAddr(),
|
|
"sent wrong type code, expected",
|
|
protocol.TypeLogin)
|
|
service.send(protocol.MessageStatus {
|
|
Status: protocol.StatusBadMessageType,
|
|
})
|
|
service.Close()
|
|
return nil, false
|
|
}
|
|
|
|
// create pattern
|
|
service.pattern = Pattern {
|
|
Scheme: login.Scheme,
|
|
Host: login.Host,
|
|
Path: login.Path,
|
|
}
|
|
service.pattern.FillDefaults()
|
|
|
|
// validate credentials
|
|
status := service.validate(login.User, login.Key, service.pattern)
|
|
service.send(protocol.MessageStatus { Status: status })
|
|
if status == protocol.StatusOk {
|
|
service.user = login.User
|
|
service.name = login.Name
|
|
service.description = login.Description
|
|
|
|
log.Println (
|
|
"-->", service.conn.RemoteAddr(),
|
|
"logged in as", login.User, "on", service.pattern)
|
|
} else {
|
|
log.Println (
|
|
"ERR", service.conn.RemoteAddr(),
|
|
"failed login as", login.User, "on", service.pattern)
|
|
service.conn.Close()
|
|
return nil, false
|
|
}
|
|
|
|
// mount service on the mux.
|
|
var unhandle func ()
|
|
muxPattern := service.pattern.MuxPattern()
|
|
switch service.pattern.Scheme {
|
|
case "http", "https":
|
|
err = service.router.HTTPMux().Handle(muxPattern, service)
|
|
unhandle = func () {
|
|
service.router.HTTPMux().Unhandle(muxPattern)
|
|
}
|
|
default:
|
|
err = protocol.StatusBadMount
|
|
}
|
|
if err != nil {
|
|
log.Println (
|
|
"ERR", service.name,
|
|
"sent bad mount pattern:", service.pattern)
|
|
service.send(protocol.MessageStatus {
|
|
Status: protocol.StatusBadMount,
|
|
})
|
|
service.Close()
|
|
return nil, false
|
|
}
|
|
|
|
// list the service in the router
|
|
service.router.list(service)
|
|
return func () {
|
|
unhandle()
|
|
service.router.unlist(service)
|
|
}, true
|
|
}
|
|
|
|
func (service *Service) send (message protocol.Message) (err error) {
|
|
service.connLock.Lock()
|
|
defer service.connLock.Unlock()
|
|
err = message.Send(service.connReadWriter)
|
|
if err != nil { return }
|
|
return service.connReadWriter.Flush()
|
|
}
|