Initial commit

This commit is contained in:
Sasha Koshka
2023-05-25 18:08:56 -04:00
commit c300567c0c
51 changed files with 42251 additions and 0 deletions

37
router/mux/http.go Normal file
View File

@@ -0,0 +1,37 @@
package mux
import "net/url"
import "net/http"
type HTTP struct {
Mux[http.Handler]
}
func NewHTTP (resolver Resolver) *HTTP {
mux := &HTTP { }
mux.Mux.Redirect = mux.newRedirect
mux.Mux.NotFound = mux.newNotFound
mux.Mux.Resolver = resolver
return mux
}
func (mux *HTTP) ServeHTTP (res http.ResponseWriter, req *http.Request) {
if req.RequestURI == "*" {
if req.ProtoAtLeast(1, 1) {
res.Header().Set("Connection", "close")
}
res.WriteHeader(http.StatusBadRequest)
return
}
handler, _ := mux.Handler(req.URL)
handler.ServeHTTP(res, req)
}
func (mux *HTTP) newNotFound (where *url.URL) http.Handler {
return http.NotFoundHandler()
}
func (mux *HTTP) newRedirect (where *url.URL) http.Handler {
return http.RedirectHandler(where.String(), http.StatusMovedPermanently)
}

245
router/mux/mux.go Normal file
View File

