297 lines
12 KiB
Go
297 lines
12 KiB
Go
package generate
|
|
|
|
import "io"
|
|
import "fmt"
|
|
import "bufio"
|
|
import "strings"
|
|
|
|
const send =
|
|
`// Send sends one message along a transaction.
|
|
func Send(trans hopp.Trans, message hopp.Message) error {
|
|
buffer, err := message.MarshalBinary()
|
|
if err != nil { return err }
|
|
return trans.Send(message.Method(), buffer)
|
|
}
|
|
|
|
`
|
|
|
|
// ResolveType resolves a HOPP type name to a Go type. For now, it supports all
|
|
// data types defined in TAPE.
|
|
func (this *Protocol) ResolveType(hopp string) (string, error) {
|
|
switch hopp {
|
|
case "I8": return "int8", nil
|
|
case "I16": return "int16", nil
|
|
case "I32": return "int32", nil
|
|
case "I64": return "int64", nil
|
|
case "U8": return "uint8", nil
|
|
case "U16": return "uint16", nil
|
|
case "U32": return "uint32", nil
|
|
case "U64": return "uint64", nil
|
|
case "I8Array": return "[]int8", nil
|
|
case "I16Array": return "[]int16", nil
|
|
case "I32Array": return "[]int32", nil
|
|
case "I64Array": return "[]int64", nil
|
|
case "U8Array": return "[]uint8", nil
|
|
case "U16Array": return "[]uint16", nil
|
|
case "U32Array": return "[]uint32", nil
|
|
case "U64Array": return "[]uint64", nil
|
|
case "String": return "string", nil
|
|
case "StringArray": return "[]string", nil
|
|
default: return "", fmt.Errorf("unknown type: %s", hopp)
|
|
}
|
|
}
|
|
|
|
// Generate turns this protocol into code. The package name for the generated
|
|
// code must be specified.
|
|
func (this *Protocol) Generate(writer io.Writer, packag string) error {
|
|
out := bufio.NewWriter(writer)
|
|
defer out.Flush()
|
|
|
|
fmt.Fprintf(out, "package %s\n\n", packag)
|
|
fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp\"\n")
|
|
fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp/tape\"\n\n")
|
|
|
|
fmt.Fprintf(out, send)
|
|
this.receive(out)
|
|
|
|
for _, message := range this.Messages {
|
|
err := this.defineMessage(out, message)
|
|
if err != nil { return err }
|
|
err = this.marshalMessage(out, message)
|
|
if err != nil { return err }
|
|
err = this.unmarshalMessage(out, message)
|
|
if err != nil { return err }
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (this *Protocol) receive(out io.Writer) error {
|
|
fmt.Fprintf(out, "// Receive receives one message from a transaction.\n")
|
|
fmt.Fprintf(out, "func Receive(trans hopp.Trans) (hopp.Message, error) {\n")
|
|
fmt.Fprintf(out, "\tmethod, data, err := trans.Receive()\n")
|
|
fmt.Fprintf(out, "\tif err != nil { return nil, err }\n")
|
|
fmt.Fprintf(out, "\tswitch method {\n")
|
|
for _, message := range this.Messages {
|
|
fmt.Fprintf(out, "\tcase 0x%04X:\n", message.Method)
|
|
fmt.Fprintf(out, "\t\tmessage := &Message%s { }\n", message.Name)
|
|
fmt.Fprintf(out, "\t\terr := message.UnmarshalBinary(data)\n")
|
|
fmt.Fprintf(out, "\t\tif err != nil { return nil, err }\n")
|
|
fmt.Fprintf(out, "\t\treturn message, nil\n")
|
|
}
|
|
fmt.Fprintf(out, "\tdefault: return nil, hopp.ErrUnknownMethod\n")
|
|
fmt.Fprintf(out, "\t}\n")
|
|
fmt.Fprintf(out, "}\n\n")
|
|
return nil
|
|
}
|
|
|
|
func (this *Protocol) defineMessage(out io.Writer, message Message) error {
|
|
fmt.Fprintln(out, comment("//", fmt.Sprintf("(%d) %s\n", message.Method, message.Doc)))
|
|
fmt.Fprintf(out, "type Message%s struct {\n", message.Name)
|
|
for _, field := range message.Fields {
|
|
typ, err := this.ResolveType(field.Type)
|
|
if err != nil { return err }
|
|
if field.Doc != "" {
|
|
fmt.Fprintf(out, "\t%s\n", comment("\t//", field.Doc))
|
|
}
|
|
if field.Optional {
|
|
typ = fmt.Sprintf("hopp.Option[%s]", typ)
|
|
}
|
|
fmt.Fprintf(
|
|
out, "\t/* %d */ %s %s\n",
|
|
field.Tag, field.Name, typ)
|
|
}
|
|
fmt.Fprintf(out, "}\n\n")
|
|
|
|
fmt.Fprintf(out, "// Method returns the method number of the message.\n")
|
|
fmt.Fprintf(out, "func (msg Message%s) Method() uint16 {\n", message.Name)
|
|
fmt.Fprintf(out, "\treturn %d\n", message.Method)
|
|
fmt.Fprintf(out, "}\n\n")
|
|
return nil
|
|
}
|
|
|
|
func (this *Protocol) marshalMessage(out io.Writer, message Message) error {
|
|
fmt.Fprintf(out, "// MarshalBinary encodes the data in this message into a buffer.\n")
|
|
fmt.Fprintf(out, "func (msg *Message%s) MarshalBinary() ([]byte, error) {\n", message.Name)
|
|
requiredCount := 0
|
|
for _, field := range message.Fields {
|
|
if !field.Optional { requiredCount ++ }
|
|
}
|
|
fmt.Fprintf(out, "\tsize := 0\n")
|
|
fmt.Fprintf(out, "\tcount := %d\n", requiredCount)
|
|
for _, field := range message.Fields {
|
|
fmt.Fprintf(out, "\toffset%s := size\n", field.Name)
|
|
if field.Optional {
|
|
fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name)
|
|
fmt.Fprintf(out, "\t\tcount ++\n")
|
|
fmt.Fprintf(out, "\t\t")
|
|
err := this.marshalSizeOf(out, field)
|
|
if err != nil { return err }
|
|
fmt.Fprintf(out, " }\n")
|
|
} else {
|
|
fmt.Fprintf(out, "\t{")
|
|
fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name)
|
|
fmt.Fprintf(out, "\t\t")
|
|
err := this.marshalSizeOf(out, field)
|
|
if err != nil { return err }
|
|
fmt.Fprintf(out, " }\n")
|
|
}
|
|
}
|
|
fmt.Fprintf(out, "\tif size > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n")
|
|
fmt.Fprintf(out, "\tif count > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n")
|
|
fmt.Fprintf(out, "\tbuffer := make([]byte, 2 + 4 * count + size)\n")
|
|
fmt.Fprintf(out, "\ttape.EncodeI16(buffer[:2], uint16(count))\n")
|
|
for _, field := range message.Fields {
|
|
if field.Optional {
|
|
fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name)
|
|
fmt.Fprintf(out, "\t\t")
|
|
err := this.marshalField(out, field)
|
|
if err != nil { return err }
|
|
fmt.Fprintf(out, "}\n")
|
|
} else {
|
|
fmt.Fprintf(out, "\t{")
|
|
fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name)
|
|
fmt.Fprintf(out, "\t\t")
|
|
err := this.marshalField(out, field)
|
|
if err != nil { return err }
|
|
fmt.Fprintf(out, "}\n")
|
|
}
|
|
}
|
|
fmt.Fprintf(out, "\treturn buffer, nil\n")
|
|
fmt.Fprintf(out, "}\n\n")
|
|
return nil
|
|
}
|
|
|
|
func (this *Protocol) marshalSizeOf(out io.Writer, field Field) error {
|
|
switch field.Type {
|
|
case "I8": fmt.Fprintf(out, "size += 1; _ = value")
|
|
case "I16": fmt.Fprintf(out, "size += 2; _ = value")
|
|
case "I32": fmt.Fprintf(out, "size += 4; _ = value")
|
|
case "I64": fmt.Fprintf(out, "size += 8; _ = value")
|
|
case "U8": fmt.Fprintf(out, "size += 1; _ = value")
|
|
case "U16": fmt.Fprintf(out, "size += 2; _ = value")
|
|
case "U32": fmt.Fprintf(out, "size += 4; _ = value")
|
|
case "U64": fmt.Fprintf(out, "size += 8; _ = value")
|
|
case "I8Array": fmt.Fprintf(out, "size += len(value)")
|
|
case "I16Array": fmt.Fprintf(out, "size += len(value) * 2")
|
|
case "I32Array": fmt.Fprintf(out, "size += len(value) * 4")
|
|
case "I64Array": fmt.Fprintf(out, "size += len(value) * 8")
|
|
case "U8Array": fmt.Fprintf(out, "size += len(value)")
|
|
case "U16Array": fmt.Fprintf(out, "size += len(value) * 2")
|
|
case "U32Array": fmt.Fprintf(out, "size += len(value) * 4")
|
|
case "U64Array": fmt.Fprintf(out, "size += len(value) * 8")
|
|
case "String": fmt.Fprintf(out, "size += len(value)")
|
|
case "StringArray":
|
|
fmt.Fprintf(
|
|
out,
|
|
"for _, el := range value { size += 2 + len(el) }")
|
|
default:
|
|
return fmt.Errorf("unknown type: %s", field.Type)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (this *Protocol) marshalField(out io.Writer, field Field) error {
|
|
switch field.Type {
|
|
case "I8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name)
|
|
case "I16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name)
|
|
case "I32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name)
|
|
case "I64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name)
|
|
case "U8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name)
|
|
case "U16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name)
|
|
case "U32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name)
|
|
case "U64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name)
|
|
case "I8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name)
|
|
case "I16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name)
|
|
case "I32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name)
|
|
case "I64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name)
|
|
case "U8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name)
|
|
case "U16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name)
|
|
case "U32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name)
|
|
case "U64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name)
|
|
case "String": fmt.Fprintf(out, "tape.EncodeString(buffer[offset%s:], value)", field.Name)
|
|
case "StringArray": fmt.Fprintf(out, "tape.EncodeStringArray(buffer[offset%s:], value)", field.Name)
|
|
default:
|
|
return fmt.Errorf("unknown type: %s", field.Type)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (this *Protocol) unmarshalMessage(out io.Writer, message Message) error {
|
|
fmt.Fprintf(out, "// UnmarshalBinary dencodes the data from a buffer int this message.\n")
|
|
fmt.Fprintf(out,
|
|
"func (msg *Message%s) UnmarshalBinary(buffer []byte) error {\n",
|
|
message.Name)
|
|
fmt.Fprintf(out, "\tpairs, err := tape.DecodePairs(buffer)\n")
|
|
fmt.Fprintf(out, "\tif err != nil { return err }\n")
|
|
requiredTotal := 0
|
|
for _, field := range message.Fields {
|
|
if field.Optional {
|
|
requiredTotal ++
|
|
}
|
|
}
|
|
if requiredTotal > 0 {
|
|
fmt.Fprintf(out, "\tfoundRequired := 0\n")
|
|
}
|
|
fmt.Fprintf(out, "\tfor tag, data := range pairs {\n")
|
|
fmt.Fprintf(out, "\t\tswitch tag {\n")
|
|
for _, field := range message.Fields {
|
|
fmt.Fprintf(out, "\t\tcase %d:\n", field.Tag)
|
|
fmt.Fprintf(out, "\t\t\t")
|
|
err := this.unmarshalField(out, field)
|
|
if err != nil { return err }
|
|
fmt.Fprintf(out, "\n")
|
|
fmt.Fprintf(out, "\t\t\tif err != nil { return err }\n")
|
|
if field.Optional {
|
|
fmt.Fprintf(out, "\t\t\tmsg.%s.Set(value)\n", field.Name)
|
|
} else {
|
|
fmt.Fprintf(out, "\t\t\tmsg.%s = value\n", field.Name)
|
|
if requiredTotal > 0 {
|
|
fmt.Fprintf(out, "\t\t\tfoundRequired ++\n")
|
|
}
|
|
}
|
|
}
|
|
fmt.Fprintf(out, "\t\t}\n")
|
|
fmt.Fprintf(out, "\t}\n")
|
|
if requiredTotal > 0 {
|
|
fmt.Fprintf(out,
|
|
"\tif foundRequired != %d { return hopp.ErrTablePairMissing }\n",
|
|
requiredTotal)
|
|
}
|
|
fmt.Fprintf(out, "\treturn nil\n")
|
|
fmt.Fprintf(out, "}\n\n")
|
|
return nil
|
|
}
|
|
|
|
func (this *Protocol) unmarshalField(out io.Writer, field Field) error {
|
|
typ, err := this.ResolveType(field.Type)
|
|
if err != nil { return err }
|
|
switch field.Type {
|
|
case "I8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ)
|
|
case "I16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ)
|
|
case "I32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ)
|
|
case "I64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ)
|
|
case "U8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ)
|
|
case "U16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ)
|
|
case "U32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ)
|
|
case "U64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ)
|
|
case "I8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ)
|
|
case "I16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ)
|
|
case "I32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ)
|
|
case "I64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ)
|
|
case "U8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ)
|
|
case "U16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ)
|
|
case "U32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ)
|
|
case "U64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ)
|
|
case "String": fmt.Fprintf(out, "value, err := tape.DecodeString[%s](data)", typ)
|
|
case "StringArray": fmt.Fprintf(out, "value, err := tape.DecodeStringArray[%s](data)", typ)
|
|
default:
|
|
return fmt.Errorf("unknown type: %s", field.Type)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func comment(prefix, text string) string {
|
|
return prefix + " " + strings.ReplaceAll(strings.TrimSpace(text), "\n", "\n" + prefix + " ")
|
|
}
|