132 lines
4.0 KiB
Go
132 lines
4.0 KiB
Go
package generate
|
|
|
|
import "os"
|
|
import "fmt"
|
|
import "os/exec"
|
|
import "testing"
|
|
import "path/filepath"
|
|
|
|
func testGenerateRun(test *testing.T, protocol *Protocol, title, imports, testCase string) {
|
|
// reset data directory
|
|
dir := filepath.Join("test", title)
|
|
err := os.RemoveAll(dir)
|
|
if err != nil { test.Fatal(err) }
|
|
err = os.MkdirAll(dir, 0750)
|
|
if err != nil { test.Fatal(err) }
|
|
|
|
// open files
|
|
sourceFile, err := os.Create(filepath.Join(dir, "protocol.go"))
|
|
if err != nil { test.Fatal(err) }
|
|
defer sourceFile.Close()
|
|
mainFile, err := os.Create(filepath.Join(dir, "main.go"))
|
|
if err != nil { test.Fatal(err) }
|
|
defer mainFile.Close()
|
|
|
|
// generate protocol
|
|
generator := Generator {
|
|
Output: sourceFile,
|
|
PackageName: "main",
|
|
}
|
|
_, err = generator.Generate(protocol)
|
|
if err != nil { test.Fatal(err) }
|
|
|
|
// build static source files
|
|
imports = `
|
|
import "log"
|
|
import "bytes"
|
|
import "reflect"
|
|
import "git.tebibyte.media/sashakoshka/hopp/tape"
|
|
import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil"
|
|
` + imports
|
|
setup := `log.Println("*** BEGIN TEST CASE OUTPUT ***")`
|
|
teardown := `log.Println("--- END TEST CASE OUTPUT ---")`
|
|
static := `
|
|
func testEncode(message Message, correct tu.Snake) {
|
|
buffer := bytes.Buffer { }
|
|
encoder := tape.NewEncoder(&buffer)
|
|
n, err := message.Encode(encoder)
|
|
if err != nil { log.Fatalf("at %d: %v\n", n, err) }
|
|
encoder.Flush()
|
|
got := buffer.Bytes()
|
|
log.Printf("got: [%s]", tu.HexBytes(got))
|
|
log.Println("correct:", correct)
|
|
if n != len(got) {
|
|
log.Fatalf("n incorrect: %d != %d\n", n, len(got))
|
|
}
|
|
if ok, n := correct.Check(got); !ok {
|
|
log.Fatalln("not equal at", n)
|
|
}
|
|
}
|
|
|
|
func testDecode(correct Message, data any) {
|
|
var flat []byte
|
|
switch data := data.(type) {
|
|
case []byte: flat = data
|
|
case tu.Snake: flat = data.Flatten()
|
|
}
|
|
message := reflect.New(reflect.ValueOf(correct).Elem().Type()).Interface().(Message)
|
|
log.Println("before: ", message)
|
|
decoder := tape.NewDecoder(bytes.NewBuffer(flat))
|
|
n, err := message.Decode(decoder)
|
|
if err != nil { log.Fatalf("at %d: %v\n", n, err) }
|
|
log.Println("got: ", message)
|
|
log.Println("correct:", correct)
|
|
if n != len(flat) {
|
|
log.Fatalf("n incorrect: %d != %d\n", n, len(flat))
|
|
}
|
|
if !reflect.DeepEqual(message, correct) {
|
|
log.Fatalln("not equal")
|
|
}
|
|
}
|
|
|
|
// TODO: possibly combine the two above functions into this one,
|
|
// also take a data parameter here (snake)
|
|
func testEncodeDecode(message Message, data tu.Snake) {buffer := bytes.Buffer { }
|
|
log.Println("encoding:")
|
|
encoder := tape.NewEncoder(&buffer)
|
|
n, err := message.Encode(encoder)
|
|
if err != nil { log.Fatalf("at %d: %v\n", n, err) }
|
|
encoder.Flush()
|
|
got := buffer.Bytes()
|
|
log.Printf("got: [%s]", tu.HexBytes(got))
|
|
log.Println("correct:", data)
|
|
if n != len(got) {
|
|
log.Fatalf("n incorrect: %d != %d\n", n, len(got))
|
|
}
|
|
if ok, n := data.Check(got); !ok {
|
|
log.Fatalln("not equal at", n)
|
|
}
|
|
|
|
log.Println("decoding:")
|
|
destination := reflect.New(reflect.ValueOf(message).Elem().Type()).Interface().(Message)
|
|
flat := data.Flatten()
|
|
log.Println("before: ", destination)
|
|
decoder := tape.NewDecoder(bytes.NewBuffer(flat))
|
|
n, err = destination.Decode(decoder)
|
|
if err != nil { log.Fatalf("at %d: %v\n", n, err) }
|
|
log.Println("got: ", destination)
|
|
log.Println("correct:", message)
|
|
if n != len(flat) {
|
|
log.Fatalf("n incorrect: %d != %d\n", n, len(flat))
|
|
}
|
|
if !reflect.DeepEqual(destination, message) {
|
|
log.Fatalln("not equal")
|
|
}
|
|
|
|
}
|
|
`
|
|
fmt.Fprintf(
|
|
mainFile, "package main\n%s\nfunc main() {\n%s\n%s\n%s\n}\n%s",
|
|
imports, setup, testCase, teardown, static)
|
|
|
|
// build and run test
|
|
command := exec.Command("go", "run", "./" + filepath.Join("generate", dir))
|
|
workingDirAbs, err := filepath.Abs("..")
|
|
if err != nil { test.Fatal(err) }
|
|
command.Dir = workingDirAbs
|
|
command.Env = os.Environ()
|
|
output, err := command.CombinedOutput()
|
|
test.Logf("output of %v:\n%s", command, output)
|
|
if err != nil { test.Fatal(err) }
|
|
}
|