@@ -0,0 +1,245 @@
package mux
import "net"
import "sync"
import "path"
import "sort"
import "errors"
import "strings"
import "net/url"
type Resolver interface {
ResolveAlias (alias string) string
}
// Mux is a generic request multiplexer based off of http.ServeMux. It can be
// adapted to any protocol where requests to a URL are handled.
type Mux[HANDLER any] struct {
// These must be non-nil!
Redirect func (*url.URL) HANDLER
NotFound func (*url.URL) HANDLER
Resolver Resolver
mutex sync.RWMutex
exactEntries map[string]muxEntry[HANDLER]
sortedEntries []muxEntry[HANDLER] // slice of entries sorted from longest to shortest.
}
type muxEntry[HANDLER any] struct {
handler HANDLER
pattern string
}
func cleanPath (p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
// Fast path for common case of p being the string we want:
if len(p) == len(np)+1 && strings.HasPrefix(p, np) {
np = p
} else {
np += "/"
}
}
return np
}
func stripHostPort (h string) string {
// If no port on host, return unchanged
if !strings.Contains(h, ":") {
return h
}
host, _, err := net.SplitHostPort(h)
if err != nil {
return h // on error, return unchanged
}
return host
}
func (mux *Mux[HANDLER]) Handler (where *url.URL) (h HANDLER, pattern string) {
// All other requests have any port stripped and path cleaned
// before passing to mux.handler.
host := stripHostPort(where.Host)
path := cleanPath(where.Path)
host = mux.Resolver.ResolveAlias(host)
pat := host + path
// If the given path is /tree and its handler is not registered,
// redirect for /tree/.
u, shouldRedirect := mux.redirectToPathSlash(host, path, where)
if shouldRedirect {
return mux.Redirect(u), u.Path
}
if path != where.Path {
_, pattern = mux.match(pat, where)
u := &url.URL{Path: path, RawQuery: where.RawQuery}
return mux.Redirect(u), pattern
}
return mux.match(pat, where)
}
func (mux *Mux[HANDLER]) redirectToPathSlash (host, path string, u *url.URL) (*url.URL, bool) {
mux.mutex.RLock()
shouldRedirect := mux.shouldRedirectRLocked(host, path)
mux.mutex.RUnlock()
if !shouldRedirect {
return u, false
}
path = path + "/"
u = &url.URL{Path: path, RawQuery: u.RawQuery}
return u, true
}
func (mux *Mux[HANDLER]) shouldRedirectRLocked (host, path string) bool {
p := []string{path, host + path}
for _, c := range p {
if _, exist := mux.exactEntries[c]; exist {
return false
}
}
n := len(path)
if n == 0 {
return false
}
for _, c := range p {
if _, exist := mux.exactEntries[c+"/"]; exist {
return path[n-1] != '/'
}
}
return false
}
func (mux *Mux[HANDLER]) match (path string, original *url.URL) (h HANDLER, pattern string) {
// Check for exact match first.
entry, matchExists := mux.exactEntries[path]
if matchExists {
return entry.handler, entry.pattern
}
// Check for longest valid match. mux.es contains all patterns
// that end in / sorted from longest to shortest.
for _, entry := range mux.sortedEntries {
if strings.HasPrefix(path, entry.pattern) {
return entry.handler, entry.pattern
}
}
return mux.NotFound(original), ""
}
func (mux *Mux[HANDLER]) Handle (pattern string, handler HANDLER) error {
mux.mutex.Lock()
defer mux.mutex.Unlock()
if pattern == "" {
return errors.New(
"mux: invalid pattern " + pattern +
", cannot be empty.")
}
if pattern[0] == '/' {
return errors.New(
"mux: invalid pattern " + pattern +
", must be host specific.")
}
if _, exist := mux.exactEntries[pattern]; exist {
return errors.New("mux: existing mount on " + pattern)
}
if mux.exactEntries == nil {
mux.exactEntries = make(map[string] muxEntry[HANDLER])
}
entry := muxEntry[HANDLER] { handler: handler, pattern: pattern }
mux.exactEntries[pattern] = entry
if pattern[len(pattern)-1] == '/' {
mux.sortedEntries = appendSorted(mux.sortedEntries, entry)
}
return nil
}
func (mux *Mux[HANDLER]) Unhandle (pattern string) error {
mux.mutex.Lock()
defer mux.mutex.Unlock()
// delete from exact match list
if _, registered := mux.exactEntries[pattern]; !registered {
return errors.New(
"mux: pattern " + pattern + " is not mounted")
}
delete(mux.exactEntries, pattern)
// delete from sorted list, if its in there.
newLen := 0
for index, entry := range mux.sortedEntries {
if entry.pattern != pattern {
mux.sortedEntries[newLen] = mux.sortedEntries[index]
newLen++
}
}
mux.sortedEntries = mux.sortedEntries[:newLen]
return nil
}
func (mux *Mux[HANDLER]) OverHandlers (callback func (pattern string, handler HANDLER) bool) {
overSorted (mux.exactEntries, func (pattern string, entry muxEntry[HANDLER]) bool {
return callback(pattern, entry.handler)
})
}
func appendSorted[HANDLER any] (entries []muxEntry[HANDLER], entry muxEntry[HANDLER]) []muxEntry[HANDLER] {
entriesLen := len(entries)
index := sort.Search(entriesLen, func(index int) bool {
return len(entries[index].pattern) < len(entry.pattern)
})
if index == entriesLen {
return append(entries, entry)
}
// we now know that i points at where we want to insert
// try to grow the slice in place, any entry works
entries = append(entries, muxEntry[HANDLER] { })
copy(entries[index+1:], entries[index:]) // Move shorter entries down
entries[index] = entry
return entries
}
func overSorted[V any] (unsorted map[string] V, fn func (string, V) bool) {
keys := sortKeys(unsorted)
for _, key := range keys {
if !fn(key, unsorted[key]) {
break
}
}
}
func sortKeys[V any] (unsorted map[string] V) []string {
sortedKeys := make([]string, len(unsorted))
index := 0
for key, _ := range unsorted {
sortedKeys[index] = key
index ++
}
sort.Strings(sortedKeys)
return sortedKeys
}

93
router/router.go Normal file
View File

