step/providers/session/session.go

256 lines
6.1 KiB
Go

package session
import "time"
import "net/http"
import "html/template"
import "github.com/google/uuid"
import "git.tebibyte.media/sashakoshka/step"
import "git.tebibyte.media/sashakoshka/go-util/sync"
import "git.tebibyte.media/sashakoshka/go-util/container"
import shttp "git.tebibyte.media/sashakoshka/step/http"
const sessionIDCookieName = "step-session-id"
const defaultLifetime = 48 * time.Hour
var _ step.FuncProviderFor = new(Provider)
var _ step.Initializer = new(Provider)
var _ step.Configurable = new(Provider)
var _ step.Trimmer = new(Provider)
// Provider provides session functions.
type Provider struct {
Lifetime time.Duration
sessions usync.RWLocker[sessionMap]
}
// Package fulfills the step.Provider interface.
func (this *Provider) Package () string {
return "session"
}
func (this *Provider) Init () error {
this.sessions = usync.NewRWLocker(make(sessionMap))
return nil
}
func (this *Provider) Trim () {
sessions, done := this.sessions.Borrow()
defer done()
stale := ucontainer.Set[uuid.UUID] { }
for id, session := range sessions {
if session.Expired() {
stale.Add(id)
}
}
for id := range stale {
delete(sessions, id)
}
}
func (this *Provider) Configure (config step.Meta) error {
if lifetimeStr := config.Get("session.lifetime"); lifetimeStr != "" {
lifetime, err := time.ParseDuration(lifetimeStr)
if err != nil { return err }
this.Lifetime = lifetime
}
return nil
}
// FuncMapFor fulfills the step.FuncProviderFor interface.
func (this *Provider) FuncMapFor (document *step.Document) template.FuncMap {
stat := &state {
document: document,
sessions: &this.sessions,
}
return template.FuncMap {
"sessionHTTP": stat.funcSessionHTTP,
"session": stat.funcSession,
}
}
type state struct {
document *step.Document
sessions *usync.RWLocker[sessionMap]
lifetime time.Duration
}
func (this *state) funcSessionHTTP (
res shttp.WrappedResponseWriter,
req *http.Request,
) (
*Session,
error,
) {
var id uuid.UUID
if cookie, err := req.Cookie(sessionIDCookieName); err == nil {
if parsed, err := uuid.Parse(cookie.Value); err == nil {
id = parsed
}
}
var expiration time.Time
if this.lifetime == 0 {
expiration = time.Now().Add(defaultLifetime)
} else {
expiration = time.Now().Add(this.lifetime)
}
var result *Session
if session, ok := this.checkSession(id); ok {
if this.lifetime != 0 {
session.setExpiration(expiration)
}
result = session
} else {
if id == (uuid.UUID { }) {
id = uuid.New()
}
session, err := this.newSession(id, expiration)
if err != nil { return nil, err }
result = session
}
cookie := &http.Cookie {
Name: sessionIDCookieName,
Value: result.ID().String(),
Expires: expiration,
Secure: true,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Path: "/",
}
underlyingRes := shttp.UnderlyingResponseWriter(res)
http.SetCookie(underlyingRes, cookie)
return result, nil
}
func (this *state) funcSession (id uuid.UUID) (*Session, error) {
expires := time.Now().Add(this.lifetime)
if session, ok := this.checkSession(id); ok {
session.setExpiration(expires)
return session, nil
}
return this.newSession(id, expires)
}
func (this *state) checkSession (id uuid.UUID) (*Session, bool) {
sessions, done := this.sessions.RBorrow()
defer done()
session, ok := sessions[id]
if !ok || session.Expired() { return nil, false }
return session, true
}
func (this *state) newSession (id uuid.UUID, expires time.Time) (*Session, error) {
return &Session {
parent: this,
id: id,
data: usync.NewRWLocker(make(map[string] any)),
expires: usync.NewRWLocker(expires),
}, nil
}
type sessionMap map[uuid.UUID] *Session
type Session struct {
parent *state // immutable
id uuid.UUID // immutable
data usync.RWLocker[map[string] any] // mutable
expires usync.RWLocker[time.Time] // mutable
}
func (this *Session) ID () uuid.UUID {
return this.id
}
func (this *Session) Get (name string) (any, error) {
data, done := this.data.RBorrow()
defer done()
return ensafenValue(data[name])
}
func (this *Session) Set (name string, value any) (string, error) {
value, err := ensafenValue(value)
if err != nil { return "", err }
data, done := this.data.Borrow()
defer done()
data[name] = value
this.addSelf()
return "", nil
}
func (this *Session) Del (name string) string {
data, done := this.data.Borrow()
defer done()
delete(data, name)
return ""
}
func (this *Session) Clear () string {
data, done := this.data.Borrow()
defer done()
clear(data)
this.delSelf()
return ""
}
func (this *Session) String () string {
return this.id.String()
}
func (this *Session) addSelf () (string, error) {
sessions, done := this.parent.sessions.Borrow()
defer done()
sessions[this.id] = this
return "", nil
}
func (this *Session) delSelf () (string, error) {
sessions, done := this.parent.sessions.Borrow()
defer done()
delete(sessions, this.id)
return "", nil
}
func (this *Session) Expiration () time.Time {
expiration, done := this.expires.RBorrow()
defer done()
return expiration // ok because its not a pointer
}
func (this *Session) setExpiration (expires time.Time) string {
// not public because this only sets the expiration time of the session,
// and not the session cookie it represents because that would require
// breaking into the user's computer.
this.expires.Set(expires)
return ""
}
func (this *Session) Expired () bool {
return this.Expiration().Before(time.Now())
}
func ensafenValue (value any) (any, error) {
switch value := value.(type) {
case
int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
bool, float32, float64, string, nil:
return value, nil
case []any:
list := make([]any, len(value))
for index, item := range value {
item, err := ensafenValue(item)
if err != nil { return nil, err }
list[index] = item
}
return list, nil
case map[string] any:
dict := make(map[string] any, len(value))
for key, item := range value {
item, err := ensafenValue(item)
if err != nil { return nil, err }
dict[key] = item
}
return dict, nil
}
return nil, step.ErrTypeMismatch
}