// Package mux provides request multiplexers for all protocols Hnakra supports. package mux import "net" import "sync" import "path" import "sort" import "errors" import "strings" import "net/url" // Resolver represents an object capable of transforming a hosname alias into // another hostname. type Resolver interface { ResolveAlias (alias string) string } // Mux is a generic request multiplexer based off of http.ServeMux. It can be // adapted to any protocol where requests to a URL are handled. type Mux[HANDLER any] struct { // These must be non-nil! Redirect func (*url.URL) HANDLER NotFound func (*url.URL) HANDLER Resolver Resolver mutex sync.RWMutex exactEntries map[string]muxEntry[HANDLER] sortedEntries []muxEntry[HANDLER] // slice of entries sorted from longest to shortest. } type muxEntry[HANDLER any] struct { handler HANDLER pattern string } func cleanPath (p string) string { if p == "" { return "/" } if p[0] != '/' { p = "/" + p } np := path.Clean(p) // path.Clean removes trailing slash except for root; // put the trailing slash back if necessary. if p[len(p)-1] == '/' && np != "/" { // Fast path for common case of p being the string we want: if len(p) == len(np)+1 && strings.HasPrefix(p, np) { np = p } else { np += "/" } } return np } func stripHostPort (h string) string { // If no port on host, return unchanged if !strings.Contains(h, ":") { return h } host, _, err := net.SplitHostPort(h) if err != nil { return h // on error, return unchanged } return host } // Handler returns the handler for a particular URL. func (mux *Mux[HANDLER]) Handler (where *url.URL) (h HANDLER, pattern string) { // All other requests have any port stripped and path cleaned // before passing to mux.handler. host := stripHostPort(where.Host) path := cleanPath(where.Path) host = mux.Resolver.ResolveAlias(host) pat := host + path // If the given path is /tree and its handler is not registered, // redirect for /tree/. u, shouldRedirect := mux.redirectToPathSlash(host, path, where) if shouldRedirect { return mux.Redirect(u), u.Path } if path != where.Path { _, pattern = mux.match(pat, where) u := &url.URL{Path: path, RawQuery: where.RawQuery} return mux.Redirect(u), pattern } return mux.match(pat, where) } func (mux *Mux[HANDLER]) redirectToPathSlash (host, path string, u *url.URL) (*url.URL, bool) { mux.mutex.RLock() shouldRedirect := mux.shouldRedirectRLocked(host, path) mux.mutex.RUnlock() if !shouldRedirect { return u, false } path = path + "/" u = &url.URL{Path: path, RawQuery: u.RawQuery} return u, true } func (mux *Mux[HANDLER]) shouldRedirectRLocked (host, path string) bool { p := []string{path, host + path} for _, c := range p { if _, exist := mux.exactEntries[c]; exist { return false } } n := len(path) if n == 0 { return false } for _, c := range p { if _, exist := mux.exactEntries[c+"/"]; exist { return path[n-1] != '/' } } return false } func (mux *Mux[HANDLER]) match (path string, original *url.URL) (h HANDLER, pattern string) { // Check for exact match first. entry, matchExists := mux.exactEntries[path] if matchExists { return entry.handler, entry.pattern } // Check for longest valid match. mux.es contains all patterns // that end in / sorted from longest to shortest. for _, entry := range mux.sortedEntries { if strings.HasPrefix(path, entry.pattern) { return entry.handler, entry.pattern } } return mux.NotFound(original), "" } // Handle registers a handler on the specified pattern. If a pattern ends in // '/', all requests for URLS under the pattern will be directed to the handler, // as well as the pattern itself. Additionally, requests for the pattern without // the trailing slash will be redirected to the pattern with the trailing slash. func (mux *Mux[HANDLER]) Handle (pattern string, handler HANDLER) error { mux.mutex.Lock() defer mux.mutex.Unlock() if pattern == "" { return errors.New( "mux: invalid pattern " + pattern + ", cannot be empty.") } if pattern[0] == '/' { return errors.New( "mux: invalid pattern " + pattern + ", must be host specific.") } if _, exist := mux.exactEntries[pattern]; exist { return errors.New("mux: existing mount on " + pattern) } if mux.exactEntries == nil { mux.exactEntries = make(map[string] muxEntry[HANDLER]) } entry := muxEntry[HANDLER] { handler: handler, pattern: pattern } mux.exactEntries[pattern] = entry if pattern[len(pattern)-1] == '/' { mux.sortedEntries = appendSorted(mux.sortedEntries, entry) } return nil } // Unhandler removes the handler that was registered on the specified pattern. func (mux *Mux[HANDLER]) Unhandle (pattern string) error { mux.mutex.Lock() defer mux.mutex.Unlock() // delete from exact match list if _, registered := mux.exactEntries[pattern]; !registered { return errors.New( "mux: pattern " + pattern + " is not mounted") } delete(mux.exactEntries, pattern) // delete from sorted list, if its in there. newLen := 0 for index, entry := range mux.sortedEntries { if entry.pattern != pattern { mux.sortedEntries[newLen] = mux.sortedEntries[index] newLen++ } } mux.sortedEntries = mux.sortedEntries[:newLen] return nil } // OverHandlers calls a function for each registered handler. func (mux *Mux[HANDLER]) OverHandlers (callback func (pattern string, handler HANDLER) bool) { overSorted (mux.exactEntries, func (pattern string, entry muxEntry[HANDLER]) bool { return callback(pattern, entry.handler) }) } func appendSorted[HANDLER any] (entries []muxEntry[HANDLER], entry muxEntry[HANDLER]) []muxEntry[HANDLER] { entriesLen := len(entries) index := sort.Search(entriesLen, func(index int) bool { return len(entries[index].pattern) < len(entry.pattern) }) if index == entriesLen { return append(entries, entry) } // we now know that i points at where we want to insert // try to grow the slice in place, any entry works entries = append(entries, muxEntry[HANDLER] { }) copy(entries[index+1:], entries[index:]) // Move shorter entries down entries[index] = entry return entries } func overSorted[V any] (unsorted map[string] V, fn func (string, V) bool) { keys := sortKeys(unsorted) for _, key := range keys { if !fn(key, unsorted[key]) { break } } } func sortKeys[V any] (unsorted map[string] V) []string { sortedKeys := make([]string, len(unsorted)) index := 0 for key, _ := range unsorted { sortedKeys[index] = key index ++ } sort.Strings(sortedKeys) return sortedKeys }