diff --git a/generate/generate.go b/generate/generate.go index cf60739..6f3a1cd 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -4,6 +4,16 @@ import "io" import "fmt" import "bufio" +const send = +`// Send sends one message along a transaction. +func Send(trans hopp.Trans, message hopp.Message) error { + buffer, err := message.MarshalBinary() + if err != nil { return err } + return trans.Send(message.Method(), buffer) +} + +` + // ResolveType resolves a HOPP type name to a Go type. For now, it supports all // data types defined in TAPE. func (this *Protocol) ResolveType(hopp string) (string, error) { @@ -40,6 +50,9 @@ func (this *Protocol) Generate(writer io.Writer, packag string) error { fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp\"\n") fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp/tape\"\n\n") + fmt.Fprintf(out, send) + this.receive(out) + for _, message := range this.Messages { err := this.defineMessage(out, message) if err != nil { return err } @@ -53,6 +66,25 @@ func (this *Protocol) Generate(writer io.Writer, packag string) error { return nil } +func (this *Protocol) receive(out io.Writer) error { + fmt.Fprintf(out, "// Receive receives one message from a transaction.\n") + fmt.Fprintf(out, "func Receive(trans hopp.Trans) (hopp.Message, error) {\n") + fmt.Fprintf(out, "\tmethod, data, err := trans.Receive()\n") + fmt.Fprintf(out, "\tif err != nil { return nil, err }\n") + fmt.Fprintf(out, "\tswitch method {\n") + for _, message := range this.Messages { + fmt.Fprintf(out, "\tcase 0x%04X:\n", message.Method) + fmt.Fprintf(out, "\t\tmessage := &Message%s { }\n", message.Name) + fmt.Fprintf(out, "\t\terr := message.UnmarshalBinary(data)\n") + fmt.Fprintf(out, "\t\tif err != nil { return nil, err }\n") + fmt.Fprintf(out, "\t\treturn message, nil\n") + } + fmt.Fprintf(out, "\tdefault: return nil, hopp.ErrUnknownMethod\n") + fmt.Fprintf(out, "\t}\n") + fmt.Fprintf(out, "}\n\n") + return nil +} + func (this *Protocol) defineMessage(out io.Writer, message Message) error { fmt.Fprintf(out, "// (%d) %s\n", message.Method, message.Doc) fmt.Fprintf(out, "type Message%s struct {\n", message.Name) @@ -80,7 +112,7 @@ func (this *Protocol) defineMessage(out io.Writer, message Message) error { func (this *Protocol) marshalMessage(out io.Writer, message Message) error { fmt.Fprintf(out, "// MarshalBinary encodes the data in this message into a buffer.\n") - fmt.Fprintf(out, "func (msg Message%s) MarshalBinary() ([]byte, error) {\n", message.Name) + fmt.Fprintf(out, "func (msg *Message%s) MarshalBinary() ([]byte, error) {\n", message.Name) requiredCount := 0 for _, field := range message.Fields { if !field.Optional { requiredCount ++ } @@ -188,7 +220,7 @@ func (this *Protocol) marshalField(out io.Writer, field Field) error { func (this *Protocol) unmarshalMessage(out io.Writer, message Message) error { fmt.Fprintf(out, "// UnmarshalBinary dencodes the data from a buffer int this message.\n") fmt.Fprintf(out, - "func (msg Message%s) UnmarshalBinary(buffer []byte) error {\n", + "func (msg *Message%s) UnmarshalBinary(buffer []byte) error {\n", message.Name) fmt.Fprintf(out, "\tpairs, err := tape.DecodePairs(buffer)\n") fmt.Fprintf(out, "\tif err != nil { return err }\n")