diff --git a/generate/generate.go b/generate/generate.go index e5ed365..72f3822 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -1,6 +1,258 @@ -package protocol +package generate -// Generate turns this protocol into code. -func (this *Protocol) Generate(writer io.Writer) error { - // TODO +import "io" +import "fmt" +import "bufio" + +// ResolveType resolves a HOPP type name to a Go type. For now, it supports all +// data types defined in TAPE. +func (this *Protocol) ResolveType(hopp string) (string, error) { + switch hopp { + case "I8": return "int8", nil + case "I16": return "int16", nil + case "I32": return "int32", nil + case "I64": return "int64", nil + case "U8": return "uint8", nil + case "U16": return "uint16", nil + case "U32": return "uint32", nil + case "U64": return "uint64", nil + case "I8Array": return "[]int8", nil + case "I16Array": return "[]int16", nil + case "I32Array": return "[]int32", nil + case "I64Array": return "[]int64", nil + case "U8Array": return "[]uint8", nil + case "U16Array": return "[]uint16", nil + case "U32Array": return "[]uint32", nil + case "U64Array": return "[]uint64", nil + case "String": return "string", nil + case "StringArray": return "[]string", nil + default: return "", fmt.Errorf("unknown type: %s", hopp) + } +} + +// Generate turns this protocol into code. The package name for the generated +// code must be specified. +func (this *Protocol) Generate(writer io.Writer, packag string) error { + out := bufio.NewWriter(writer) + defer out.Flush() + + fmt.Fprintf(out, "package %s\n\n", packag) + fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp\"\n") + fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp/tape\"\n\n") + + for _, message := range this.Messages { + err := this.defineMessage(out, message) + if err != nil { return err } + err = this.marshalMessage(out, message) + if err != nil { return err } + err = this.unmarshalMessage(out, message) + if err != nil { return err } + } + + // TODO + return nil +} + +func (this *Protocol) defineMessage(out io.Writer, message Message) error { + fmt.Fprintf(out, "// (%d) %s\n", message.Method, message.Doc) + fmt.Fprintf(out, "type Message%s struct {\n", message.Name) + for _, field := range message.Fields { + typ, err := this.ResolveType(field.Type) + if err != nil { return err } + if field.Doc != "" { + fmt.Fprintf(out, "\t// %s\n", field.Doc) + } + if field.Optional { + typ = fmt.Sprintf("hopp.Option[%s]", typ) + } + fmt.Fprintf( + out, "\t/* %d */ %s %s\n", + field.Tag, field.Name, typ) + } + fmt.Fprintf(out, "}\n\n") + + fmt.Fprintf(out, "// Method returns the method number of the message.\n") + fmt.Fprintf(out, "func (msg Message%s) Method() uint16 {\n", message.Name) + fmt.Fprintf(out, "\treturn %d\n", message.Method) + fmt.Fprintf(out, "}\n\n") + return nil +} + +func (this *Protocol) marshalMessage(out io.Writer, message Message) error { + fmt.Fprintf(out, "// MarshalBinary encodes the data in this message into a buffer.\n") + fmt.Fprintf(out, "func (msg Message%s) MarshalBinary() ([]byte, error) {\n", message.Name) + requiredCount := 0 + for _, field := range message.Fields { + if !field.Optional { requiredCount ++ } + } + fmt.Fprintf(out, "\tsize := 0\n") + fmt.Fprintf(out, "\tcount := %d\n", requiredCount) + for _, field := range message.Fields { + fmt.Fprintf(out, "\toffset%s := size\n", field.Name) + if field.Optional { + fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name) + fmt.Fprintf(out, "\t\tcount ++\n") + fmt.Fprintf(out, "\t\t") + err := this.marshalSizeOf(out, field) + if err != nil { return err } + fmt.Fprintf(out, " }\n") + } else { + fmt.Fprintf(out, "\t{") + fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name) + fmt.Fprintf(out, "\t\t") + err := this.marshalSizeOf(out, field) + if err != nil { return err } + fmt.Fprintf(out, " }\n") + } + } + fmt.Fprintf(out, "\tif size > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n") + fmt.Fprintf(out, "\tif count > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n") + fmt.Fprintf(out, "\tbuffer := make([]byte, 2 + 4 * count + size)\n") + fmt.Fprintf(out, "\ttape.EncodeI16(buffer[:2], uint16(count))\n") + for _, field := range message.Fields { + if field.Optional { + fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name) + fmt.Fprintf(out, "\t\t") + err := this.marshalField(out, field) + if err != nil { return err } + fmt.Fprintf(out, "}\n") + } else { + fmt.Fprintf(out, "\t{") + fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name) + fmt.Fprintf(out, "\t\t") + err := this.marshalField(out, field) + if err != nil { return err } + fmt.Fprintf(out, "}\n") + } + } + fmt.Fprintf(out, "\treturn buffer, nil\n") + fmt.Fprintf(out, "}\n\n") + return nil +} + +func (this *Protocol) marshalSizeOf(out io.Writer, field Field) error { + switch field.Type { + case "I8": fmt.Fprintf(out, "size += 1; _ = value") + case "I16": fmt.Fprintf(out, "size += 2; _ = value") + case "I32": fmt.Fprintf(out, "size += 4; _ = value") + case "I64": fmt.Fprintf(out, "size += 8; _ = value") + case "U8": fmt.Fprintf(out, "size += 1; _ = value") + case "U16": fmt.Fprintf(out, "size += 2; _ = value") + case "U32": fmt.Fprintf(out, "size += 4; _ = value") + case "U64": fmt.Fprintf(out, "size += 8; _ = value") + case "I8Array": fmt.Fprintf(out, "size += len(value)") + case "I16Array": fmt.Fprintf(out, "size += len(value) * 2") + case "I32Array": fmt.Fprintf(out, "size += len(value) * 4") + case "I64Array": fmt.Fprintf(out, "size += len(value) * 8") + case "U8Array": fmt.Fprintf(out, "size += len(value)") + case "U16Array": fmt.Fprintf(out, "size += len(value) * 2") + case "U32Array": fmt.Fprintf(out, "size += len(value) * 4") + case "U64Array": fmt.Fprintf(out, "size += len(value) * 8") + case "String": fmt.Fprintf(out, "size += len(value)") + case "StringArray": + fmt.Fprintf( + out, + "for _, el := range value { size += 2 + len(el) }") + default: + return fmt.Errorf("unknown type: %s", field.Type) + } + return nil +} + +func (this *Protocol) marshalField(out io.Writer, field Field) error { + switch field.Type { + case "I8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name) + case "I16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name) + case "I32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name) + case "I64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name) + case "U8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name) + case "U16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name) + case "U32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name) + case "U64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name) + case "I8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name) + case "I16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name) + case "I32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name) + case "I64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name) + case "U8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name) + case "U16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name) + case "U32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name) + case "U64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name) + case "String": fmt.Fprintf(out, "tape.EncodeString(buffer[offset%s:], value)", field.Name) + case "StringArray": fmt.Fprintf(out, "tape.EncodeStringArray(buffer[offset%s:], value)", field.Name) + default: + return fmt.Errorf("unknown type: %s", field.Type) + } + return nil +} + +func (this *Protocol) unmarshalMessage(out io.Writer, message Message) error { + fmt.Fprintf(out, "// UnmarshalBinary dencodes the data from a buffer int this message.\n") + fmt.Fprintf(out, + "func (msg Message%s) UnmarshalBinary(buffer []byte) error {\n", + message.Name) + fmt.Fprintf(out, "\tpairs, err := tape.DecodePairs(buffer)\n") + fmt.Fprintf(out, "\tif err != nil { return err }\n") + requiredTotal := 0 + for _, field := range message.Fields { + if field.Optional { + requiredTotal ++ + } + } + if requiredTotal > 0 { + fmt.Fprintf(out, "\tfoundRequired := 0\n") + } + fmt.Fprintf(out, "\tfor tag, data := range pairs {\n") + fmt.Fprintf(out, "\t\tswitch tag {\n") + for _, field := range message.Fields { + fmt.Fprintf(out, "\t\tcase %d:\n", field.Tag) + fmt.Fprintf(out, "\t\t\t") + err := this.unmarshalField(out, field) + if err != nil { return err } + fmt.Fprintf(out, "\n") + fmt.Fprintf(out, "\t\t\tif err != nil { return err }\n") + if field.Optional { + fmt.Fprintf(out, "\t\t\tmsg.%s.Set(value)\n", field.Name) + } else { + fmt.Fprintf(out, "\t\t\tmsg.%s = value\n", field.Name) + fmt.Fprintf(out, "\t\t\tfoundRequired ++\n") + } + } + fmt.Fprintf(out, "\t\t}\n") + fmt.Fprintf(out, "\t}\n") + if requiredTotal > 0 { + fmt.Fprintf(out, + "\tif foundRequired != %d { return hopp.ErrPairMissing }\n", + requiredTotal) + } + fmt.Fprintf(out, "\treturn nil\n") + fmt.Fprintf(out, "}\n\n") + return nil +} + +func (this *Protocol) unmarshalField(out io.Writer, field Field) error { + typ, err := this.ResolveType(field.Type) + if err != nil { return err } + switch field.Type { + case "I8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ) + case "I16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ) + case "I32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ) + case "I64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ) + case "U8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ) + case "U16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ) + case "U32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ) + case "U64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ) + case "I8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ) + case "I16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ) + case "I32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ) + case "I64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ) + case "U8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ) + case "U16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ) + case "U32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ) + case "U64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ) + case "String": fmt.Fprintf(out, "value, err := tape.DecodeString[%s](data)", typ) + case "StringArray": fmt.Fprintf(out, "value, err := tape.DecodeStringArray[%s](data)", typ) + default: + return fmt.Errorf("unknown type: %s", field.Type) + } + return nil }