package router import "io" import "log" import "net" import "sync" import "bufio" import "net/url" import "net/http" import "hnakra/protocol" type activeRequest struct { service *Service body chan []byte requestBody io.ReadCloser header http.Header status int url *url.URL res http.ResponseWriter wroteHeader bool headerMutex sync.Mutex } func (request *activeRequest) ensureHeader () { request.headerMutex.Lock() defer request.headerMutex.Unlock() if request.wroteHeader { return } for key, values := range request.header { for _, value := range values { request.res.Header().Add(key, value) } } request.res.WriteHeader(request.status) request.wroteHeader = true } type Service struct { router *Router scheme, host, path string user, name, description string idFactory *protocol.IDFactory requests map[protocol.ID] *activeRequest requestsLock sync.RWMutex conn net.Conn connLock sync.Mutex connReadWriter *bufio.ReadWriter connErr error } func (router *Router) newService (conn net.Conn) (service *Service) { service = &Service { idFactory: protocol.NewRouterIDFactory(), router: router, conn: conn, requests: make(map[protocol.ID] *activeRequest), connReadWriter: bufio.NewReadWriter ( bufio.NewReader(conn), bufio.NewWriter(conn)), } go service.work() return } func (service *Service) OverRequests (callback func (protocol.ID, *url.URL) bool) { service.requestsLock.RLock() defer service.requestsLock.RUnlock() for id, request := range service.requests { if !callback(id, request.url) { break } } } func (service *Service) User () string { return service.user } func (service *Service) Name () string { return service.name } func (service *Service) Description () string { return service.description } func (service *Service) Pattern () (user, name, description string) { return service.user, service.name, service.description } func (service *Service) Close () error { return service.conn.Close() } func (service *Service) Shutdown () error { // TODO: wait for all requests to finish return service.conn.Close() } func (service *Service) ServeHTTP (res http.ResponseWriter, req *http.Request) { // if we are only accepting https requests and we recieve an http one, // redirect to the https version if req.TLS == nil && service.scheme == "https" { newURL := req.URL newURL.Scheme = service.scheme http.Redirect(res, req, newURL.String(), http.StatusPermanentRedirect) return } // create a new request request, id := service.generateActiveRequest(res, req) defer service.freeRequest(id) // ask the service to fulfill it err := service.send(protocol.MessageHTTPRequest { ID: id, Address: req.RemoteAddr, Method: req.Method, URL: req.URL.String(), Header: protocol.Map(req.Header), }) if err != nil { http.Error ( res, "connection to service closed abruptly: " + err.Error(), http.StatusBadGateway) return } // read in the body chunk bu chunk from the channel, and finalize the // request when the channel is closed for chunk := range request.body { request.ensureHeader() res.Write(chunk) } // if we have stopped serving the page prematurely due to an error, // inform the user (if possible) if service.connErr != nil && !request.wroteHeader { http.Error ( res, "connection to service closed abruptly: " + service.connErr.Error(), http.StatusBadGateway) return } } func (service *Service) generateActiveRequest ( res http.ResponseWriter, req *http.Request, ) ( *activeRequest, protocol.ID, ) { service.requestsLock.Lock() defer service.requestsLock.Unlock() id := service.idFactory.Next() request := &activeRequest { service: service, body: make(chan []byte), header: make(http.Header), status: http.StatusOK, url: req.URL, requestBody: req.Body, res: res, } service.requests[id] = request return request, id } func (service *Service) freeRequest (id protocol.ID) { service.requestsLock.Lock() defer service.requestsLock.Unlock() delete(service.requests, id) service.idFactory.Free(id) } func (service *Service) work () { defer log.Println("X=- connection to", service.conn.RemoteAddr(), "closed") unhandle, ok := service.authenticate() if !ok { return } // dismantle service when we exit defer func () { unhandle() service.conn.Close() closedRequest := false service.requestsLock.Lock() for _, request := range service.requests { closedRequest = true close(request.body) } service.requestsLock.Unlock() if service.connErr != nil && closedRequest { log.Println("!!! service connection error:", service.connErr) } } () // respond to requests for { message, err := protocol.ReadMessage(service.connReadWriter) if err != nil { service.connErr = err return } switch message.(type) { // set the header and status code of the request. when the first // body chunk is sent or the body is ended without any chunks // being sent, the serving goroutine will send the header data. case protocol.MessageHTTPResponse: response := message.(protocol.MessageHTTPResponse) service.requestsLock.RLock() if _, ok := service.requests[response.ID]; ok { service.requests[response.ID].header = http.Header(response.Header) service.requests[response.ID].status = int(response.Status) service.requests[response.ID].ensureHeader() } service.requestsLock.RUnlock() // send the body chunk along the body channel to the goroutine // serving the request. case protocol.MessageHTTPBodySegment: segment := message.(protocol.MessageHTTPBodySegment) service.requestsLock.RLock() if _, ok := service.requests[segment.ID]; ok { service.requests[segment.ID].body <- segment.Data } else { log.Println("!!! service sent segment for nonexistent request", segment.ID) } service.requestsLock.RUnlock() // close the body channel and remove the request from the list, // signaling to the goroutine serving the request that the body // has been read entirely. case protocol.MessageHTTPBodyEnd: end := message.(protocol.MessageHTTPBodyEnd) service.requestsLock.Lock() if _, ok := service.requests[end.ID]; ok { close(service.requests[end.ID].body) service.requests[end.ID].body = nil } else { log.Println("!!! service ended nonexistent request", end.ID) } service.requestsLock.Unlock() // the client wants the reqeust body case protocol.MessageHTTPBodyRequest: id := message.(protocol.MessageHTTPBodyRequest).ID service.requestsLock.RLock() if request, ok := service.requests[id]; ok { body := request.requestBody request.requestBody = nil if body == nil { // TODO: send an error to the service bc this // means either the implementation of // http.Request is broken or the service has // requested the body twice. } else { go service.sendRequestBody(body, id) } } else { log.Println("!!! service requested body for nonexistent request", id) } service.requestsLock.RUnlock() // a message from the client. case protocol.MessageStatus: status := message.(protocol.MessageStatus) log.Println("(i)", service.name, "says:", status.Status) default: log.Println("ERR service sent unknown type code") service.Close() } } } func (service *Service) sendRequestBody (body io.ReadCloser, id protocol.ID) { // FIXME: is this a good max buffer size? maybe make it configurable buffer := [2048]byte { } for { n, err := body.Read(buffer[:]) if n > 0 { if service.send(protocol.MessageHTTPBodySegment { ID: id, Data: buffer[:n], }) != nil { return } } if err != nil { if err != io.EOF { // TODO: send the service an error } break } } service.send(protocol.MessageHTTPBodyEnd { ID: id }) } func (service *Service) authenticate () (func(), bool) { // get login details message, err := protocol.ReadMessage(service.connReadWriter) if err != nil { service.Close() return nil, false } login, ok := message.(protocol.MessageLogin) if !ok { log.Println ( "ERR", service.conn.RemoteAddr(), "sent wrong type code, expected", protocol.TypeLogin) service.send(protocol.MessageStatus { Status: protocol.StatusBadMessageType, }) service.Close() return nil, false } // create pattern service.scheme = login.Scheme service.host = login.Host service.path = login.Path if service.scheme == "" { service.scheme = "https://" } if service.host == "" { service.host = "@" } if service.path == "" { service.path = "/" } // validate credentials status := service.router.Validate ( login.User, login.Key, service.scheme, service.host, service.path) service.send(protocol.MessageStatus { Status: status }) if status == protocol.StatusOk { service.user = login.User service.name = login.Name service.description = login.Description log.Println ( "-->", service.conn.RemoteAddr(), "logged in as", login.User, "on", FormatPattern(service.scheme, service.host, service.path)) } else { log.Println ( "ERR", service.conn.RemoteAddr(), "failed login as", login.User, "on", FormatPattern(service.scheme, service.host, service.path)) service.conn.Close() return nil, false } // mount service on the mux. var unhandle func () muxPattern := service.host + service.path switch service.scheme { case "http", "https": err = service.router.HTTPMux().Handle(muxPattern, service) unhandle = func () { service.router.HTTPMux().Unhandle(muxPattern) } default: err = protocol.StatusBadMount } if err != nil { log.Println ( "ERR", service.name, "sent bad mount pattern:", FormatPattern(service.scheme, service.host, service.path)) service.send(protocol.MessageStatus { Status: protocol.StatusBadMount, }) service.Close() return nil, false } // list the service in the router service.router.list(service) return func () { unhandle() service.router.unlist(service) }, true } func (service *Service) send (message protocol.Message) (err error) { service.connLock.Lock() defer service.connLock.Unlock() err = message.Send(service.connReadWriter) if err != nil { return } return service.connReadWriter.Flush() }