package http import "net" import "log" import "fmt" import "time" import "path" import "io/fs" import "slices" import "errors" import "strings" import "strconv" import "context" import "net/http" import "path/filepath" import "git.tebibyte.media/sashakoshka/step" import "git.tebibyte.media/sashakoshka/go-util/sync" import "git.tebibyte.media/sashakoshka/go-util/container" var errRateLimit = errors.New("rate limit") type rateLimitEntry struct { time time.Time duration time.Duration } type DirectoryData struct { Name string Entries []fs.DirEntry } type Handler struct { Environment *step.Environment Directories bool StepExt ucontainer.Set[string] Index []string ErrorDocument string DirectoryDocument string DenyAll ucontainer.Set[string] RateLimit time.Duration TrustXForwardedFor bool TrustCFConnectingIP bool rateLimitMap usync.Monitor[map[string] rateLimitEntry] } func (this *Handler) Init (ctx context.Context) error { this.rateLimitMap.Set(make(map[string] rateLimitEntry)) return nil } func (this *Handler) Configure (config step.Meta) error { var err error var rateLimit float64 if rateLimitStr := config.Get("http.rate-limit"); rateLimitStr != "" { rateLimit, err = strconv.ParseFloat(rateLimitStr, 64) if err != nil { return err } } this.Directories = config.Get("http.serve-directories") == "true" this.StepExt = ucontainer.NewSet(slices.Clone(config["http.step-extension"])...) this.Index = slices.Clone(config["http.index-file"]) this.ErrorDocument = config.Get("http.error-document") this.DirectoryDocument = config.Get("http.directory-document") this.DenyAll = ucontainer.NewSet("step.meta") this.RateLimit = time.Duration(rateLimit * float64(time.Second)) this.TrustXForwardedFor = config.Get("http.trust-x-forwarded-for") == "true" this.TrustCFConnectingIP = config.Get("http.trust-cf-connecting-ip") == "true" if len(this.StepExt) == 0 { this.StepExt.Add(".step") } if len(this.Index) == 0 { this.Index = []string { "index.step", "index.html", "index" } } return nil } func (this *Handler) ServeHTTP (res http.ResponseWriter, req *http.Request) { filesystem := this.Environment.GetFS() // get the real client IP remoteAddrStr := req.RemoteAddr realRemoteAddr := req.RemoteAddr if addr := req.Header.Get("CF-Connecting-IP"); addr != "" { if this.TrustCFConnectingIP { realRemoteAddr = addr remoteAddrStr = fmt.Sprintf("%s --CF-> %s", addr, req.RemoteAddr) } } else if addr := req.Header.Get("X-Forwarded-For"); addr != "" { if this.TrustXForwardedFor { realRemoteAddr = addr remoteAddrStr = fmt.Sprintf("%s --XF-> %s", addr, req.RemoteAddr) } } // make clone of req with correct ip because modifying req directly is // forbidden req = req.Clone(req.Context()) req.RemoteAddr = realRemoteAddr log.Println("(i)", req.Method, req.URL, "from", remoteAddrStr) // rate limit err := this.rateLimit(req, "", this.RateLimit) if err != nil { this.serveError(res, req, http.StatusTooManyRequests, req.URL, false) return } // normalize path pat := req.URL.Path if !strings.HasPrefix(pat, "/") { pat = "/" + pat req.URL.Path = pat } hasTrailingSlash := strings.HasSuffix(pat, "/") pat = path.Clean(req.URL.Path) name := pathToName(pat) // access control if this.DenyAll.Has(name) { this.serveError(res, req, http.StatusForbidden, req.URL, false) return } info, err := statFile(filesystem, name) if err != nil { this.serveError(res, req, http.StatusNotFound, req.URL, false) return } if info.IsDir() { // ensure the path ends with a / if !hasTrailingSlash { http.Redirect(res, req, pat + "/", http.StatusMovedPermanently) return } // try to find an index for _, base := range this.Index { currentPath := path.Join(pat, base) info, err := statFile(filesystem, pathToName(currentPath)) if err != nil { continue } if info.IsDir() { continue } this.serveFile(res, req, currentPath) return } if !this.Directories { this.serveError(res, req, http.StatusForbidden, req.URL, false) return } this.serveDirectory(res, req, pat) return } this.serveFile(res, req, pat) } func (this *Handler) Trim () { table, done := this.rateLimitMap.Borrow() defer done() if len(table) < 128 { return } keys := make([]string, 0, len(table) / 3) for key, entry := range table { duration := time.Hour if entry.duration != 0 { duration = entry.duration } if time.Since(entry.time) > duration { keys = append(keys, key) } } for _, key := range keys { delete(table, key) } } func (this *Handler) serveFile ( res http.ResponseWriter, req *http.Request, pat string, ) { name := pathToName(pat) if !this.StepExt.Has(filepath.Ext(name)) { // just a normal file http.ServeFileFS(res, req, this.Environment.GetFS(), name) return } this.serveDocument(res, req, name) } func (this *Handler) serveDirectory ( res http.ResponseWriter, req *http.Request, pat string, ) { name := pathToName(pat) entries, err := readDir(this.Environment.GetFS(), name) if err != nil { this.serveError(res, req, http.StatusInternalServerError, err, false) return } if this.DirectoryDocument == "" { this.serveFile(res, req, pat) return } document, err := this.Environment.Load(this.DirectoryDocument) if err != nil { this.serveError(res, req, http.StatusInternalServerError, err, false) return } err = document.Execute(res, step.ExecutionData { Data: DirectoryData { Name: name, Entries: entries, }, }) if err != nil { this.serveError(res, req, http.StatusInternalServerError, err, false) return } } func (this *Handler) serveDocument ( res http.ResponseWriter, req *http.Request, name string, ) { // parse document, err := this.Environment.Load(name) if err != nil { this.serveError(res, req, http.StatusInternalServerError, err, false) return } // rate limit (again) if rateLimit, err := strconv.ParseFloat(document.Meta.Get("rate-limit"), 64); err == nil { err := this.rateLimit(req, name, time.Duration(rateLimit * float64(time.Second))) if err != nil { this.serveError(res, req, http.StatusTooManyRequests, req.URL, false) return } } // set up HTTP response recorder recorder := HTTPResponseRecorder { } defer recorder.Reset() resetRecorder := func () { recorder.Reset() recorder.Head = res.Header().Clone() } if contentType := document.Meta.Get("content-type"); contentType != "" { recorder.Header().Set("Content-Type", contentType) } if status := document.Meta.Get("status"); status != "" { if status, err := strconv.Atoi(status); err == nil { recorder.Status = status } } // execute document data := HTTPData { } data.Res = WrappedResponseWriter { responseWriter: &recorder, resetFunc: resetRecorder, Header: WrappedHeader { Header: recorder.Header(), }, } data.Req = req err = document.Execute(&recorder, step.ExecutionData { Data: data, }) if errors.Is(err, step.ErrExecutionCanceled) { err = nil } var httpError Error if errors.As(err, &httpError) { this.serveError ( res, req, httpError.Status, httpError.Message, false) return } var httpRedirect Redirect if errors.As(err, &httpRedirect) { recorder.WriteHeaderValues(res) http.Redirect(res, req, httpRedirect.Location, httpRedirect.Status) return } if err != nil { this.serveError ( res, req, http.StatusInternalServerError, err, false) return } // play back recorded response recorder.WriteHeaderValues(res) recorder.WriteHeaderStatus(res) _, err = recorder.WriteBody(res) if err != nil { this.logErr(name, err) return } } func (this *Handler) serveError ( res http.ResponseWriter, req *http.Request, status int, message any, safeMode bool, ) { log.Printf("ERR %d %s: %v\n", status, http.StatusText(status), message) if safeMode || this.ErrorDocument == "" { res.Header().Add("Content-Type", "text/plain") res.WriteHeader(status) if message == nil { fmt.Fprintf(res, "%d %s\n", status, http.StatusText(status)) } else { fmt.Fprintf(res, "%d %s: %v\n", status, http.StatusText(status), message) } return } document, err := this.Environment.Load(this.ErrorDocument) if err != nil { this.serveError(res, req, http.StatusInternalServerError, err, true) return } res.WriteHeader(status) err = document.Execute(res, step.ExecutionData { Data: Error { Status: status, Message: message, }, }) if err != nil { this.serveError(res, req, http.StatusInternalServerError, err, true) return } } func (this *Handler) logErr (name string, err error) { log.Printf("ERR %s: %v\n", name, err) } func (this *Handler) rateLimit (req *http.Request, resource string, duration time.Duration) error { if duration == 0 { return nil } host, _, _ := net.SplitHostPort(req.RemoteAddr) key := fmt.Sprintf("%s|%s", host, resource) table, done := this.rateLimitMap.Borrow() defer done() if entry, ok := table[key]; ok { if time.Since(entry.time) < duration { return errRateLimit } } table[key] = rateLimitEntry { time: time.Now(), duration: duration, } return nil } func statFile (filesystem fs.FS, name string) (fs.FileInfo, error) { if filesystem, ok := filesystem.(fs.StatFS); ok { return filesystem.Stat(name) } file, err := filesystem.Open(name) if err != nil { return nil, err } defer file.Close() return file.Stat() } func pathToName (pat string) string { if strings.HasPrefix(pat, "/") { pat = strings.TrimPrefix(pat, "/") } if pat == "" { return "." } return pat } func readDir (filesystem fs.FS, name string) ([]fs.DirEntry, error) { if filesystem, ok := filesystem.(fs.ReadDirFS); ok { return filesystem.ReadDir(name) } dir, err := filesystem.Open(name) if err != nil { return nil, err } defer dir.Close() if dir, ok := dir.(fs.ReadDirFile); ok { return dir.ReadDir(0) } return nil, nil }