diff --git a/examples/auth/auth.go b/examples/auth/auth.go index c126de8..94b3016 100644 --- a/examples/auth/auth.go +++ b/examples/auth/auth.go @@ -46,11 +46,11 @@ func main() { handler := &gmi.ServeMux{} handler.HandleFunc("/", welcome) - handler.HandleFunc("/login/", login) - handler.HandleFunc("/login/password/", loginPassword) - handler.HandleFunc("/profile/", profile) - handler.HandleFunc("/admin/", admin) - handler.HandleFunc("/logout/", logout) + handler.HandleFunc("/login", login) + handler.HandleFunc("/login/password", loginPassword) + handler.HandleFunc("/profile", profile) + handler.HandleFunc("/admin", admin) + handler.HandleFunc("/logout", logout) server := &gmi.Server{ Certificate: cert, diff --git a/gemini_test.go b/gemini_test.go deleted file mode 100644 index 9bba3f7..0000000 --- a/gemini_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package gmi - -import ( - "math/rand" - "testing" - "time" -) - -func TestServeMuxEntryOrder(t *testing.T) { - expected := []string{ - "https://example.com/longpath", - "https://example.com/path", - "https://example.com", - "http://example.com/longpath", - "http://example.com/path", - "http://example.com", - "example.com/longpath", - "example.com/path", - "example.com", - "/longpath", - "/path", - "/", - } - - // Shuffle input - a := make([]string, len(expected)) - copy(a, expected) - rand.Seed(time.Now().UnixNano()) - rand.Shuffle(len(a), func(i, j int) { a[i], a[j] = a[j], a[i] }) - - mux := &ServeMux{} - for _, s := range a { - mux.Handle(s, NotFoundHandler()) - var es string - for i := range mux.es { - es += mux.es[i].u.String() + " " - } - t.Logf(es) - } - for i, e := range mux.es { - s := e.u.String() - if s != expected[i] { - t.Errorf("wrong order of mux entries: expected %s, got %s", expected[i], s) - } - } -} diff --git a/server.go b/server.go index 1874016..cea6f64 100644 --- a/server.go +++ b/server.go @@ -454,15 +454,15 @@ func isSlashRune(r rune) bool { return r == '/' || r == '\\' } // header, stripping the port number and redirecting any request containing . or // .. elements or repeated slashes to an equivalent, cleaner URL. type ServeMux struct { - mu sync.RWMutex - m map[string]muxEntry - es []muxEntry // slice of entries sorted from longest to shortest. + mu sync.RWMutex + m map[string]muxEntry + es []muxEntry // slice of entries sorted from longest to shortest. + hosts bool // whether any patterns contain hostnames } type muxEntry struct { h Handler pattern string - u *url.URL } // NewServeMux allocates and returns a new ServeMux. @@ -505,19 +505,17 @@ func stripHostPort(h string) string { // Find a handler on a handler map given a path string. // Most-specific (longest) pattern wins. -func (mux *ServeMux) match(url *url.URL) (h Handler, pattern string) { +func (mux *ServeMux) match(path string) (h Handler, pattern string) { // Check for exact match first. - v, ok := mux.m[url.String()] + v, ok := mux.m[path] if ok { return v.h, v.pattern } - // Check for longest valid match. mux.es contains all patterns + // Check for longest valid match. mux.es contains all patterns // that end in / sorted from longest to shortest. for _, e := range mux.es { - if (e.u.Scheme == "" || url.Scheme == e.u.Scheme) && - (e.u.Host == "" || url.Host == e.u.Host) && - strings.HasPrefix(url.Path, e.u.Path) { + if strings.HasPrefix(path, e.pattern) { return e.h, e.pattern } } @@ -572,6 +570,8 @@ func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool { // to the canonical path. If the host contains a port, it is ignored // when matching handlers. // +// The path and host are used unchanged for CONNECT requests. +// // Handler also returns the registered pattern that matches the // request or, in the case of internally-generated redirects, // the pattern that will match after following the redirect. @@ -579,34 +579,40 @@ func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool { // If there is no registered handler that applies to the request, // Handler returns a ``page not found'' handler and an empty pattern. func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { - // All requests have any port stripped and path cleaned + // All other requests have any port stripped and path cleaned // before passing to mux.handler. - url := *r.URL - url.Host = stripHostPort(r.Host) - url.Path = cleanPath(r.URL.Path) + host := stripHostPort(r.Host) + path := cleanPath(r.URL.Path) // If the given path is /tree and its handler is not registered, // redirect for /tree/. - if u, ok := mux.redirectToPathSlash(url.Host, url.Path, r.URL); ok { - return PermanentRedirectHandler(u.String()), u.Path + if u, ok := mux.redirectToPathSlash(host, path, r.URL); ok { + return RedirectHandler(u.String()), u.Path } - if url.Path != r.URL.Path { - _, pattern = mux.handler(&url) - red := *r.URL - red.Path = url.Path - return PermanentRedirectHandler(red.String()), pattern + if path != r.URL.Path { + _, pattern = mux.handler(host, path) + url := *r.URL + url.Path = path + return RedirectHandler(url.String()), pattern } - return mux.handler(&url) + return mux.handler(host, r.URL.Path) } // handler is the main implementation of Handler. -func (mux *ServeMux) handler(url *url.URL) (h Handler, pattern string) { +// The path is known to be in canonical form, except for CONNECT methods. +func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) { mux.mu.RLock() defer mux.mu.RUnlock() - h, pattern = mux.match(url) + // Host-specific pattern takes precedence over generic ones + if mux.hosts { + h, pattern = mux.match(host + path) + } + if h == nil { + h, pattern = mux.match(path) + } if h == nil { h, pattern = NotFoundHandler(), "" } @@ -639,45 +645,21 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) { if mux.m == nil { mux.m = make(map[string]muxEntry) } - url, err := url.Parse(pattern) - if err != nil { - panic("gmi: invalid pattern") - } - e := muxEntry{h: handler, pattern: pattern, u: url} + e := muxEntry{h: handler, pattern: pattern} mux.m[pattern] = e - // if pattern[len(pattern)-1] == '/' { - mux.es = appendSorted(mux.es, e) - // } + if pattern[len(pattern)-1] == '/' { + mux.es = appendSorted(mux.es, e) + } + + if pattern[0] != '/' { + mux.hosts = true + } } func appendSorted(es []muxEntry, e muxEntry) []muxEntry { n := len(es) i := sort.Search(n, func(i int) bool { - // Sort entries by length. - // - Entries with a scheme take preference over entries without. - // - Entries with a host take preference over entries without. - // - Longer paths take preference over shorter paths. - if e.u.Scheme != "" { - if es[i].u.Scheme == "" { - return true - } - if es[i].u.Scheme != e.u.Scheme { - return len(es[i].u.Scheme) < len(e.u.Scheme) - } - } else if es[i].u.Scheme != "" { - return false - } - if e.u.Host != "" { - if es[i].u.Host == "" { - return true - } - if es[i].u.Host != e.u.Host { - return len(es[i].u.Scheme) < len(e.u.Scheme) - } - } else if es[i].u.Host != "" { - return false - } - return len(es[i].u.Path) < len(e.u.Path) + return len(es[i].pattern) < len(e.pattern) }) if i == n { return append(es, e)