diff --git a/sync/select.go b/sync/select.go new file mode 100644 index 0000000..eb0c8d1 --- /dev/null +++ b/sync/select.go @@ -0,0 +1,78 @@ +package usync + +import "slices" +import "context" +import "reflect" + +// A type-safe wrapper around reflect.Select. Taken from: +// https://stackoverflow.com/questions/19992334 +func Select[T any] (ctx context.Context, channels ...chan T) (int, T, bool) { + var zero T + + // add all channels as select cases + cases := make([]reflect.SelectCase, len(channels) + 1) + for i, ch := range channels { + cases[i] = reflect.SelectCase { + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ch), + } + } + + // add ctx.Done() as another select case to stop listening when the + // context is closed + cases[len(channels)] = reflect.SelectCase { + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ctx.Done()), + } + + // read from the channel + chosen, value, ok := reflect.Select(cases) + if !ok { + if ctx.Err() != nil { + return -1, zero, false + } + return chosen, zero, false + } + + // cast return value + if ret, ok := value.Interface().(T); ok { + return chosen, ret, true + } + return chosen, zero, false +} + +// Combine returns a channel that continuously returns the result of the select +// until all input sources are exhauste, or the context is canceled. +func Combine[T any] (ctx context.Context, channels ...chan T) <- chan T { + channel := make(chan T) + // our silly slection routine + go func () { + for { + if len(channels) < 2 { + // only the context is left, stop everything + close(channel) + return + } + // read new value + chosen, value, ok := Select(ctx, channels...) + if ok { + // we have a value + channel <- value + } else { + // a channel has been closed and we need to do + // something about it + if chosen == len(channels) - 1 { + // the context has expired, stop + // everything + close(channel) + return + } else { + // a normal channel has closed, remove + // it from the list + channels = slices.Delete(channels, chosen, chosen + 1) + } + } + } + } () + return channel +} diff --git a/sync/select_test.go b/sync/select_test.go new file mode 100644 index 0000000..9167c79 --- /dev/null +++ b/sync/select_test.go @@ -0,0 +1,24 @@ +package usync + +import "time" +import "testing" +import "context" + +func TestSelect (test *testing.T) { + // https://stackoverflow.com/questions/19992334 + c1 := make(chan int) + c2 := make(chan int) + c3 := make(chan int) + chs := []chan int { c1, c2, c3 } + go func () { + time.Sleep(time.Second) + c2 <- 42 + } () + ctx, done := context.WithTimeout(context.Background(), 5 * time.Second) + defer done() + + chosen, val, ok := Select(ctx, chs...) + if !ok { test.Fatal("not ok") } + if 1 != chosen { test.Fatal("expected 1, got", chosen) } + if 42 != val { test.Fatal("expected 42, got", val) } +}