From 67d4f8a10a2cb59684a123bd7005975be1f6a397 Mon Sep 17 00:00:00 2001 From: Sasha Koshka Date: Thu, 12 Dec 2024 03:10:30 -0500 Subject: [PATCH] Add support for HTTP rate limiting Closes #16 --- cmd/stepd/main.go | 56 ++++++++++++++++++---- http/handler.go | 116 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 148 insertions(+), 24 deletions(-) diff --git a/cmd/stepd/main.go b/cmd/stepd/main.go index 149b2cd..4ddbc5b 100644 --- a/cmd/stepd/main.go +++ b/cmd/stepd/main.go @@ -51,6 +51,10 @@ func main () { 0, "http-directory-document", "The document to use for displaying directory listings over http", "", cli.ValString) + flagHTTPRateLimit := cli.NewInputFlag ( + 0, "http-rate-limit", + "Seconds an HTTP client must wait per request", + "", cli.ValString) flagDirectories := cli.NewFlag ( 'd', "directories", "Serve the contents of directories") @@ -149,6 +153,9 @@ func main () { if flagHTTPErrorDocument.Value != "" { config.Set("http.error-document", flagHTTPErrorDocument.Value) } + if flagHTTPRateLimit.Value != "" { + config.Set("http.rate-limit", flagHTTPRateLimit.Value) + } if flagDirectories.Value != "" { config.Set("http.serve-directories", flagDirectories.Value) } @@ -181,17 +188,25 @@ func main () { // initialize the environment err = environment.Init(context.Background()) - if err != nil { log.Fatal(err) } + if err != nil { log.Fatal("XXX:", err) } // set up the HTTP handler + rateLimit := 0.0 + if rateLimitStr := config.Get("http.rate-limit"); rateLimitStr != "" { + rateLimit, err = strconv.ParseFloat(rateLimitStr, 64) + if err != nil { log.Fatal("XXX bad value for rate limit", err) } + } handler := stephttp.Handler { - Environment: &environment, - Directories: config.Get("http.serve-directories") == "true", - StepExt: ucontainer.NewSet(slices.Clone(config["http.step-extension"])...), - Index: slices.Clone(config["http.index-file"]), - ErrorDocument: config.Get("http.error-document"), - DirectoryDocument: config.Get("http.directory-document"), - DenyAll: ucontainer.NewSet(configFileName), + Environment: &environment, + Directories: config.Get("http.serve-directories") == "true", + StepExt: ucontainer.NewSet(slices.Clone(config["http.step-extension"])...), + Index: slices.Clone(config["http.index-file"]), + ErrorDocument: config.Get("http.error-document"), + DirectoryDocument: config.Get("http.directory-document"), + DenyAll: ucontainer.NewSet(configFileName), + RateLimit: time.Duration(rateLimit * float64(time.Second)), + TrustXForwardedFor: config.Get("http.trust-x-forwarded-for") == "true", + TrustCFConnectingIP: config.Get("http.trust-cf-connecting-ip") == "true", } if len(handler.StepExt) == 0 { handler.StepExt.Add(".step") @@ -199,6 +214,8 @@ func main () { if len(handler.Index) == 0 { handler.Index = []string { "index.step", "index.html", "index" } } + err = handler.Init(ctx) + if err != nil { log.Println("XXX", err) } // set up the HTTP server httpServer := httpServerRoutine { @@ -206,10 +223,16 @@ func main () { Handler: &handler, } + // set up the trimming routine + trimmer := trimmerRoutine { + HTTPHandler: &handler, + } + // set up the routine manager manager := routines.Manager { Routines: []routines.Routine { &httpServer, + &trimmer, }, } @@ -242,6 +265,23 @@ func (this *httpServerRoutine) Run (ctx context.Context) error { return err } +type trimmerRoutine struct { + HTTPHandler *stephttp.Handler +} + +func (this *trimmerRoutine) Run (ctx context.Context) error { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for { + select { + case <- ticker.C: + this.HTTPHandler.Trim() + case <- ctx.Done(): + return ctx.Err() + } + } +} + func logProviders (providers []step.Provider) { output := "providers: " x := utf8.RuneCountInString(output) diff --git a/http/handler.go b/http/handler.go index d5bb723..870b645 100644 --- a/http/handler.go +++ b/http/handler.go @@ -1,42 +1,77 @@ package http +import "net" import "log" import "fmt" +import "time" import "path" import "io/fs" import "errors" import "strings" import "strconv" +import "context" import "net/http" import "path/filepath" import "git.tebibyte.media/sashakoshka/step" +import "git.tebibyte.media/sashakoshka/goutil/sync" import "git.tebibyte.media/sashakoshka/goutil/container" +var errRateLimit = errors.New("rate limit") + +type rateLimitEntry struct { + time time.Time + duration time.Duration +} + type DirectoryData struct { Name string Entries []fs.DirEntry } type Handler struct { - Environment *step.Environment - Directories bool - StepExt ucontainer.Set[string] - Index []string - ErrorDocument string - DirectoryDocument string - DenyAll ucontainer.Set[string] + Environment *step.Environment + Directories bool + StepExt ucontainer.Set[string] + Index []string + ErrorDocument string + DirectoryDocument string + DenyAll ucontainer.Set[string] + RateLimit time.Duration + TrustXForwardedFor bool + TrustCFConnectingIP bool + + rateLimitMap usync.Locker[map[string] rateLimitEntry] +} + +func (this *Handler) Init (ctx context.Context) error { + this.rateLimitMap.Set(make(map[string] rateLimitEntry)) + return nil } func (this *Handler) ServeHTTP (res http.ResponseWriter, req *http.Request) { - remoteAddr := req.RemoteAddr + remoteAddr := req.RemoteAddr + remoteAddrStr := req.RemoteAddr if addr := req.Header.Get("CF-Connecting-IP"); addr != "" { - remoteAddr = fmt.Sprintf("%s --CF-> %s", addr, req.RemoteAddr) + remoteAddr = addr + if this.TrustCFConnectingIP { + remoteAddrStr = fmt.Sprintf("%s --CF-> %s", addr, req.RemoteAddr) + } } else if addr := req.Header.Get("X-Forwarded-For"); addr != "" { - remoteAddr = fmt.Sprintf("%s --??-> %s", addr, req.RemoteAddr) + remoteAddr = addr + if this.TrustXForwardedFor { + remoteAddrStr = fmt.Sprintf("%s --??-> %s", addr, req.RemoteAddr) + } } - log.Println("(i)", req.Method, req.URL, "from", remoteAddr) + log.Println("(i)", req.Method, req.URL, "from", remoteAddrStr) filesystem := this.Environment.GetFS() + // rate limit + err := this.rateLimit(req, remoteAddr, "", this.RateLimit) + if err != nil { + this.serveError(res, req, http.StatusTooManyRequests, req.URL, false) + return + } + // normalize path pat := req.URL.Path if !strings.HasPrefix(pat, "/") { @@ -71,7 +106,7 @@ func (this *Handler) ServeHTTP (res http.ResponseWriter, req *http.Request) { info, err := statFile(filesystem, pathToName(currentPath)) if err != nil { continue } if info.IsDir() { continue } - this.serveFile(res, req, currentPath) + this.serveFile(res, req, currentPath, remoteAddr) return } @@ -80,17 +115,38 @@ func (this *Handler) ServeHTTP (res http.ResponseWriter, req *http.Request) { return } - this.serveDirectory(res, req, pat) + this.serveDirectory(res, req, pat, remoteAddr) return } - this.serveFile(res, req, pat) + this.serveFile(res, req, pat, remoteAddr) +} + +func (this *Handler) Trim () { + table, done := this.rateLimitMap.Borrow() + defer done() + if len(table) < 128 { return } + keys := make([]string, 0, len(table) / 3) + for key, entry := range table { + duration := time.Hour + if entry.duration != 0 { + duration = entry.duration + } + if time.Since(entry.time) > duration { + keys = append(keys, key) + } + + } + for _, key := range keys { + delete(table, key) + } } func (this *Handler) serveFile ( res http.ResponseWriter, req *http.Request, pat string, + remoteAddr string, ) { name := pathToName(pat) if !this.StepExt.Has(filepath.Ext(name)) { @@ -98,13 +154,14 @@ func (this *Handler) serveFile ( http.ServeFileFS(res, req, this.Environment.GetFS(), name) return } - this.serveDocument(res, req, name) + this.serveDocument(res, req, name, remoteAddr) } func (this *Handler) serveDirectory ( res http.ResponseWriter, req *http.Request, pat string, + remoteAddr string, ) { name := pathToName(pat) entries, err := readDir(this.Environment.GetFS(), name) @@ -113,7 +170,7 @@ func (this *Handler) serveDirectory ( return } if this.DirectoryDocument == "" { - this.serveFile(res, req, pat) + this.serveFile(res, req, pat, remoteAddr) return } document, err := this.Environment.Load(this.DirectoryDocument) @@ -137,6 +194,7 @@ func (this *Handler) serveDocument ( res http.ResponseWriter, req *http.Request, name string, + remoteAddr string, ) { // parse document, err := this.Environment.Load(name) @@ -145,6 +203,15 @@ func (this *Handler) serveDocument ( return } + // rate limit (again) + if rateLimit, err := strconv.ParseFloat(document.Meta.Get("rate-limit"), 64); err == nil { + err := this.rateLimit(req, remoteAddr, name, time.Duration(rateLimit * float64(time.Second))) + if err != nil { + this.serveError(res, req, http.StatusTooManyRequests, req.URL, false) + return + } + } + // set up HTTP response recorder recorder := HTTPResponseRecorder { } resetRecorder := func () { @@ -241,6 +308,23 @@ func (this *Handler) logErr (name string, err error) { log.Printf("ERR %s: %v\n", name, err) } +func (this *Handler) rateLimit (req *http.Request, remoteAddr string, resource string, duration time.Duration) error { + host, _, _ := net.SplitHostPort(remoteAddr) + key := fmt.Sprintf("%s|%s", host, resource) + table, done := this.rateLimitMap.Borrow() + defer done() + if entry, ok := table[key]; ok { + if time.Since(entry.time) < duration { + return errRateLimit + } + } + table[key] = rateLimitEntry { + time: time.Now(), + duration: duration, + } + return nil +} + func statFile (filesystem fs.FS, name string) (fs.FileInfo, error) { if filesystem, ok := filesystem.(fs.StatFS); ok { return filesystem.Stat(name)