internal/testutil: Add ConnRecorder which records net.Conn writes

This commit is contained in:
Sasha Koshka 2025-11-19 13:10:26 -05:00
parent ad930144cf
commit 3136dcbfdf
3 changed files with 91 additions and 2 deletions

View File

@ -57,9 +57,9 @@ func runTrans(conn hopp.Conn, trans hopp.Trans) {
return return
} }
switch message := message.(type) { switch message := message.(type) {
case *ping.MessagePing: case ping.MessagePing:
log.Printf("--> ping (%d) from %v", message, conn.RemoteAddr()) log.Printf("--> ping (%d) from %v", message, conn.RemoteAddr())
response := ping.MessagePong(*message) response := ping.MessagePong(message)
_, err := ping.Send(trans, &response) _, err := ping.Send(trans, &response)
if err != nil { if err != nil {
log.Printf("XXX failed to send message: %v", err) log.Printf("XXX failed to send message: %v", err)

View File

@ -0,0 +1,45 @@
package testutil
import "net"
import "fmt"
import "strings"
var _ net.Conn = new(ConnRecorder)
// ConnRecorder records write/flush actions performed on a net.Conn.
type ConnRecorder struct {
net.Conn
// A []byte means data was written, and untyped nil
// means data was flushed.
Log []any
}
func RecordConn(underlying net.Conn) *ConnRecorder {
return &ConnRecorder {
Conn: underlying,
}
}
func (this *ConnRecorder) Write(data []byte) (n int, err error) {
this.Log = append(this.Log, data)
return len(data), nil
}
func (this *ConnRecorder) Flush() error {
this.Log = append(this.Log, nil)
return nil
}
func (this *ConnRecorder) Dump() string {
builder := strings.Builder { }
for index, item := range this.Log {
fmt.Fprintf(&builder, "%06d ", index)
switch item := item.(type) {
case nil:
fmt.Fprintln(&builder, "FLUSH")
case []byte:
fmt.Fprintln(&builder, HexBytes(item))
}
}
return builder.String()
}

View File

@ -0,0 +1,44 @@
package testutil
import "net"
import "testing"
func TestConnRecorder(test *testing.T) {
// server
listener, err := net.Listen("tcp", "localhost:9999")
if err != nil { test.Fatal(err) }
defer listener.Close()
go func() {
conn, err := listener.Accept()
defer conn.Close()
if err != nil { test.Fatal(err) }
buf := [16]byte { }
for {
_, err := conn.Read(buf[:])
if err != nil { break }
}
}()
// client
conn, err := net.Dial("tcp", "localhost:9999")
if err != nil { test.Fatal(err) }
defer conn.Close()
recorder := RecordConn(conn)
_, err = recorder.Write([]byte("hello"))
if err != nil { test.Fatal(err) }
_, err = recorder.Write([]byte("world!"))
if err != nil { test.Fatal(err) }
err = recorder.Flush()
if err != nil { test.Fatal(err) }
test.Log("GOT:\n" + recorder.Dump())
if len(recorder.Log) != 3 { test.Fatal("wrong length") }
if string(recorder.Log[0].([]byte)) != "hello" {
test.Fatal("not equal")
}
if string(recorder.Log[1].([]byte)) != "world!" {
test.Fatal("not equal")
}
}