sync: Add Select, Combine
This commit is contained in:
		
							parent
							
								
									9fd40a37b8
								
							
						
					
					
						commit
						09aa19e7c1
					
				
							
								
								
									
										78
									
								
								sync/select.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								sync/select.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,78 @@ | |||||||
|  | package usync | ||||||
|  | 
 | ||||||
|  | import "slices" | ||||||
|  | import "context" | ||||||
|  | import "reflect" | ||||||
|  | 
 | ||||||
|  | // A type-safe wrapper around reflect.Select. Taken from: | ||||||
|  | // https://stackoverflow.com/questions/19992334 | ||||||
|  | func Select[T any] (ctx context.Context, channels ...chan T) (int, T, bool) { | ||||||
|  | 	var zero T | ||||||
|  | 
 | ||||||
|  | 	// add all channels as select cases | ||||||
|  | 	cases := make([]reflect.SelectCase, len(channels) + 1) | ||||||
|  | 	for i, ch := range channels { | ||||||
|  | 		cases[i] = reflect.SelectCase { | ||||||
|  | 			Dir:  reflect.SelectRecv, | ||||||
|  | 			Chan: reflect.ValueOf(ch), | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// add ctx.Done() as another select case to stop listening when the | ||||||
|  | 	// context is closed | ||||||
|  | 	cases[len(channels)] = reflect.SelectCase { | ||||||
|  | 		Dir:  reflect.SelectRecv, | ||||||
|  | 		Chan: reflect.ValueOf(ctx.Done()), | ||||||
|  | 	} | ||||||
|  | 	 | ||||||
|  | 	// read from the channel | ||||||
|  | 	chosen, value, ok := reflect.Select(cases) | ||||||
|  | 	if !ok { | ||||||
|  | 		if ctx.Err() != nil { | ||||||
|  | 			return -1, zero, false | ||||||
|  | 		} | ||||||
|  | 		return chosen, zero, false | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// cast return value | ||||||
|  | 	if ret, ok := value.Interface().(T); ok { | ||||||
|  | 		return chosen, ret, true | ||||||
|  | 	} | ||||||
|  | 	return chosen, zero, false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Combine returns a channel that continuously returns the result of the select | ||||||
|  | // until all input sources are exhauste, or the context is canceled. | ||||||
|  | func Combine[T any] (ctx context.Context, channels ...chan T) <- chan T { | ||||||
|  | 	channel := make(chan T) | ||||||
|  | 	// our silly slection routine | ||||||
|  | 	go func () { | ||||||
|  | 		for { | ||||||
|  | 			if len(channels) < 2 { | ||||||
|  | 				// only the context is left, stop everything | ||||||
|  | 				close(channel) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			// read new value | ||||||
|  | 			chosen, value, ok := Select(ctx, channels...) | ||||||
|  | 			if ok { | ||||||
|  | 				// we have a value | ||||||
|  | 				channel <- value | ||||||
|  | 			} else { | ||||||
|  | 				// a channel has been closed and we need to do | ||||||
|  | 				// something about it | ||||||
|  | 				if chosen == len(channels) - 1 { | ||||||
|  | 					// the context has expired, stop | ||||||
|  | 					// everything | ||||||
|  | 					close(channel) | ||||||
|  | 					return | ||||||
|  | 				} else { | ||||||
|  | 					// a normal channel has closed, remove | ||||||
|  | 					// it from the list | ||||||
|  | 					channels = slices.Delete(channels, chosen, chosen + 1) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} () | ||||||
|  | 	return channel | ||||||
|  | } | ||||||
							
								
								
									
										24
									
								
								sync/select_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								sync/select_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,24 @@ | |||||||
|  | package usync | ||||||
|  | 
 | ||||||
|  | import "time" | ||||||
|  | import "testing" | ||||||
|  | import "context" | ||||||
|  | 
 | ||||||
|  | func TestSelect (test *testing.T) { | ||||||
|  | 	// https://stackoverflow.com/questions/19992334 | ||||||
|  | 	c1 := make(chan int) | ||||||
|  | 	c2 := make(chan int) | ||||||
|  | 	c3 := make(chan int) | ||||||
|  | 	chs := []chan int { c1, c2, c3 } | ||||||
|  | 	go func () { | ||||||
|  | 		time.Sleep(time.Second) | ||||||
|  | 		c2 <- 42 | ||||||
|  | 	} () | ||||||
|  | 	ctx, done := context.WithTimeout(context.Background(), 5 * time.Second) | ||||||
|  | 	defer done() | ||||||
|  | 
 | ||||||
|  | 	chosen, val, ok := Select(ctx, chs...) | ||||||
|  | 	if !ok          { test.Fatal("not ok")                  } | ||||||
|  | 	if 1  != chosen { test.Fatal("expected 1, got", chosen) } | ||||||
|  | 	if 42 != val    { test.Fatal("expected 42, got", val)   } | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user