From 95e60f477ad2d72ccc67af35584000f83a2575e2 Mon Sep 17 00:00:00 2001 From: Sasha Koshka Date: Sat, 14 Dec 2024 16:17:57 -0500 Subject: [PATCH] Fix handling of forwarded requests --- http/handler.go | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/http/handler.go b/http/handler.go index 41dfa55..ae02a44 100644 --- a/http/handler.go +++ b/http/handler.go @@ -49,16 +49,15 @@ func (this *Handler) Init (ctx context.Context) error { } func (this *Handler) ServeHTTP (res http.ResponseWriter, req *http.Request) { - remoteAddr := req.RemoteAddr remoteAddrStr := req.RemoteAddr if addr := req.Header.Get("CF-Connecting-IP"); addr != "" { if this.TrustCFConnectingIP { - remoteAddr = addr + req.RemoteAddr = addr remoteAddrStr = fmt.Sprintf("%s --CF-> %s", addr, req.RemoteAddr) } } else if addr := req.Header.Get("X-Forwarded-For"); addr != "" { if this.TrustXForwardedFor { - remoteAddr = addr + req.RemoteAddr = addr remoteAddrStr = fmt.Sprintf("%s --XF-> %s", addr, req.RemoteAddr) } } @@ -66,7 +65,7 @@ func (this *Handler) ServeHTTP (res http.ResponseWriter, req *http.Request) { filesystem := this.Environment.GetFS() // rate limit - err := this.rateLimit(req, remoteAddr, "", this.RateLimit) + err := this.rateLimit(req, "", this.RateLimit) if err != nil { this.serveError(res, req, http.StatusTooManyRequests, req.URL, false) return @@ -106,7 +105,7 @@ func (this *Handler) ServeHTTP (res http.ResponseWriter, req *http.Request) { info, err := statFile(filesystem, pathToName(currentPath)) if err != nil { continue } if info.IsDir() { continue } - this.serveFile(res, req, currentPath, remoteAddr) + this.serveFile(res, req, currentPath) return } @@ -115,11 +114,11 @@ func (this *Handler) ServeHTTP (res http.ResponseWriter, req *http.Request) { return } - this.serveDirectory(res, req, pat, remoteAddr) + this.serveDirectory(res, req, pat) return } - this.serveFile(res, req, pat, remoteAddr) + this.serveFile(res, req, pat) } func (this *Handler) Trim () { @@ -146,7 +145,6 @@ func (this *Handler) serveFile ( res http.ResponseWriter, req *http.Request, pat string, - remoteAddr string, ) { name := pathToName(pat) if !this.StepExt.Has(filepath.Ext(name)) { @@ -154,14 +152,13 @@ func (this *Handler) serveFile ( http.ServeFileFS(res, req, this.Environment.GetFS(), name) return } - this.serveDocument(res, req, name, remoteAddr) + this.serveDocument(res, req, name) } func (this *Handler) serveDirectory ( res http.ResponseWriter, req *http.Request, pat string, - remoteAddr string, ) { name := pathToName(pat) entries, err := readDir(this.Environment.GetFS(), name) @@ -170,7 +167,7 @@ func (this *Handler) serveDirectory ( return } if this.DirectoryDocument == "" { - this.serveFile(res, req, pat, remoteAddr) + this.serveFile(res, req, pat) return } document, err := this.Environment.Load(this.DirectoryDocument) @@ -194,7 +191,6 @@ func (this *Handler) serveDocument ( res http.ResponseWriter, req *http.Request, name string, - remoteAddr string, ) { // parse document, err := this.Environment.Load(name) @@ -205,7 +201,7 @@ func (this *Handler) serveDocument ( // rate limit (again) if rateLimit, err := strconv.ParseFloat(document.Meta.Get("rate-limit"), 64); err == nil { - err := this.rateLimit(req, remoteAddr, name, time.Duration(rateLimit * float64(time.Second))) + err := this.rateLimit(req, name, time.Duration(rateLimit * float64(time.Second))) if err != nil { this.serveError(res, req, http.StatusTooManyRequests, req.URL, false) return @@ -308,8 +304,8 @@ func (this *Handler) logErr (name string, err error) { log.Printf("ERR %s: %v\n", name, err) } -func (this *Handler) rateLimit (req *http.Request, remoteAddr string, resource string, duration time.Duration) error { - host, _, _ := net.SplitHostPort(remoteAddr) +func (this *Handler) rateLimit (req *http.Request, resource string, duration time.Duration) error { + host, _, _ := net.SplitHostPort(req.RemoteAddr) key := fmt.Sprintf("%s|%s", host, resource) table, done := this.rateLimitMap.Borrow() defer done()