diff --git a/generate/misc_test.go b/generate/misc_test.go index 463d13f..c4e293f 100644 --- a/generate/misc_test.go +++ b/generate/misc_test.go @@ -6,9 +6,9 @@ import "os/exec" import "testing" import "path/filepath" -func testGenerateRun(test *testing.T, protocol *Protocol, imports string, testCase string) { +func testGenerateRun(test *testing.T, protocol *Protocol, title, imports, testCase string) { // reset data directory - dir := "test/generate-run" + dir := filepath.Join("test", title) err := os.RemoveAll(dir) if err != nil { test.Fatal(err) } err = os.MkdirAll(dir, 0750) @@ -58,14 +58,21 @@ func testGenerateRun(test *testing.T, protocol *Protocol, imports string, testCa } } - func testDecode(data []byte, message Message, correct Message) { - decoder := tape.NewDecoder(bytes.NewBuffer(data)) + func testDecode(correct Message, data any) { + var flat []byte + switch data := data.(type) { + case []byte: flat = data + case tu.Snake: flat = data.Flatten() + } + message := reflect.New(reflect.ValueOf(correct).Elem().Type()).Interface().(Message) + log.Println("before: ", message) + decoder := tape.NewDecoder(bytes.NewBuffer(flat)) n, err := message.Decode(decoder) if err != nil { log.Fatalf("at %d: %v\n", n, err) } log.Println("got: ", message) log.Println("correct:", correct) - if n != len(data) { - log.Fatalf("n incorrect: %d != %d\n", n, len(data)) + if n != len(flat) { + log.Fatalf("n incorrect: %d != %d\n", n, len(flat)) } if !reflect.DeepEqual(message, correct) { log.Fatalln("not equal") @@ -77,7 +84,7 @@ func testGenerateRun(test *testing.T, protocol *Protocol, imports string, testCa imports, setup, testCase, teardown, static) // build and run test - command := exec.Command("go", "run", "./generate/test/generate-run") + command := exec.Command("go", "run", "./" + filepath.Join("generate", dir)) workingDirAbs, err := filepath.Abs("..") if err != nil { test.Fatal(err) } command.Dir = workingDirAbs