459 lines
13 KiB
Go
459 lines
13 KiB
Go
package hopp
|
|
|
|
import "io"
|
|
import "net"
|
|
import "sync"
|
|
import "bytes"
|
|
import "errors"
|
|
import "slices"
|
|
import "testing"
|
|
import "context"
|
|
import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil"
|
|
|
|
// some of these tests spawn goroutines that can signal a failure.
|
|
// abide by the documentation for testing.T (https://pkg.go.dev/testing#T):
|
|
//
|
|
// A test ends when its Test function returns or calls any of the methods
|
|
// FailNow, Fatal, Fatalf, SkipNow, Skip, or Skipf. Those methods, as well as
|
|
// the Parallel method, must be called only from the goroutine running the
|
|
// Test function.
|
|
//
|
|
// The other reporting methods, such as the variations of Log and Error, may
|
|
// be called simultaneously from multiple goroutines.
|
|
|
|
func TestConnA(test *testing.T) {
|
|
payloads := []string {
|
|
"hello",
|
|
"world",
|
|
"When the impostor is sus!",
|
|
}
|
|
|
|
clientFunc := func(a Conn) {
|
|
test.Log("CLIENT accepting transaction")
|
|
trans, err := a.AcceptTrans()
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
test.Log("CLIENT accepted transaction")
|
|
test.Cleanup(func() { trans.Close() })
|
|
for method, payload := range payloads {
|
|
test.Log("CLIENT waiting...")
|
|
gotMethod, gotPayloadBytes, err := trans.Receive()
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
gotPayload := string(gotPayloadBytes)
|
|
test.Log("CLIENT m:", gotMethod, "p:", gotPayload)
|
|
if int(gotMethod) != method {
|
|
test.Errorf("CLIENT method not equal")
|
|
}
|
|
if gotPayload != payload {
|
|
test.Errorf("CLIENT payload not equal")
|
|
}
|
|
}
|
|
test.Log("CLIENT waiting for transaction close...")
|
|
gotMethod, gotPayload, err := trans.Receive()
|
|
if !errors.Is(err, io.EOF) {
|
|
test.Error("CLIENT wrong error:", err)
|
|
test.Error("CLIENT method:", gotMethod)
|
|
test.Error("CLIENT payload:", gotPayload)
|
|
test.Fatal("CLIENT ok byeeeeeeeeeeeee")
|
|
}
|
|
test.Log("CLIENT transaction has closed")
|
|
}
|
|
|
|
serverFunc := func(a Conn) {
|
|
trans, err := a.OpenTrans()
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
test.Cleanup(func() { trans.Close() })
|
|
for method, payload := range payloads {
|
|
test.Log("SERVER m:", method, "p:", payload)
|
|
err := trans.Send(uint16(method), []byte(payload))
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
}
|
|
test.Log("SERVER closing connection")
|
|
}
|
|
|
|
clientServerEnvironment(test, clientFunc, serverFunc)
|
|
}
|
|
|
|
func TestTransOpenCloseA(test *testing.T) {
|
|
clientFunc := func(conn Conn) {
|
|
// 10
|
|
trans, err := conn.OpenTrans()
|
|
if err != nil { test.Error("CLIENT", err); return }
|
|
test.Log("CLIENT sending 10")
|
|
trans.Send(10, []byte("hi"))
|
|
trans.Close()
|
|
|
|
// 20
|
|
test.Log("CLIENT awaiting 20")
|
|
trans, err = conn.AcceptTrans()
|
|
if err != nil { test.Error("CLIENT", err); return }
|
|
test.Cleanup(func() { trans.Close() })
|
|
gotMethod, gotPayload, err := trans.Receive()
|
|
if err != nil { test.Error("CLIENT", err); return }
|
|
test.Logf("CLIENT m: %d p: %s", gotMethod, gotPayload)
|
|
if gotMethod != 20 { test.Error("CLIENT wrong method")}
|
|
|
|
// 30
|
|
trans, err = conn.OpenTrans()
|
|
if err != nil { test.Error("CLIENT", err); return }
|
|
test.Log("CLIENT sending 30")
|
|
trans.Send(30, []byte("good"))
|
|
trans.Close()
|
|
}
|
|
|
|
serverFunc := func(conn Conn) {
|
|
// 10
|
|
test.Log("SERVER awaiting 10")
|
|
trans, err := conn.AcceptTrans()
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
test.Cleanup(func() { trans.Close() })
|
|
gotMethod, gotPayload, err := trans.Receive()
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
test.Logf("SERVER m: %d p: %s", gotMethod, gotPayload)
|
|
if gotMethod != 10 { test.Error("SERVER wrong method")}
|
|
|
|
// 20
|
|
trans, err = conn.OpenTrans()
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
test.Log("SERVER sending 20")
|
|
trans.Send(20, []byte("hi how r u"))
|
|
trans.Close()
|
|
|
|
// 30
|
|
test.Log("SERVER awaiting 30")
|
|
trans, err = conn.AcceptTrans()
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
test.Cleanup(func() { trans.Close() })
|
|
gotMethod, gotPayload, err = trans.Receive()
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
test.Logf("SERVER m: %d p: %s", gotMethod, gotPayload)
|
|
if gotMethod != 30 { test.Error("SERVER wrong method")}
|
|
}
|
|
|
|
clientServerEnvironment(test, clientFunc, serverFunc)
|
|
}
|
|
|
|
func TestReadWriteA(test *testing.T) {
|
|
payloads := []string {
|
|
"hello",
|
|
"world",
|
|
"When the impostor is sus!",
|
|
}
|
|
|
|
clientFunc := func(a Conn) {
|
|
test.Log("CLIENT accepting transaction")
|
|
trans, err := a.AcceptTrans()
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
test.Log("CLIENT accepted transaction")
|
|
test.Cleanup(func() { trans.Close() })
|
|
for method, payload := range payloads {
|
|
test.Log("CLIENT waiting...")
|
|
gotMethod, gotReader, err := trans.ReceiveReader()
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
gotPayloadBytes, err := io.ReadAll(gotReader)
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
gotPayload := string(gotPayloadBytes)
|
|
test.Log("CLIENT m:", gotMethod, "p:", gotPayload)
|
|
if int(gotMethod) != method {
|
|
test.Errorf("CLIENT method not equal")
|
|
}
|
|
if gotPayload != payload {
|
|
test.Errorf("CLIENT payload not equal")
|
|
}
|
|
}
|
|
test.Log("CLIENT waiting for transaction close...")
|
|
gotMethod, gotPayload, err := trans.Receive()
|
|
if !errors.Is(err, io.EOF) {
|
|
test.Error("CLIENT wrong error:", err)
|
|
test.Error("CLIENT method:", gotMethod)
|
|
test.Error("CLIENT payload:", gotPayload)
|
|
test.Fatal("CLIENT ok byeeeeeeeeeeeee")
|
|
}
|
|
test.Log("CLIENT transaction has closed")
|
|
}
|
|
|
|
serverFunc := func(a Conn) {
|
|
defer test.Log("SERVER closing connection")
|
|
trans, err := a.OpenTrans()
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
test.Cleanup(func() { trans.Close() })
|
|
for method, payload := range payloads {
|
|
test.Log("SERVER m:", method, "p:", payload)
|
|
func() {
|
|
writer, err := trans.SendWriter(uint16(method))
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
defer writer.Close()
|
|
_, err = writer.Write([]byte(payload))
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
}()
|
|
}
|
|
}
|
|
|
|
clientServerEnvironment(test, clientFunc, serverFunc)
|
|
}
|
|
|
|
func TestEncodeMessageA(test *testing.T) {
|
|
buffer := new(bytes.Buffer)
|
|
payload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 }
|
|
err := encodeMessageA(buffer, defaultSizeLimit, 0x5800FEABC3104F04, 0x6B12, 0, payload)
|
|
correct := []byte {
|
|
0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04,
|
|
0x6B, 0x12,
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06,
|
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
|
|
}
|
|
if err != nil {
|
|
test.Fatal(err)
|
|
}
|
|
if got, correct := buffer.Bytes(), correct; !slices.Equal(got, correct) {
|
|
test.Fatalf("not equal: %v %v", got, correct)
|
|
}
|
|
}
|
|
|
|
func TestEncodeMessageAErr(test *testing.T) {
|
|
buffer := new(bytes.Buffer)
|
|
payload := make([]byte, 0x10000)
|
|
err := encodeMessageA(buffer, 0x20, 0x5800FEABC3104F04, 0x6B12, 0, payload)
|
|
if !errors.Is(err, ErrPayloadTooLarge) {
|
|
test.Fatalf("wrong error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDecodeMessageA(test *testing.T) {
|
|
transID, method, _, payload, err := decodeMessageA(bytes.NewReader([]byte {
|
|
0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04,
|
|
0x6B, 0x12,
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06,
|
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
|
|
}), defaultSizeLimit)
|
|
if err != nil {
|
|
test.Fatal(err)
|
|
}
|
|
if got, correct := transID, int64(0x5800FEABC3104F04); got != correct {
|
|
test.Fatalf("not equal: %v %v", got, correct)
|
|
}
|
|
if got, correct := method, uint16(0x6B12); got != correct {
|
|
test.Fatalf("not equal: %v %v", got, correct)
|
|
}
|
|
correctPayload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 }
|
|
if got, correct := payload, correctPayload; !slices.Equal(got, correct) {
|
|
test.Fatalf("not equal: %v %v", got, correct)
|
|
}
|
|
}
|
|
|
|
func TestDecodeMessageAErr(test *testing.T) {
|
|
_, _, _, _, err := decodeMessageA(bytes.NewReader([]byte {
|
|
0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04,
|
|
0x6B, 0x12,
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x06,
|
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
|
|
}), defaultSizeLimit)
|
|
if !errors.Is(err, io.ErrUnexpectedEOF) {
|
|
test.Fatalf("wrong error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestEncodeDecodeMessageA(test *testing.T) {
|
|
correctTransID := int64(2)
|
|
correctMethod := uint16(30)
|
|
correctPayload := []byte("good")
|
|
buffer := bytes.Buffer { }
|
|
err := encodeMessageA(&buffer, defaultSizeLimit, correctTransID, correctMethod, 0, correctPayload)
|
|
if err != nil { test.Fatal(err) }
|
|
transID, method, chunked, payload, err := decodeMessageA(&buffer, defaultSizeLimit)
|
|
if got, correct := transID, int64(2); got != correct {
|
|
test.Fatalf("not equal: %v %v", got, correct)
|
|
}
|
|
if got, correct := method, uint16(30); got != correct {
|
|
test.Fatalf("not equal: %v %v", got, correct)
|
|
}
|
|
if chunked {
|
|
test.Fatalf("message should not be chunked")
|
|
}
|
|
if got, correct := payload, correctPayload; !slices.Equal(got, correct) {
|
|
test.Fatalf("not equal: %v %v", got, correct)
|
|
}
|
|
}
|
|
|
|
func TestConsecutiveWrite(test *testing.T) {
|
|
packets := [][]byte {
|
|
[]byte {
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
|
|
0x00, 0x00,
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
|
|
0x43, 0x00, 0x00, 0x00, 0x07 },
|
|
|
|
[]byte {
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
|
|
0x00, 0x00,
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
|
|
0x43, 0x00, 0x00, 0x00, 0x08 },
|
|
}
|
|
payloads := [][]byte {
|
|
[]byte { 0x43, 0x00, 0x00, 0x00, 0x07 },
|
|
[]byte { 0x43, 0x00, 0x00, 0x00, 0x08 },
|
|
}
|
|
|
|
var group sync.WaitGroup
|
|
group.Add(2)
|
|
|
|
// server
|
|
listener, err := net.Listen("tcp", "localhost:9999")
|
|
if err != nil { test.Fatal("SERVER", err) }
|
|
go func() {
|
|
defer group.Done()
|
|
defer listener.Close()
|
|
conn, err := listener.Accept()
|
|
if err != nil { test.Fatal("SERVER", err) }
|
|
defer conn.Close()
|
|
|
|
buf := [16]byte { }
|
|
for {
|
|
_, err := conn.Read(buf[:])
|
|
if err != nil { break }
|
|
}
|
|
}()
|
|
|
|
// client
|
|
go func() {
|
|
defer group.Done()
|
|
conn, err := net.Dial("tcp", "localhost:9999")
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
defer conn.Close()
|
|
recorder := tu.RecordConn(conn)
|
|
|
|
a := AdaptA(recorder, ClientSide)
|
|
trans, err := a.OpenTrans()
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
|
|
for _, payload := range payloads {
|
|
err := trans.Send(0x0000, payload)
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
}
|
|
|
|
test.Log("CLIENT recorded output:\n" + recorder.Dump())
|
|
if len(recorder.Log) != 2 { test.Fatal("wrong length") }
|
|
if !slices.Equal(recorder.Log[0].([]byte), packets[0]) {
|
|
test.Fatal("not equal")
|
|
}
|
|
if !slices.Equal(recorder.Log[1].([]byte), packets[1]) {
|
|
test.Fatal("not equal")
|
|
}
|
|
}()
|
|
|
|
group.Wait()
|
|
}
|
|
|
|
func TestConsecutiveRead(test *testing.T) {
|
|
stream := []byte {
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
|
|
0x00, 0x00,
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
|
|
0x43, 0x00, 0x00, 0x00, 0x07,
|
|
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
|
|
0x00, 0x00,
|
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05,
|
|
0x43, 0x00, 0x00, 0x00, 0x08,
|
|
}
|
|
payloads := [][]byte {
|
|
[]byte { 0x43, 0x00, 0x00, 0x00, 0x07 },
|
|
[]byte { 0x43, 0x00, 0x00, 0x00, 0x08 },
|
|
}
|
|
|
|
var group sync.WaitGroup
|
|
group.Add(2)
|
|
|
|
// server
|
|
listener, err := net.Listen("tcp", "localhost:9999")
|
|
if err != nil { test.Fatal("SERVER", err) }
|
|
go func() {
|
|
defer group.Done()
|
|
defer listener.Close()
|
|
conn, err := listener.Accept()
|
|
if err != nil { test.Fatal("SERVER", err) }
|
|
defer conn.Close()
|
|
|
|
a := AdaptA(conn, ServerSide)
|
|
trans, err := a.AcceptTrans()
|
|
if err != nil { test.Fatal("SERVER", err) }
|
|
index := 0
|
|
for {
|
|
method, data, err := trans.Receive()
|
|
if err != nil {
|
|
if !errors.Is(err, io.EOF) {
|
|
test.Fatal("SERVER", err)
|
|
}
|
|
break
|
|
}
|
|
test.Logf("SERVER GOT: M%04X %s", method, tu.HexBytes(data))
|
|
if index >= len(payloads) {
|
|
test.Fatalf(
|
|
"SERVER we weren't supposed to receive %d messages",
|
|
index + 1)
|
|
}
|
|
if method != 0 {
|
|
test.Fatal("SERVER", "method not equal")
|
|
}
|
|
if !slices.Equal(data, payloads[index]) {
|
|
test.Fatal("SERVER", "data not equal")
|
|
}
|
|
index ++
|
|
}
|
|
if index != len(payloads) {
|
|
test.Fatalf(
|
|
"SERVER we weren't supposed to receive %d messages",
|
|
index + 1)
|
|
}
|
|
}()
|
|
|
|
// client
|
|
go func() {
|
|
defer group.Done()
|
|
conn, err := net.Dial("tcp", "localhost:9999")
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
defer conn.Close()
|
|
_, err = conn.Write(stream)
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
}()
|
|
|
|
group.Wait()
|
|
}
|
|
|
|
func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), serverFunc func(conn Conn)) {
|
|
network := "tcp"
|
|
addr := "localhost:7959"
|
|
|
|
// server
|
|
listener, err := Listen(network, addr, nil)
|
|
test.Cleanup(func() { listener.Close() })
|
|
go func() {
|
|
test.Log("SERVER listening")
|
|
conn, err := listener.Accept()
|
|
if err != nil { test.Error("SERVER", err); return }
|
|
|
|
defer conn.Close()
|
|
test.Cleanup(func() { conn.Close() })
|
|
|
|
serverFunc(conn)
|
|
test.Log("SERVER closing")
|
|
}()
|
|
|
|
// client
|
|
test.Log("CLIENT dialing")
|
|
conn, err := Dial(context.Background(), network, addr, nil)
|
|
if err != nil { test.Fatal("CLIENT", err) }
|
|
test.Cleanup(func() { conn.Close() })
|
|
test.Log("CLIENT dialed")
|
|
|
|
clientFunc(conn)
|
|
|
|
test.Log("CLIENT waiting for connection close...")
|
|
trans, err := conn.AcceptTrans()
|
|
if !errors.Is(err, io.EOF) {
|
|
test.Error("CLIENT wrong error:", err)
|
|
test.Fatal("CLIENT trans:", trans)
|
|
}
|
|
test.Log("CLIENT DONE")
|
|
conn.Close()
|
|
}
|