256 lines
6.4 KiB
Go
256 lines
6.4 KiB
Go
// Package mux provides request multiplexers for all protocols Hnakra supports.
|
|
package mux
|
|
|
|
import "net"
|
|
import "sync"
|
|
import "path"
|
|
import "sort"
|
|
import "errors"
|
|
import "strings"
|
|
import "net/url"
|
|
|
|
// Resolver represents an object capable of transforming a hosname alias into
|
|
// another hostname.
|
|
type Resolver interface {
|
|
ResolveAlias (alias string) string
|
|
}
|
|
|
|
// Mux is a generic request multiplexer based off of http.ServeMux. It can be
|
|
// adapted to any protocol where requests to a URL are handled.
|
|
type Mux[HANDLER any] struct {
|
|
// These must be non-nil!
|
|
Redirect func (*url.URL) HANDLER
|
|
NotFound func (*url.URL) HANDLER
|
|
Resolver Resolver
|
|
|
|
mutex sync.RWMutex
|
|
exactEntries map[string]muxEntry[HANDLER]
|
|
sortedEntries []muxEntry[HANDLER] // slice of entries sorted from longest to shortest.
|
|
}
|
|
|
|
type muxEntry[HANDLER any] struct {
|
|
handler HANDLER
|
|
pattern string
|
|
}
|
|
|
|
func cleanPath (p string) string {
|
|
if p == "" {
|
|
return "/"
|
|
}
|
|
|
|
if p[0] != '/' {
|
|
p = "/" + p
|
|
}
|
|
|
|
np := path.Clean(p)
|
|
// path.Clean removes trailing slash except for root;
|
|
// put the trailing slash back if necessary.
|
|
if p[len(p)-1] == '/' && np != "/" {
|
|
// Fast path for common case of p being the string we want:
|
|
if len(p) == len(np)+1 && strings.HasPrefix(p, np) {
|
|
np = p
|
|
} else {
|
|
np += "/"
|
|
}
|
|
}
|
|
|
|
return np
|
|
}
|
|
|
|
func stripHostPort (h string) string {
|
|
// If no port on host, return unchanged
|
|
if !strings.Contains(h, ":") {
|
|
return h
|
|
}
|
|
|
|
host, _, err := net.SplitHostPort(h)
|
|
if err != nil {
|
|
return h // on error, return unchanged
|
|
}
|
|
|
|
return host
|
|
}
|
|
|
|
// Handler returns the handler for a particular URL.
|
|
func (mux *Mux[HANDLER]) Handler (where *url.URL) (h HANDLER, pattern string) {
|
|
// All other requests have any port stripped and path cleaned
|
|
// before passing to mux.handler.
|
|
host := stripHostPort(where.Host)
|
|
path := cleanPath(where.Path)
|
|
host = mux.Resolver.ResolveAlias(host)
|
|
pat := host + path
|
|
|
|
// If the given path is /tree and its handler is not registered,
|
|
// redirect for /tree/.
|
|
u, shouldRedirect := mux.redirectToPathSlash(host, path, where)
|
|
if shouldRedirect {
|
|
return mux.Redirect(u), u.Path
|
|
}
|
|
|
|
if path != where.Path {
|
|
_, pattern = mux.match(pat, where)
|
|
u := &url.URL{Path: path, RawQuery: where.RawQuery}
|
|
return mux.Redirect(u), pattern
|
|
}
|
|
|
|
return mux.match(pat, where)
|
|
}
|
|
|
|
func (mux *Mux[HANDLER]) redirectToPathSlash (host, path string, u *url.URL) (*url.URL, bool) {
|
|
mux.mutex.RLock()
|
|
shouldRedirect := mux.shouldRedirectRLocked(host, path)
|
|
mux.mutex.RUnlock()
|
|
if !shouldRedirect {
|
|
return u, false
|
|
}
|
|
path = path + "/"
|
|
u = &url.URL{Path: path, RawQuery: u.RawQuery}
|
|
return u, true
|
|
}
|
|
|
|
func (mux *Mux[HANDLER]) shouldRedirectRLocked (host, path string) bool {
|
|
p := []string{path, host + path}
|
|
|
|
for _, c := range p {
|
|
if _, exist := mux.exactEntries[c]; exist {
|
|
return false
|
|
}
|
|
}
|
|
|
|
n := len(path)
|
|
if n == 0 {
|
|
return false
|
|
}
|
|
for _, c := range p {
|
|
if _, exist := mux.exactEntries[c+"/"]; exist {
|
|
return path[n-1] != '/'
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (mux *Mux[HANDLER]) match (path string, original *url.URL) (h HANDLER, pattern string) {
|
|
// Check for exact match first.
|
|
entry, matchExists := mux.exactEntries[path]
|
|
if matchExists {
|
|
return entry.handler, entry.pattern
|
|
}
|
|
|
|
// Check for longest valid match. mux.es contains all patterns
|
|
// that end in / sorted from longest to shortest.
|
|
for _, entry := range mux.sortedEntries {
|
|
if strings.HasPrefix(path, entry.pattern) {
|
|
return entry.handler, entry.pattern
|
|
}
|
|
}
|
|
|
|
return mux.NotFound(original), ""
|
|
}
|
|
|
|
// Handle registers a handler on the specified pattern. If a pattern ends in
|
|
// '/', all requests for URLS under the pattern will be directed to the handler,
|
|
// as well as the pattern itself. Additionally, requests for the pattern without
|
|
// the trailing slash will be redirected to the pattern with the trailing slash.
|
|
func (mux *Mux[HANDLER]) Handle (pattern string, handler HANDLER) error {
|
|
mux.mutex.Lock()
|
|
defer mux.mutex.Unlock()
|
|
|
|
if pattern == "" {
|
|
return errors.New(
|
|
"mux: invalid pattern " + pattern +
|
|
", cannot be empty.")
|
|
}
|
|
|
|
if pattern[0] == '/' {
|
|
return errors.New(
|
|
"mux: invalid pattern " + pattern +
|
|
", must be host specific.")
|
|
}
|
|
|
|
if _, exist := mux.exactEntries[pattern]; exist {
|
|
return errors.New("mux: existing mount on " + pattern)
|
|
}
|
|
|
|
if mux.exactEntries == nil {
|
|
mux.exactEntries = make(map[string] muxEntry[HANDLER])
|
|
}
|
|
|
|
entry := muxEntry[HANDLER] { handler: handler, pattern: pattern }
|
|
mux.exactEntries[pattern] = entry
|
|
if pattern[len(pattern)-1] == '/' {
|
|
mux.sortedEntries = appendSorted(mux.sortedEntries, entry)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Unhandler removes the handler that was registered on the specified pattern.
|
|
func (mux *Mux[HANDLER]) Unhandle (pattern string) error {
|
|
mux.mutex.Lock()
|
|
defer mux.mutex.Unlock()
|
|
|
|
// delete from exact match list
|
|
if _, registered := mux.exactEntries[pattern]; !registered {
|
|
return errors.New(
|
|
"mux: pattern " + pattern + " is not mounted")
|
|
}
|
|
delete(mux.exactEntries, pattern)
|
|
|
|
// delete from sorted list, if its in there.
|
|
newLen := 0
|
|
for index, entry := range mux.sortedEntries {
|
|
if entry.pattern != pattern {
|
|
mux.sortedEntries[newLen] = mux.sortedEntries[index]
|
|
newLen++
|
|
}
|
|
}
|
|
mux.sortedEntries = mux.sortedEntries[:newLen]
|
|
return nil
|
|
}
|
|
|
|
// OverHandlers calls a function for each registered handler.
|
|
func (mux *Mux[HANDLER]) OverHandlers (callback func (pattern string, handler HANDLER) bool) {
|
|
overSorted (mux.exactEntries, func (pattern string, entry muxEntry[HANDLER]) bool {
|
|
return callback(pattern, entry.handler)
|
|
})
|
|
}
|
|
|
|
func appendSorted[HANDLER any] (entries []muxEntry[HANDLER], entry muxEntry[HANDLER]) []muxEntry[HANDLER] {
|
|
entriesLen := len(entries)
|
|
index := sort.Search(entriesLen, func(index int) bool {
|
|
return len(entries[index].pattern) < len(entry.pattern)
|
|
})
|
|
|
|
if index == entriesLen {
|
|
return append(entries, entry)
|
|
}
|
|
|
|
// we now know that i points at where we want to insert
|
|
// try to grow the slice in place, any entry works
|
|
entries = append(entries, muxEntry[HANDLER] { })
|
|
copy(entries[index+1:], entries[index:]) // Move shorter entries down
|
|
entries[index] = entry
|
|
return entries
|
|
}
|
|
|
|
func overSorted[V any] (unsorted map[string] V, fn func (string, V) bool) {
|
|
keys := sortKeys(unsorted)
|
|
for _, key := range keys {
|
|
if !fn(key, unsorted[key]) {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func sortKeys[V any] (unsorted map[string] V) []string {
|
|
sortedKeys := make([]string, len(unsorted))
|
|
index := 0
|
|
for key, _ := range unsorted {
|
|
sortedKeys[index] = key
|
|
index ++
|
|
}
|
|
sort.Strings(sortedKeys)
|
|
|
|
return sortedKeys
|
|
}
|