package session import "log" import "time" import "net/http" import "html/template" import "github.com/google/uuid" import "git.tebibyte.media/sashakoshka/step" import "git.tebibyte.media/sashakoshka/go-util/sync" import "git.tebibyte.media/sashakoshka/go-util/container" import shttp "git.tebibyte.media/sashakoshka/step/http" const sessionIDCookieName = "step-session-id" const secureSessionIDCookieName = "__Host-step-session-id" const defaultLifetime = 48 * time.Hour var _ step.FuncProviderFor = new(Provider) var _ step.Initializer = new(Provider) var _ step.Configurable = new(Provider) var _ step.Trimmer = new(Provider) // Provider provides session functions. type Provider struct { Lifetime time.Duration InsecureCookie bool CookieDomain string sessions usync.RWMonitor[sessionMap] } // Package fulfills the step.Provider interface. func (this *Provider) Package () string { return "session" } func (this *Provider) Init () error { this.sessions = usync.NewRWMonitor(make(sessionMap)) return nil } func (this *Provider) Trim () { sessions, done := this.sessions.Borrow() defer done() stale := ucontainer.Set[uuid.UUID] { } for id, session := range sessions { if session.Expired() { stale.Add(id) } } for id := range stale { delete(sessions, id) } } func (this *Provider) Configure (config step.Meta) error { if lifetimeStr := config.Get("session.lifetime"); lifetimeStr != "" { lifetime, err := time.ParseDuration(lifetimeStr) if err != nil { return err } this.Lifetime = lifetime } if insecureCookieStr := config.Get("session.insecure-cookie"); insecureCookieStr != "" { if insecureCookieStr != "true" && insecureCookieStr != "false" { return step.ErrTypeMismatch } this.InsecureCookie = insecureCookieStr == "true" if this.InsecureCookie { log.Println("!!! session.insecure-cookie is active, this is not recommended") } } if cookieDomainStr := config.Get("session.cookie-domain"); cookieDomainStr != "" { this.CookieDomain = cookieDomainStr } return nil } // FuncMapFor fulfills the step.FuncProviderFor interface. func (this *Provider) FuncMapFor (document *step.Document) template.FuncMap { stat := &state { document: document, sessions: &this.sessions, insecureCookie: this.InsecureCookie, cookieDomain: this.CookieDomain, } return template.FuncMap { "sessionHTTP": stat.funcSessionHTTP, "session": stat.funcSession, } } type state struct { document *step.Document sessions *usync.RWMonitor[sessionMap] lifetime time.Duration insecureCookie bool cookieDomain string } func (this *state) funcSessionHTTP ( res shttp.WrappedResponseWriter, req *http.Request, ) ( *Session, error, ) { cookieName := sessionIDCookieName if !this.insecureCookie { cookieName = secureSessionIDCookieName } var id uuid.UUID if cookie, err := req.Cookie(cookieName); err == nil { if parsed, err := uuid.Parse(cookie.Value); err == nil { id = parsed } } var expiration time.Time if this.lifetime == 0 { expiration = time.Now().Add(defaultLifetime) } else { expiration = time.Now().Add(this.lifetime) } var result *Session if session, ok := this.checkSession(id); ok { if this.lifetime != 0 { session.setExpiration(expiration) } result = session } else { if id == (uuid.UUID { }) { id = uuid.New() } session, err := this.newSession(id, expiration) if err != nil { return nil, err } result = session } cookie := &http.Cookie { Name: cookieName, Value: result.ID().String(), Expires: expiration, SameSite: http.SameSiteStrictMode, Path: "/", Domain: this.cookieDomain, } if !this.insecureCookie { cookie.Secure = true cookie.HttpOnly = true } underlyingRes := shttp.UnderlyingResponseWriter(res) http.SetCookie(underlyingRes, cookie) return result, nil } func (this *state) funcSession (id uuid.UUID) (*Session, error) { expires := time.Now().Add(this.lifetime) if session, ok := this.checkSession(id); ok { session.setExpiration(expires) return session, nil } return this.newSession(id, expires) } func (this *state) checkSession (id uuid.UUID) (*Session, bool) { sessions, done := this.sessions.RBorrow() defer done() session, ok := sessions[id] if !ok || session.Expired() { return nil, false } return session, true } func (this *state) newSession (id uuid.UUID, expires time.Time) (*Session, error) { return &Session { parent: this, id: id, data: usync.NewRWMonitor(make(map[string] any)), expires: usync.NewRWMonitor(expires), }, nil } type sessionMap map[uuid.UUID] *Session type Session struct { parent *state // immutable id uuid.UUID // immutable data usync.RWMonitor[map[string] any] // mutable expires usync.RWMonitor[time.Time] // mutable } func (this *Session) ID () uuid.UUID { return this.id } func (this *Session) Get (name string) (any, error) { data, done := this.data.RBorrow() defer done() return ensafenValue(data[name]) } func (this *Session) Set (name string, value any) (string, error) { value, err := ensafenValue(value) if err != nil { return "", err } data, done := this.data.Borrow() defer done() data[name] = value this.addSelf() return "", nil } func (this *Session) Del (name string) string { data, done := this.data.Borrow() defer done() delete(data, name) return "" } func (this *Session) Clear () string { data, done := this.data.Borrow() defer done() clear(data) this.delSelf() return "" } func (this *Session) String () string { return this.id.String() } func (this *Session) addSelf () (string, error) { sessions, done := this.parent.sessions.Borrow() defer done() sessions[this.id] = this return "", nil } func (this *Session) delSelf () (string, error) { sessions, done := this.parent.sessions.Borrow() defer done() delete(sessions, this.id) return "", nil } func (this *Session) Expiration () time.Time { expiration, done := this.expires.RBorrow() defer done() return expiration // ok because its not a pointer } func (this *Session) setExpiration (expires time.Time) string { // not public because this only sets the expiration time of the session, // and not the session cookie it represents because that would require // breaking into the user's computer. this.expires.Set(expires) return "" } func (this *Session) Expired () bool { return this.Expiration().Before(time.Now()) } func ensafenValue (value any) (any, error) { switch value := value.(type) { case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, bool, float32, float64, string, nil: return value, nil case []any: list := make([]any, len(value)) for index, item := range value { item, err := ensafenValue(item) if err != nil { return nil, err } list[index] = item } return list, nil case map[string] any: dict := make(map[string] any, len(value)) for key, item := range value { item, err := ensafenValue(item) if err != nil { return nil, err } dict[key] = item } return dict, nil } return nil, step.ErrTypeMismatch }