sync: Add Select, Combine
This commit is contained in:
parent
9fd40a37b8
commit
09aa19e7c1
78
sync/select.go
Normal file
78
sync/select.go
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
package usync
|
||||||
|
|
||||||
|
import "slices"
|
||||||
|
import "context"
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
// A type-safe wrapper around reflect.Select. Taken from:
|
||||||
|
// https://stackoverflow.com/questions/19992334
|
||||||
|
func Select[T any] (ctx context.Context, channels ...chan T) (int, T, bool) {
|
||||||
|
var zero T
|
||||||
|
|
||||||
|
// add all channels as select cases
|
||||||
|
cases := make([]reflect.SelectCase, len(channels) + 1)
|
||||||
|
for i, ch := range channels {
|
||||||
|
cases[i] = reflect.SelectCase {
|
||||||
|
Dir: reflect.SelectRecv,
|
||||||
|
Chan: reflect.ValueOf(ch),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add ctx.Done() as another select case to stop listening when the
|
||||||
|
// context is closed
|
||||||
|
cases[len(channels)] = reflect.SelectCase {
|
||||||
|
Dir: reflect.SelectRecv,
|
||||||
|
Chan: reflect.ValueOf(ctx.Done()),
|
||||||
|
}
|
||||||
|
|
||||||
|
// read from the channel
|
||||||
|
chosen, value, ok := reflect.Select(cases)
|
||||||
|
if !ok {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return -1, zero, false
|
||||||
|
}
|
||||||
|
return chosen, zero, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// cast return value
|
||||||
|
if ret, ok := value.Interface().(T); ok {
|
||||||
|
return chosen, ret, true
|
||||||
|
}
|
||||||
|
return chosen, zero, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine returns a channel that continuously returns the result of the select
|
||||||
|
// until all input sources are exhauste, or the context is canceled.
|
||||||
|
func Combine[T any] (ctx context.Context, channels ...chan T) <- chan T {
|
||||||
|
channel := make(chan T)
|
||||||
|
// our silly slection routine
|
||||||
|
go func () {
|
||||||
|
for {
|
||||||
|
if len(channels) < 2 {
|
||||||
|
// only the context is left, stop everything
|
||||||
|
close(channel)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// read new value
|
||||||
|
chosen, value, ok := Select(ctx, channels...)
|
||||||
|
if ok {
|
||||||
|
// we have a value
|
||||||
|
channel <- value
|
||||||
|
} else {
|
||||||
|
// a channel has been closed and we need to do
|
||||||
|
// something about it
|
||||||
|
if chosen == len(channels) - 1 {
|
||||||
|
// the context has expired, stop
|
||||||
|
// everything
|
||||||
|
close(channel)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// a normal channel has closed, remove
|
||||||
|
// it from the list
|
||||||
|
channels = slices.Delete(channels, chosen, chosen + 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} ()
|
||||||
|
return channel
|
||||||
|
}
|
24
sync/select_test.go
Normal file
24
sync/select_test.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package usync
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
import "testing"
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
func TestSelect (test *testing.T) {
|
||||||
|
// https://stackoverflow.com/questions/19992334
|
||||||
|
c1 := make(chan int)
|
||||||
|
c2 := make(chan int)
|
||||||
|
c3 := make(chan int)
|
||||||
|
chs := []chan int { c1, c2, c3 }
|
||||||
|
go func () {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
c2 <- 42
|
||||||
|
} ()
|
||||||
|
ctx, done := context.WithTimeout(context.Background(), 5 * time.Second)
|
||||||
|
defer done()
|
||||||
|
|
||||||
|
chosen, val, ok := Select(ctx, chs...)
|
||||||
|
if !ok { test.Fatal("not ok") }
|
||||||
|
if 1 != chosen { test.Fatal("expected 1, got", chosen) }
|
||||||
|
if 42 != val { test.Fatal("expected 42, got", val) }
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user