Remove support for matching schemes

This commit is contained in:
adnano 2020-09-28 16:02:32 -04:00
parent 5a07b49ef5
commit a49b0cdaac
3 changed files with 44 additions and 108 deletions

View File

@ -46,11 +46,11 @@ func main() {
handler := &gmi.ServeMux{} handler := &gmi.ServeMux{}
handler.HandleFunc("/", welcome) handler.HandleFunc("/", welcome)
handler.HandleFunc("/login/", login) handler.HandleFunc("/login", login)
handler.HandleFunc("/login/password/", loginPassword) handler.HandleFunc("/login/password", loginPassword)
handler.HandleFunc("/profile/", profile) handler.HandleFunc("/profile", profile)
handler.HandleFunc("/admin/", admin) handler.HandleFunc("/admin", admin)
handler.HandleFunc("/logout/", logout) handler.HandleFunc("/logout", logout)
server := &gmi.Server{ server := &gmi.Server{
Certificate: cert, Certificate: cert,

View File

@ -1,46 +0,0 @@
package gmi
import (
"math/rand"
"testing"
"time"
)
func TestServeMuxEntryOrder(t *testing.T) {
expected := []string{
"https://example.com/longpath",
"https://example.com/path",
"https://example.com",
"http://example.com/longpath",
"http://example.com/path",
"http://example.com",
"example.com/longpath",
"example.com/path",
"example.com",
"/longpath",
"/path",
"/",
}
// Shuffle input
a := make([]string, len(expected))
copy(a, expected)
rand.Seed(time.Now().UnixNano())
rand.Shuffle(len(a), func(i, j int) { a[i], a[j] = a[j], a[i] })
mux := &ServeMux{}
for _, s := range a {
mux.Handle(s, NotFoundHandler())
var es string
for i := range mux.es {
es += mux.es[i].u.String() + " "
}
t.Logf(es)
}
for i, e := range mux.es {
s := e.u.String()
if s != expected[i] {
t.Errorf("wrong order of mux entries: expected %s, got %s", expected[i], s)
}
}
}

View File

@ -454,15 +454,15 @@ func isSlashRune(r rune) bool { return r == '/' || r == '\\' }
// header, stripping the port number and redirecting any request containing . or // header, stripping the port number and redirecting any request containing . or
// .. elements or repeated slashes to an equivalent, cleaner URL. // .. elements or repeated slashes to an equivalent, cleaner URL.
type ServeMux struct { type ServeMux struct {
mu sync.RWMutex mu sync.RWMutex
m map[string]muxEntry m map[string]muxEntry
es []muxEntry // slice of entries sorted from longest to shortest. es []muxEntry // slice of entries sorted from longest to shortest.
hosts bool // whether any patterns contain hostnames
} }
type muxEntry struct { type muxEntry struct {
h Handler h Handler
pattern string pattern string
u *url.URL
} }
// NewServeMux allocates and returns a new ServeMux. // NewServeMux allocates and returns a new ServeMux.
@ -505,19 +505,17 @@ func stripHostPort(h string) string {
// Find a handler on a handler map given a path string. // Find a handler on a handler map given a path string.
// Most-specific (longest) pattern wins. // Most-specific (longest) pattern wins.
func (mux *ServeMux) match(url *url.URL) (h Handler, pattern string) { func (mux *ServeMux) match(path string) (h Handler, pattern string) {
// Check for exact match first. // Check for exact match first.
v, ok := mux.m[url.String()] v, ok := mux.m[path]
if ok { if ok {
return v.h, v.pattern return v.h, v.pattern
} }
// Check for longest valid match. mux.es contains all patterns // Check for longest valid match. mux.es contains all patterns
// that end in / sorted from longest to shortest. // that end in / sorted from longest to shortest.
for _, e := range mux.es { for _, e := range mux.es {
if (e.u.Scheme == "" || url.Scheme == e.u.Scheme) && if strings.HasPrefix(path, e.pattern) {
(e.u.Host == "" || url.Host == e.u.Host) &&
strings.HasPrefix(url.Path, e.u.Path) {
return e.h, e.pattern return e.h, e.pattern
} }
} }
@ -572,6 +570,8 @@ func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool {
// to the canonical path. If the host contains a port, it is ignored // to the canonical path. If the host contains a port, it is ignored
// when matching handlers. // when matching handlers.
// //
// The path and host are used unchanged for CONNECT requests.
//
// Handler also returns the registered pattern that matches the // Handler also returns the registered pattern that matches the
// request or, in the case of internally-generated redirects, // request or, in the case of internally-generated redirects,
// the pattern that will match after following the redirect. // the pattern that will match after following the redirect.
@ -579,34 +579,40 @@ func (mux *ServeMux) shouldRedirectRLocked(host, path string) bool {
// If there is no registered handler that applies to the request, // If there is no registered handler that applies to the request,
// Handler returns a ``page not found'' handler and an empty pattern. // Handler returns a ``page not found'' handler and an empty pattern.
func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) { func (mux *ServeMux) Handler(r *Request) (h Handler, pattern string) {
// All requests have any port stripped and path cleaned // All other requests have any port stripped and path cleaned
// before passing to mux.handler. // before passing to mux.handler.
url := *r.URL host := stripHostPort(r.Host)
url.Host = stripHostPort(r.Host) path := cleanPath(r.URL.Path)
url.Path = cleanPath(r.URL.Path)
// If the given path is /tree and its handler is not registered, // If the given path is /tree and its handler is not registered,
// redirect for /tree/. // redirect for /tree/.
if u, ok := mux.redirectToPathSlash(url.Host, url.Path, r.URL); ok { if u, ok := mux.redirectToPathSlash(host, path, r.URL); ok {
return PermanentRedirectHandler(u.String()), u.Path return RedirectHandler(u.String()), u.Path
} }
if url.Path != r.URL.Path { if path != r.URL.Path {
_, pattern = mux.handler(&url) _, pattern = mux.handler(host, path)
red := *r.URL url := *r.URL
red.Path = url.Path url.Path = path
return PermanentRedirectHandler(red.String()), pattern return RedirectHandler(url.String()), pattern
} }
return mux.handler(&url) return mux.handler(host, r.URL.Path)
} }
// handler is the main implementation of Handler. // handler is the main implementation of Handler.
func (mux *ServeMux) handler(url *url.URL) (h Handler, pattern string) { // The path is known to be in canonical form, except for CONNECT methods.
func (mux *ServeMux) handler(host, path string) (h Handler, pattern string) {
mux.mu.RLock() mux.mu.RLock()
defer mux.mu.RUnlock() defer mux.mu.RUnlock()
h, pattern = mux.match(url) // Host-specific pattern takes precedence over generic ones
if mux.hosts {
h, pattern = mux.match(host + path)
}
if h == nil {
h, pattern = mux.match(path)
}
if h == nil { if h == nil {
h, pattern = NotFoundHandler(), "" h, pattern = NotFoundHandler(), ""
} }
@ -639,45 +645,21 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) {
if mux.m == nil { if mux.m == nil {
mux.m = make(map[string]muxEntry) mux.m = make(map[string]muxEntry)
} }
url, err := url.Parse(pattern) e := muxEntry{h: handler, pattern: pattern}
if err != nil {
panic("gmi: invalid pattern")
}
e := muxEntry{h: handler, pattern: pattern, u: url}
mux.m[pattern] = e mux.m[pattern] = e
// if pattern[len(pattern)-1] == '/' { if pattern[len(pattern)-1] == '/' {
mux.es = appendSorted(mux.es, e) mux.es = appendSorted(mux.es, e)
// } }
if pattern[0] != '/' {
mux.hosts = true
}
} }
func appendSorted(es []muxEntry, e muxEntry) []muxEntry { func appendSorted(es []muxEntry, e muxEntry) []muxEntry {
n := len(es) n := len(es)
i := sort.Search(n, func(i int) bool { i := sort.Search(n, func(i int) bool {
// Sort entries by length. return len(es[i].pattern) < len(e.pattern)
// - Entries with a scheme take preference over entries without.
// - Entries with a host take preference over entries without.
// - Longer paths take preference over shorter paths.
if e.u.Scheme != "" {
if es[i].u.Scheme == "" {
return true
}
if es[i].u.Scheme != e.u.Scheme {
return len(es[i].u.Scheme) < len(e.u.Scheme)
}
} else if es[i].u.Scheme != "" {
return false
}
if e.u.Host != "" {
if es[i].u.Host == "" {
return true
}
if es[i].u.Host != e.u.Host {
return len(es[i].u.Scheme) < len(e.u.Scheme)
}
} else if es[i].u.Host != "" {
return false
}
return len(es[i].u.Path) < len(e.u.Path)
}) })
if i == n { if i == n {
return append(es, e) return append(es, e)