sync: Add Select, Combine

This commit is contained in:
Sasha Koshka 2024-12-02 14:34:39 -05:00
parent 9fd40a37b8
commit 09aa19e7c1
2 changed files with 102 additions and 0 deletions

78
sync/select.go Normal file
View File

@ -0,0 +1,78 @@
package usync
import "slices"
import "context"
import "reflect"
// A type-safe wrapper around reflect.Select. Taken from:
// https://stackoverflow.com/questions/19992334
func Select[T any] (ctx context.Context, channels ...chan T) (int, T, bool) {
var zero T
// add all channels as select cases
cases := make([]reflect.SelectCase, len(channels) + 1)
for i, ch := range channels {
cases[i] = reflect.SelectCase {
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(ch),
}
}
// add ctx.Done() as another select case to stop listening when the
// context is closed
cases[len(channels)] = reflect.SelectCase {
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(ctx.Done()),
}
// read from the channel
chosen, value, ok := reflect.Select(cases)
if !ok {
if ctx.Err() != nil {
return -1, zero, false
}
return chosen, zero, false
}
// cast return value
if ret, ok := value.Interface().(T); ok {
return chosen, ret, true
}
return chosen, zero, false
}
// Combine returns a channel that continuously returns the result of the select
// until all input sources are exhauste, or the context is canceled.
func Combine[T any] (ctx context.Context, channels ...chan T) <- chan T {
channel := make(chan T)
// our silly slection routine
go func () {
for {
if len(channels) < 2 {
// only the context is left, stop everything
close(channel)
return
}
// read new value
chosen, value, ok := Select(ctx, channels...)
if ok {
// we have a value
channel <- value
} else {
// a channel has been closed and we need to do
// something about it
if chosen == len(channels) - 1 {
// the context has expired, stop
// everything
close(channel)
return
} else {
// a normal channel has closed, remove
// it from the list
channels = slices.Delete(channels, chosen, chosen + 1)
}
}
}
} ()
return channel
}

24
sync/select_test.go Normal file
View File

@ -0,0 +1,24 @@
package usync
import "time"
import "testing"
import "context"
func TestSelect (test *testing.T) {
// https://stackoverflow.com/questions/19992334
c1 := make(chan int)
c2 := make(chan int)
c3 := make(chan int)
chs := []chan int { c1, c2, c3 }
go func () {
time.Sleep(time.Second)
c2 <- 42
} ()
ctx, done := context.WithTimeout(context.Background(), 5 * time.Second)
defer done()
chosen, val, ok := Select(ctx, chs...)
if !ok { test.Fatal("not ok") }
if 1 != chosen { test.Fatal("expected 1, got", chosen) }
if 42 != val { test.Fatal("expected 42, got", val) }
}