internal/testutil: Add ConnRecorder which records net.Conn writes
This commit is contained in:
parent
ad930144cf
commit
3136dcbfdf
@ -57,9 +57,9 @@ func runTrans(conn hopp.Conn, trans hopp.Trans) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch message := message.(type) {
|
switch message := message.(type) {
|
||||||
case *ping.MessagePing:
|
case ping.MessagePing:
|
||||||
log.Printf("--> ping (%d) from %v", message, conn.RemoteAddr())
|
log.Printf("--> ping (%d) from %v", message, conn.RemoteAddr())
|
||||||
response := ping.MessagePong(*message)
|
response := ping.MessagePong(message)
|
||||||
_, err := ping.Send(trans, &response)
|
_, err := ping.Send(trans, &response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("XXX failed to send message: %v", err)
|
log.Printf("XXX failed to send message: %v", err)
|
||||||
|
|||||||
45
internal/testutil/conn-recorder.go
Normal file
45
internal/testutil/conn-recorder.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
import "fmt"
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
var _ net.Conn = new(ConnRecorder)
|
||||||
|
|
||||||
|
// ConnRecorder records write/flush actions performed on a net.Conn.
|
||||||
|
type ConnRecorder struct {
|
||||||
|
net.Conn
|
||||||
|
// A []byte means data was written, and untyped nil
|
||||||
|
// means data was flushed.
|
||||||
|
Log []any
|
||||||
|
}
|
||||||
|
|
||||||
|
func RecordConn(underlying net.Conn) *ConnRecorder {
|
||||||
|
return &ConnRecorder {
|
||||||
|
Conn: underlying,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ConnRecorder) Write(data []byte) (n int, err error) {
|
||||||
|
this.Log = append(this.Log, data)
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ConnRecorder) Flush() error {
|
||||||
|
this.Log = append(this.Log, nil)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *ConnRecorder) Dump() string {
|
||||||
|
builder := strings.Builder { }
|
||||||
|
for index, item := range this.Log {
|
||||||
|
fmt.Fprintf(&builder, "%06d ", index)
|
||||||
|
switch item := item.(type) {
|
||||||
|
case nil:
|
||||||
|
fmt.Fprintln(&builder, "FLUSH")
|
||||||
|
case []byte:
|
||||||
|
fmt.Fprintln(&builder, HexBytes(item))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
44
internal/testutil/conn-recorder_test.go
Normal file
44
internal/testutil/conn-recorder_test.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestConnRecorder(test *testing.T) {
|
||||||
|
// server
|
||||||
|
listener, err := net.Listen("tcp", "localhost:9999")
|
||||||
|
if err != nil { test.Fatal(err) }
|
||||||
|
defer listener.Close()
|
||||||
|
go func() {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
defer conn.Close()
|
||||||
|
if err != nil { test.Fatal(err) }
|
||||||
|
buf := [16]byte { }
|
||||||
|
for {
|
||||||
|
_, err := conn.Read(buf[:])
|
||||||
|
if err != nil { break }
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// client
|
||||||
|
conn, err := net.Dial("tcp", "localhost:9999")
|
||||||
|
if err != nil { test.Fatal(err) }
|
||||||
|
defer conn.Close()
|
||||||
|
recorder := RecordConn(conn)
|
||||||
|
|
||||||
|
_, err = recorder.Write([]byte("hello"))
|
||||||
|
if err != nil { test.Fatal(err) }
|
||||||
|
_, err = recorder.Write([]byte("world!"))
|
||||||
|
if err != nil { test.Fatal(err) }
|
||||||
|
err = recorder.Flush()
|
||||||
|
if err != nil { test.Fatal(err) }
|
||||||
|
|
||||||
|
test.Log("GOT:\n" + recorder.Dump())
|
||||||
|
|
||||||
|
if len(recorder.Log) != 3 { test.Fatal("wrong length") }
|
||||||
|
if string(recorder.Log[0].([]byte)) != "hello" {
|
||||||
|
test.Fatal("not equal")
|
||||||
|
}
|
||||||
|
if string(recorder.Log[1].([]byte)) != "world!" {
|
||||||
|
test.Fatal("not equal")
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user