hnakra/router/mux/mux.go

256 lines
6.4 KiB
Go
Raw Normal View History

// Package mux provides request multiplexers for all protocols Hnakra supports.
2023-05-25 16:08:56 -06:00
package mux
import "net"
import "sync"
import "path"
import "sort"
import "errors"
import "strings"
import "net/url"
// Resolver represents an object capable of transforming a hosname alias into
// another hostname.
2023-05-25 16:08:56 -06:00
type Resolver interface {
ResolveAlias (alias string) string
}
// Mux is a generic request multiplexer based off of http.ServeMux. It can be
// adapted to any protocol where requests to a URL are handled.
type Mux[HANDLER any] struct {
// These must be non-nil!
Redirect func (*url.URL) HANDLER
NotFound func (*url.URL) HANDLER
Resolver Resolver
mutex sync.RWMutex
exactEntries map[string]muxEntry[HANDLER]
sortedEntries []muxEntry[HANDLER] // slice of entries sorted from longest to shortest.
}
type muxEntry[HANDLER any] struct {
handler HANDLER
pattern string
}
func cleanPath (p string) string {
if p == "" {
return "/"
}
if p[0] != '/' {
p = "/" + p
}
np := path.Clean(p)
// path.Clean removes trailing slash except for root;
// put the trailing slash back if necessary.
if p[len(p)-1] == '/' && np != "/" {
// Fast path for common case of p being the string we want:
if len(p) == len(np)+1 && strings.HasPrefix(p, np) {
np = p
} else {
np += "/"
}
}
return np
}
func stripHostPort (h string) string {
// If no port on host, return unchanged
if !strings.Contains(h, ":") {
return h
}
host, _, err := net.SplitHostPort(h)
if err != nil {
return h // on error, return unchanged
}
return host
}
// Handler returns the handler for a particular URL.
2023-05-25 16:08:56 -06:00
func (mux *Mux[HANDLER]) Handler (where *url.URL) (h HANDLER, pattern string) {
// All other requests have any port stripped and path cleaned
// before passing to mux.handler.
host := stripHostPort(where.Host)
path := cleanPath(where.Path)
host = mux.Resolver.ResolveAlias(host)
pat := host + path
// If the given path is /tree and its handler is not registered,
// redirect for /tree/.
u, shouldRedirect := mux.redirectToPathSlash(host, path, where)
if shouldRedirect {
return mux.Redirect(u), u.Path
}
if path != where.Path {
_, pattern = mux.match(pat, where)
u := &url.URL{Path: path, RawQuery: where.RawQuery}
return mux.Redirect(u), pattern
}
return mux.match(pat, where)
}
func (mux *Mux[HANDLER]) redirectToPathSlash (host, path string, u *url.URL) (*url.URL, bool) {
mux.mutex.RLock()
shouldRedirect := mux.shouldRedirectRLocked(host, path)
mux.mutex.RUnlock()
if !shouldRedirect {
return u, false
}
path = path + "/"
u = &url.URL{Path: path, RawQuery: u.RawQuery}
return u, true
}
func (mux *Mux[HANDLER]) shouldRedirectRLocked (host, path string) bool {
p := []string{path, host + path}
for _, c := range p {
if _, exist := mux.exactEntries[c]; exist {
return false
}
}
n := len(path)
if n == 0 {
return false
}
for _, c := range p {
if _, exist := mux.exactEntries[c+"/"]; exist {
return path[n-1] != '/'
}
}
return false
}
func (mux *Mux[HANDLER]) match (path string, original *url.URL) (h HANDLER, pattern string) {
// Check for exact match first.
entry, matchExists := mux.exactEntries[path]
if matchExists {
return entry.handler, entry.pattern
}
// Check for longest valid match. mux.es contains all patterns
// that end in / sorted from longest to shortest.
for _, entry := range mux.sortedEntries {
if strings.HasPrefix(path, entry.pattern) {
return entry.handler, entry.pattern
}
}
return mux.NotFound(original), ""
}
// Handle registers a handler on the specified pattern. If a pattern ends in
// '/', all requests for URLS under the pattern will be directed to the handler,
// as well as the pattern itself. Additionally, requests for the pattern without
// the trailing slash will be redirected to the pattern with the trailing slash.
2023-05-25 16:08:56 -06:00
func (mux *Mux[HANDLER]) Handle (pattern string, handler HANDLER) error {
mux.mutex.Lock()
defer mux.mutex.Unlock()
if pattern == "" {
return errors.New(
"mux: invalid pattern " + pattern +
", cannot be empty.")
}
if pattern[0] == '/' {
return errors.New(
"mux: invalid pattern " + pattern +
", must be host specific.")
}
if _, exist := mux.exactEntries[pattern]; exist {
return errors.New("mux: existing mount on " + pattern)
}
if mux.exactEntries == nil {
mux.exactEntries = make(map[string] muxEntry[HANDLER])
}
entry := muxEntry[HANDLER] { handler: handler, pattern: pattern }
mux.exactEntries[pattern] = entry
if pattern[len(pattern)-1] == '/' {
mux.sortedEntries = appendSorted(mux.sortedEntries, entry)
}
return nil
}
// Unhandler removes the handler that was registered on the specified pattern.
2023-05-25 16:08:56 -06:00
func (mux *Mux[HANDLER]) Unhandle (pattern string) error {
mux.mutex.Lock()
defer mux.mutex.Unlock()
// delete from exact match list
if _, registered := mux.exactEntries[pattern]; !registered {
return errors.New(
"mux: pattern " + pattern + " is not mounted")
}
delete(mux.exactEntries, pattern)
// delete from sorted list, if its in there.
newLen := 0
for index, entry := range mux.sortedEntries {
if entry.pattern != pattern {
mux.sortedEntries[newLen] = mux.sortedEntries[index]
newLen++
}
}
mux.sortedEntries = mux.sortedEntries[:newLen]
return nil
}
// OverHandlers calls a function for each registered handler.
2023-05-25 16:08:56 -06:00
func (mux *Mux[HANDLER]) OverHandlers (callback func (pattern string, handler HANDLER) bool) {
overSorted (mux.exactEntries, func (pattern string, entry muxEntry[HANDLER]) bool {
return callback(pattern, entry.handler)
})
}
func appendSorted[HANDLER any] (entries []muxEntry[HANDLER], entry muxEntry[HANDLER]) []muxEntry[HANDLER] {
entriesLen := len(entries)
index := sort.Search(entriesLen, func(index int) bool {
return len(entries[index].pattern) < len(entry.pattern)
})
if index == entriesLen {
return append(entries, entry)
}
// we now know that i points at where we want to insert
// try to grow the slice in place, any entry works
entries = append(entries, muxEntry[HANDLER] { })
copy(entries[index+1:], entries[index:]) // Move shorter entries down
entries[index] = entry
return entries
}
func overSorted[V any] (unsorted map[string] V, fn func (string, V) bool) {
keys := sortKeys(unsorted)
for _, key := range keys {
if !fn(key, unsorted[key]) {
break
}
}
}
func sortKeys[V any] (unsorted map[string] V) []string {
sortedKeys := make([]string, len(unsorted))
index := 0
for key, _ := range unsorted {
sortedKeys[index] = key
index ++
}
sort.Strings(sortedKeys)
return sortedKeys
}