diff --git a/generate/misc_test.go b/generate/misc_test.go index acf86f4..11e1de3 100644 --- a/generate/misc_test.go +++ b/generate/misc_test.go @@ -37,6 +37,7 @@ func testGenerateRun(test *testing.T, protocol *Protocol, title, imports, testCa import "reflect" import "git.tebibyte.media/sashakoshka/hopp/tape" import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil" + import "git.tebibyte.media/sashakoshka/hopp/internal/testutil/snake" ` + imports setup := `log.Println("*** BEGIN TEST CASE OUTPUT ***")` teardown := `log.Println("--- END TEST CASE OUTPUT ---")` @@ -61,8 +62,9 @@ func testGenerateRun(test *testing.T, protocol *Protocol, title, imports, testCa func testDecode(correct Message, data any) { var flat []byte switch data := data.(type) { - case []byte: flat = data - case tu.Snake: flat = data.Flatten() + case []byte: flat = data + case tu.Snake: flat = data.Flatten() + case snake.Snake: flat = data.Flatten() } message := reflect.New(reflect.ValueOf(correct).Elem().Type()).Interface().(Message) log.Println("before: ", message) @@ -79,9 +81,7 @@ func testGenerateRun(test *testing.T, protocol *Protocol, title, imports, testCa } } - // TODO: possibly combine the two above functions into this one, - // also take a data parameter here (snake) - func testEncodeDecode(message Message, data tu.Snake) {buffer := bytes.Buffer { } + func testEncodeDecode(message Message, data any) {buffer := bytes.Buffer { } log.Println("encoding:") encoder := tape.NewEncoder(&buffer) n, err := message.Encode(encoder) @@ -93,13 +93,30 @@ func testGenerateRun(test *testing.T, protocol *Protocol, title, imports, testCa if n != len(got) { log.Fatalf("n incorrect: %d != %d\n", n, len(got)) } - if ok, n := data.Check(got); !ok { - log.Fatalln("not equal at", n) + + var flat []byte + switch data := data.(type) { + case []byte: + flat = data + if ok, n := snake.Check(snake.L(data...), got); !ok { + log.Fatalln("not equal at", n) + } + case tu.Snake: + flat = data.Flatten() + if ok, n := data.Check(got); !ok { + log.Fatalln("not equal at", n) + } + case snake.Node: + flat = data.Flatten() + if ok, n := snake.Check(data, got); !ok { + log.Fatalln("not equal at", n) + } + default: + panic("AUSIAUGH AAAUUGUHGHGHH OUHGHGJDSGK") } log.Println("decoding:") destination := reflect.New(reflect.ValueOf(message).Elem().Type()).Interface().(Message) - flat := data.Flatten() log.Println("before: ", tu.Describe(destination)) decoder := tape.NewDecoder(bytes.NewBuffer(flat)) n, err = destination.Decode(decoder)