From 4f0364bb4932f95a6521f6325044e2a8006168fe Mon Sep 17 00:00:00 2001 From: Sasha Koshka Date: Fri, 20 Dec 2024 22:37:00 -0500 Subject: [PATCH] providers/session: "Ensafen" values when they enter/exit the session --- providers/session/session.go | 39 ++++++++++++++++--- providers/session/session_test.go | 62 +++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 providers/session/session_test.go diff --git a/providers/session/session.go b/providers/session/session.go index c044ca9..7cbc41a 100644 --- a/providers/session/session.go +++ b/providers/session/session.go @@ -1,4 +1,4 @@ -package os +package session import "time" import "net/http" @@ -160,18 +160,20 @@ func (this *Session) ID () uuid.UUID { return this.id } -func (this *Session) Get (name string) any { +func (this *Session) Get (name string) (any, error) { data, done := this.data.RBorrow() defer done() - return data[name] + return ensafenValue(data[name]) } -func (this *Session) Set (name string, value any) string { +func (this *Session) Set (name string, value any) (string, error) { + value, err := ensafenValue(value) + if err != nil { return "", err } data, done := this.data.Borrow() defer done() data[name] = value this.addSelf() - return "" + return "", nil } func (this *Session) Del (name string) string { @@ -224,3 +226,30 @@ func (this *Session) setExpiration (expires time.Time) string { func (this *Session) Expired () bool { return this.Expiration().Before(time.Now()) } + +func ensafenValue (value any) (any, error) { + switch value := value.(type) { + case + int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + bool, float32, float64, string, nil: + return value, nil + case []any: + list := make([]any, len(value)) + for index, item := range value { + item, err := ensafenValue(item) + if err != nil { return nil, err } + list[index] = item + } + return list, nil + case map[string] any: + dict := make(map[string] any, len(value)) + for key, item := range value { + item, err := ensafenValue(item) + if err != nil { return nil, err } + dict[key] = item + } + return dict, nil + } + return nil, step.ErrTypeMismatch +} diff --git a/providers/session/session_test.go b/providers/session/session_test.go new file mode 100644 index 0000000..3bc0531 --- /dev/null +++ b/providers/session/session_test.go @@ -0,0 +1,62 @@ +package session + +import "io" +import "reflect" +import "testing" +import "git.tebibyte.media/sashakoshka/step" + +func TestEnsafenValue (test *testing.T) { + items := []any { + "hello", + 123, + 934.3298, + 'o', + []any { + "asljdkasd", + "90iur3e", + }, + map[string] any { + "asdkiasd": 34, + "jdjjdfj": '-', + }, + nil, + } + for index, item := range items { + safe, err := ensafenValue(item) + if err != nil { test.Fatal(index, err) } + test.Logf("%d: %v --> %v", index, item, safe) + if !reflect.DeepEqual(item, safe) { + test.Fatal("not equal") + } + switch item := item.(type) { + case []any, map[string] any: + if reflect.ValueOf(item).Pointer() == reflect.ValueOf(safe).Pointer() { + test.Fatal("memory wasn't duplicated") + } + } + } +} + +func TestEnsafenValueErrTypeMismatch (test *testing.T) { + items := []any { + []any { + "asljdkasd", + &struct { } { }, + "90iur3e", + }, + io.EOF, + []string { + "hello", + }, + map[string] any { + "asdkiasd": 34, + "jdjjdfj": struct { } { }, + }, + } + for index, item := range items { + safe, err := ensafenValue(item) + if err != step.ErrTypeMismatch { + test.Fatalf("%d: no error, produced %v", index, safe) + } + } +}