@@ -0,0 +1,93 @@
package router
import "log"
import "net"
// import "errors"
import "net/http"
import "hnakra/config"
import "hnakra/protocol"
import "hnakra/router/mux"
type Pattern struct {
Scheme string
Host string
Path string
}
func (pattern Pattern) MuxPattern () string {
return pattern.Host + pattern.Path
}
func (pattern Pattern) String () string {
return pattern.Scheme + "://" + pattern.Host + pattern.Path
}
func (pattern *Pattern) FillDefaults () {
if pattern.Scheme == "" { pattern.Scheme = "https" }
if pattern.Host == "" { pattern.Host = "@" }
if pattern.Path == "" { pattern.Path = "/" }
}
type Router struct {
config config.Config
http *mux.HTTP
services map[*Service] struct { }
}
func New (config config.Config) *Router {
router := &Router {
config: config,
http: mux.NewHTTP(config),
services: make(map[*Service] struct { }),
}
return router
}
func (router *Router) HTTPMux () *mux.HTTP {
return router.http
}
func (router *Router) Accept (conn net.Conn) {
router.newService(conn)
}
func (router *Router) ServeHTTP (res http.ResponseWriter, req *http.Request) {
log.Println("->?", req.RemoteAddr, "requests", req.URL.String())
router.http.ServeHTTP(res, req)
}
// TODO: uniquely identify connected services by an ID
func (router *Router) Service (name string) *Service {
for service := range router.services {
if service.Name() == name {
return service
}
}
return nil
}
func (router *Router) OverServices (callback func (service *Service) bool) {
for service := range router.services {
if !callback(service) { break }
}
}
func (router *Router) list (service *Service) {
router.services[service] = struct { } { }
}
func (router *Router) unlist (service *Service) {
delete(router.services, service)
}
func (router *Router) Validate (name string, key []byte, pattern Pattern) protocol.Status {
user := router.config.User(name)
if user == nil || !user.Validate(key) {
return protocol.StatusBadCredentials
}
if !user.CanMountOn(pattern.Scheme, pattern.Host, pattern.Path) {
return protocol.StatusBadMount
}
return protocol.StatusOk
}

392
router/service.go Normal file
View File

