From 3136dcbfdfede9611a04e394a8e7e5878773003f Mon Sep 17 00:00:00 2001 From: Sasha Koshka Date: Wed, 19 Nov 2025 13:10:26 -0500 Subject: [PATCH] internal/testutil: Add ConnRecorder which records net.Conn writes --- examples/ping/server/main.go | 4 +-- internal/testutil/conn-recorder.go | 45 +++++++++++++++++++++++++ internal/testutil/conn-recorder_test.go | 44 ++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 internal/testutil/conn-recorder.go create mode 100644 internal/testutil/conn-recorder_test.go diff --git a/examples/ping/server/main.go b/examples/ping/server/main.go index d85cd96..222fc5a 100644 --- a/examples/ping/server/main.go +++ b/examples/ping/server/main.go @@ -57,9 +57,9 @@ func runTrans(conn hopp.Conn, trans hopp.Trans) { return } switch message := message.(type) { - case *ping.MessagePing: + case ping.MessagePing: log.Printf("--> ping (%d) from %v", message, conn.RemoteAddr()) - response := ping.MessagePong(*message) + response := ping.MessagePong(message) _, err := ping.Send(trans, &response) if err != nil { log.Printf("XXX failed to send message: %v", err) diff --git a/internal/testutil/conn-recorder.go b/internal/testutil/conn-recorder.go new file mode 100644 index 0000000..55186b3 --- /dev/null +++ b/internal/testutil/conn-recorder.go @@ -0,0 +1,45 @@ +package testutil + +import "net" +import "fmt" +import "strings" + +var _ net.Conn = new(ConnRecorder) + +// ConnRecorder records write/flush actions performed on a net.Conn. +type ConnRecorder struct { + net.Conn + // A []byte means data was written, and untyped nil + // means data was flushed. + Log []any +} + +func RecordConn(underlying net.Conn) *ConnRecorder { + return &ConnRecorder { + Conn: underlying, + } +} + +func (this *ConnRecorder) Write(data []byte) (n int, err error) { + this.Log = append(this.Log, data) + return len(data), nil +} + +func (this *ConnRecorder) Flush() error { + this.Log = append(this.Log, nil) + return nil +} + +func (this *ConnRecorder) Dump() string { + builder := strings.Builder { } + for index, item := range this.Log { + fmt.Fprintf(&builder, "%06d ", index) + switch item := item.(type) { + case nil: + fmt.Fprintln(&builder, "FLUSH") + case []byte: + fmt.Fprintln(&builder, HexBytes(item)) + } + } + return builder.String() +} diff --git a/internal/testutil/conn-recorder_test.go b/internal/testutil/conn-recorder_test.go new file mode 100644 index 0000000..8dc2575 --- /dev/null +++ b/internal/testutil/conn-recorder_test.go @@ -0,0 +1,44 @@ +package testutil + +import "net" +import "testing" + +func TestConnRecorder(test *testing.T) { + // server + listener, err := net.Listen("tcp", "localhost:9999") + if err != nil { test.Fatal(err) } + defer listener.Close() + go func() { + conn, err := listener.Accept() + defer conn.Close() + if err != nil { test.Fatal(err) } + buf := [16]byte { } + for { + _, err := conn.Read(buf[:]) + if err != nil { break } + } + }() + + // client + conn, err := net.Dial("tcp", "localhost:9999") + if err != nil { test.Fatal(err) } + defer conn.Close() + recorder := RecordConn(conn) + + _, err = recorder.Write([]byte("hello")) + if err != nil { test.Fatal(err) } + _, err = recorder.Write([]byte("world!")) + if err != nil { test.Fatal(err) } + err = recorder.Flush() + if err != nil { test.Fatal(err) } + + test.Log("GOT:\n" + recorder.Dump()) + + if len(recorder.Log) != 3 { test.Fatal("wrong length") } + if string(recorder.Log[0].([]byte)) != "hello" { + test.Fatal("not equal") + } + if string(recorder.Log[1].([]byte)) != "world!" { + test.Fatal("not equal") + } +}