@@ -0,0 +1,392 @@
package router
import "io"
import "log"
import "net"
import "sync"
import "bufio"
import "net/url"
import "net/http"
import "hnakra/protocol"
type activeRequest struct {
service *Service
body chan []byte
requestBody io.ReadCloser
header http.Header
status int
url *url.URL
res http.ResponseWriter
wroteHeader bool
headerMutex sync.Mutex
}
func (request *activeRequest) ensureHeader () {
request.headerMutex.Lock()
defer request.headerMutex.Unlock()
if request.wroteHeader { return }
for key, values := range request.header {
for _, value := range values {
request.res.Header().Add(key, value)
}
}
request.res.WriteHeader(request.status)
request.wroteHeader = true
}
type Service struct {
router *Router
validate func (user string, key []byte, pattern Pattern) protocol.Status
pattern Pattern
user, name, description string
idFactory *protocol.IDFactory
requests map[protocol.ID] *activeRequest
requestsLock sync.RWMutex
conn net.Conn
connLock sync.Mutex
connReadWriter *bufio.ReadWriter
connErr error
}
func (router *Router) newService (conn net.Conn) (service *Service) {
service = &Service {
idFactory: protocol.NewRouterIDFactory(),
router: router,
validate: router.Validate,
conn: conn,
requests: make(map[protocol.ID] *activeRequest),
connReadWriter: bufio.NewReadWriter (
bufio.NewReader(conn),
bufio.NewWriter(conn)),
}
go service.work()
return
}
func (service *Service) OverRequests (callback func (protocol.ID, *url.URL) bool) {
service.requestsLock.RLock()
defer service.requestsLock.RUnlock()
for id, request := range service.requests {
if !callback(id, request.url) { break }
}
}
func (service *Service) User () string {
return service.user
}
func (service *Service) Name () string {
return service.name
}
func (service *Service) Description () string {
return service.description
}
func (service *Service) Pattern () Pattern {
return service.pattern
}
func (service *Service) Close () error {
return service.conn.Close()
}
func (service *Service) Shutdown () error {
// TODO: wait for all requests to finish
return service.conn.Close()
}
func (service *Service) ServeHTTP (res http.ResponseWriter, req *http.Request) {
// if we are only accepting https requests and we recieve an http one,
// redirect to the https version
if req.TLS == nil && service.pattern.Scheme == "https" {
newURL := req.URL
newURL.Scheme = service.pattern.Scheme
http.Redirect(res, req, newURL.String(), http.StatusPermanentRedirect)
return
}
// create a new request
request, id := service.generateActiveRequest(res, req)
defer service.freeRequest(id)
// ask the service to fulfill it
err := service.send(protocol.MessageHTTPRequest {
ID: id,
Address: req.RemoteAddr,
Method: req.Method,
URL: req.URL.String(),
Header: protocol.Map(req.Header),
})
if err != nil {
http.Error (
res, "connection to service closed abruptly: " +
err.Error(), http.StatusBadGateway)
return
}
// read in the body chunk bu chunk from the channel, and finalize the
// request when the channel is closed
for chunk := range request.body {
request.ensureHeader()
res.Write(chunk)
}
// if we have stopped serving the page prematurely due to an error,
// inform the user (if possible)
if service.connErr != nil && !request.wroteHeader {
http.Error (
res, "connection to service closed abruptly: " +
service.connErr.Error(), http.StatusBadGateway)
return
}
}
func (service *Service) generateActiveRequest (
res http.ResponseWriter,
req *http.Request,
) (
*activeRequest,
protocol.ID,
) {
service.requestsLock.Lock()
defer service.requestsLock.Unlock()
id := service.idFactory.Next()
request := &activeRequest {
service: service,
body: make(chan []byte),
header: make(http.Header),
status: http.StatusOK,
url: req.URL,
requestBody: req.Body,
res: res,
}
service.requests[id] = request
return request, id
}
func (service *Service) freeRequest (id protocol.ID) {
service.requestsLock.Lock()
defer service.requestsLock.Unlock()
delete(service.requests, id)
service.idFactory.Free(id)
}
func (service *Service) work () {
defer log.Println("X=- connection to", service.conn.RemoteAddr(), "closed")
unhandle, ok := service.authenticate()
if !ok { return }
// dismantle service when we exit
defer func () {
unhandle()
service.conn.Close()
closedRequest := false
service.requestsLock.Lock()
for _, request := range service.requests {
closedRequest = true
close(request.body)
}
service.requestsLock.Unlock()
if service.connErr != nil && closedRequest {
log.Println("!!! service connection error:", service.connErr)
}
} ()
// respond to requests
for {
message, err := protocol.ReadMessage(service.connReadWriter)
if err != nil {
service.connErr = err
return
}
switch message.(type) {
// set the header and status code of the request. when the first
// body chunk is sent or the body is ended without any chunks
// being sent, the serving goroutine will send the header data.
case protocol.MessageHTTPResponse:
response := message.(protocol.MessageHTTPResponse)
service.requestsLock.RLock()
if _, ok := service.requests[response.ID]; ok {
service.requests[response.ID].header = http.Header(response.Header)
service.requests[response.ID].status = int(response.Status)
service.requests[response.ID].ensureHeader()
}
service.requestsLock.RUnlock()
// send the body chunk along the body channel to the goroutine
// serving the request.
case protocol.MessageHTTPBodySegment:
segment := message.(protocol.MessageHTTPBodySegment)
service.requestsLock.RLock()
if _, ok := service.requests[segment.ID]; ok {
service.requests[segment.ID].body <- segment.Data
} else {
log.Println("!!! service sent segment for nonexistent request", segment.ID)
}
service.requestsLock.RUnlock()
// close the body channel and remove the request from the list,
// signaling to the goroutine serving the request that the body
// has been read entirely.
case protocol.MessageHTTPBodyEnd:
end := message.(protocol.MessageHTTPBodyEnd)
service.requestsLock.Lock()
if _, ok := service.requests[end.ID]; ok {
close(service.requests[end.ID].body)
service.requests[end.ID].body = nil
} else {
log.Println("!!! service ended nonexistent request", end.ID)
}
service.requestsLock.Unlock()
// the client wants the reqeust body
case protocol.MessageHTTPBodyRequest:
id := message.(protocol.MessageHTTPBodyRequest).ID
service.requestsLock.RLock()
if request, ok := service.requests[id]; ok {
body := request.requestBody
request.requestBody = nil
if body == nil {
// TODO: send an error to the service bc this
// means either the implementation of
// http.Request is broken or the service has
// requested the body twice.
} else {
go service.sendRequestBody(body, id)
}
} else {
log.Println("!!! service requested body for nonexistent request", id)
}
service.requestsLock.RUnlock()
// a message from the client.
case protocol.MessageStatus:
status := message.(protocol.MessageStatus)
log.Println("(i)", service.name, "says:", status.Status)
default:
log.Println("ERR service sent unknown type code")
service.Close()
}
}
}
func (service *Service) sendRequestBody (body io.ReadCloser, id protocol.ID) {
// FIXME: is this a good max buffer size? maybe make it configurable
buffer := [2048]byte { }
for {
n, err := body.Read(buffer[:])
if n > 0 {
if service.send(protocol.MessageHTTPBodySegment {
ID: id,
Data: buffer[:n],
}) != nil { return }
}
if err != nil {
if err != io.EOF {
// TODO: send the service an error
}
break
}
}
service.send(protocol.MessageHTTPBodyEnd { ID: id })
}
func (service *Service) authenticate () (func(), bool) {
// get login details
message, err := protocol.ReadMessage(service.connReadWriter)
if err != nil {
service.Close()
return nil, false
}
login, ok := message.(protocol.MessageLogin)
if !ok {
log.Println (
"ERR", service.conn.RemoteAddr(),
"sent wrong type code, expected",
protocol.TypeLogin)
service.send(protocol.MessageStatus {
Status: protocol.StatusBadMessageType,
})
service.Close()
return nil, false
}
// create pattern
service.pattern = Pattern {
Scheme: login.Scheme,
Host: login.Host,
Path: login.Path,
}
service.pattern.FillDefaults()
// validate credentials
status := service.validate(login.User, login.Key, service.pattern)
service.send(protocol.MessageStatus { Status: status })
if status == protocol.StatusOk {
service.user = login.User
service.name = login.Name
service.description = login.Description
log.Println (
"-->", service.conn.RemoteAddr(),
"logged in as", login.User, "on", service.pattern)
} else {
log.Println (
"ERR", service.conn.RemoteAddr(),
"failed login as", login.User, "on", service.pattern)
service.conn.Close()
return nil, false
}
// mount service on the mux.
var unhandle func ()
muxPattern := service.pattern.MuxPattern()
switch service.pattern.Scheme {
case "http", "https":
err = service.router.HTTPMux().Handle(muxPattern, service)
unhandle = func () {
service.router.HTTPMux().Unhandle(muxPattern)
}
default:
err = protocol.StatusBadMount
}
if err != nil {
log.Println (
"ERR", service.name,
"sent bad mount pattern:", service.pattern)
service.send(protocol.MessageStatus {
Status: protocol.StatusBadMount,
})
service.Close()
return nil, false
}
// list the service in the router
service.router.list(service)
return func () {
unhandle()
service.router.unlist(service)
}, true
}
func (service *Service) send (message protocol.Message) (err error) {
service.connLock.Lock()
defer service.connLock.Unlock()
err = message.Send(service.connReadWriter)
if err != nil { return }
return service.connReadWriter.Flush()
}