package generate import "io" import "fmt" import "bufio" import "strings" const send = `// Send sends one message along a transaction. func Send(trans hopp.Trans, message hopp.Message) error { buffer, err := message.MarshalBinary() if err != nil { return err } return trans.Send(message.Method(), buffer) } ` // ResolveType resolves a HOPP type name to a Go type. For now, it supports all // data types defined in TAPE. func (this *Protocol) ResolveType(hopp string) (string, error) { switch hopp { case "I8": return "int8", nil case "I16": return "int16", nil case "I32": return "int32", nil case "I64": return "int64", nil case "U8": return "uint8", nil case "U16": return "uint16", nil case "U32": return "uint32", nil case "U64": return "uint64", nil case "I8Array": return "[]int8", nil case "I16Array": return "[]int16", nil case "I32Array": return "[]int32", nil case "I64Array": return "[]int64", nil case "U8Array": return "[]uint8", nil case "U16Array": return "[]uint16", nil case "U32Array": return "[]uint32", nil case "U64Array": return "[]uint64", nil case "String": return "string", nil case "StringArray": return "[]string", nil default: return "", fmt.Errorf("unknown type: %s", hopp) } } // Generate turns this protocol into code. The package name for the generated // code must be specified. func (this *Protocol) Generate(writer io.Writer, packag string) error { out := bufio.NewWriter(writer) defer out.Flush() fmt.Fprintf(out, "package %s\n\n", packag) fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp\"\n") fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp/tape\"\n\n") fmt.Fprintf(out, send) this.receive(out) for _, message := range this.Messages { err := this.defineMessage(out, message) if err != nil { return err } err = this.marshalMessage(out, message) if err != nil { return err } err = this.unmarshalMessage(out, message) if err != nil { return err } } return nil } func (this *Protocol) receive(out io.Writer) error { fmt.Fprintf(out, "// Receive receives one message from a transaction.\n") fmt.Fprintf(out, "func Receive(trans hopp.Trans) (hopp.Message, error) {\n") fmt.Fprintf(out, "\tmethod, data, err := trans.Receive()\n") fmt.Fprintf(out, "\tif err != nil { return nil, err }\n") fmt.Fprintf(out, "\tswitch method {\n") for _, message := range this.Messages { fmt.Fprintf(out, "\tcase 0x%04X:\n", message.Method) fmt.Fprintf(out, "\t\tmessage := &Message%s { }\n", message.Name) fmt.Fprintf(out, "\t\terr := message.UnmarshalBinary(data)\n") fmt.Fprintf(out, "\t\tif err != nil { return nil, err }\n") fmt.Fprintf(out, "\t\treturn message, nil\n") } fmt.Fprintf(out, "\tdefault: return nil, hopp.ErrUnknownMethod\n") fmt.Fprintf(out, "\t}\n") fmt.Fprintf(out, "}\n\n") return nil } func (this *Protocol) defineMessage(out io.Writer, message Message) error { fmt.Fprintln(out, comment("//", fmt.Sprintf("(%d) %s\n", message.Method, message.Doc))) fmt.Fprintf(out, "type Message%s struct {\n", message.Name) for _, field := range message.Fields { typ, err := this.ResolveType(field.Type) if err != nil { return err } if field.Doc != "" { fmt.Fprintf(out, "\t%s\n", comment("\t//", field.Doc)) } if field.Optional { typ = fmt.Sprintf("hopp.Option[%s]", typ) } fmt.Fprintf( out, "\t/* %d */ %s %s\n", field.Tag, field.Name, typ) } fmt.Fprintf(out, "}\n\n") fmt.Fprintf(out, "// Method returns the method number of the message.\n") fmt.Fprintf(out, "func (msg Message%s) Method() uint16 {\n", message.Name) fmt.Fprintf(out, "\treturn %d\n", message.Method) fmt.Fprintf(out, "}\n\n") return nil } func (this *Protocol) marshalMessage(out io.Writer, message Message) error { fmt.Fprintf(out, "// MarshalBinary encodes the data in this message into a buffer.\n") fmt.Fprintf(out, "func (msg *Message%s) MarshalBinary() ([]byte, error) {\n", message.Name) requiredCount := 0 for _, field := range message.Fields { if !field.Optional { requiredCount ++ } } fmt.Fprintf(out, "\tsize := 0\n") fmt.Fprintf(out, "\tcount := %d\n", requiredCount) for _, field := range message.Fields { fmt.Fprintf(out, "\toffset%s := size\n", field.Name) if field.Optional { fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name) fmt.Fprintf(out, "\t\tcount ++\n") fmt.Fprintf(out, "\t\t") err := this.marshalSizeOf(out, field) if err != nil { return err } fmt.Fprintf(out, " }\n") } else { fmt.Fprintf(out, "\t{") fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name) fmt.Fprintf(out, "\t\t") err := this.marshalSizeOf(out, field) if err != nil { return err } fmt.Fprintf(out, " }\n") } } fmt.Fprintf(out, "\tif size > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n") fmt.Fprintf(out, "\tif count > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n") fmt.Fprintf(out, "\tbuffer := make([]byte, 2 + 4 * count + size)\n") fmt.Fprintf(out, "\ttape.EncodeI16(buffer[:2], uint16(count))\n") for _, field := range message.Fields { if field.Optional { fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name) fmt.Fprintf(out, "\t\t") err := this.marshalField(out, field) if err != nil { return err } fmt.Fprintf(out, "}\n") } else { fmt.Fprintf(out, "\t{") fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name) fmt.Fprintf(out, "\t\t") err := this.marshalField(out, field) if err != nil { return err } fmt.Fprintf(out, "}\n") } } fmt.Fprintf(out, "\treturn buffer, nil\n") fmt.Fprintf(out, "}\n\n") return nil } func (this *Protocol) marshalSizeOf(out io.Writer, field Field) error { switch field.Type { case "I8": fmt.Fprintf(out, "size += 1; _ = value") case "I16": fmt.Fprintf(out, "size += 2; _ = value") case "I32": fmt.Fprintf(out, "size += 4; _ = value") case "I64": fmt.Fprintf(out, "size += 8; _ = value") case "U8": fmt.Fprintf(out, "size += 1; _ = value") case "U16": fmt.Fprintf(out, "size += 2; _ = value") case "U32": fmt.Fprintf(out, "size += 4; _ = value") case "U64": fmt.Fprintf(out, "size += 8; _ = value") case "I8Array": fmt.Fprintf(out, "size += len(value)") case "I16Array": fmt.Fprintf(out, "size += len(value) * 2") case "I32Array": fmt.Fprintf(out, "size += len(value) * 4") case "I64Array": fmt.Fprintf(out, "size += len(value) * 8") case "U8Array": fmt.Fprintf(out, "size += len(value)") case "U16Array": fmt.Fprintf(out, "size += len(value) * 2") case "U32Array": fmt.Fprintf(out, "size += len(value) * 4") case "U64Array": fmt.Fprintf(out, "size += len(value) * 8") case "String": fmt.Fprintf(out, "size += len(value)") case "StringArray": fmt.Fprintf( out, "for _, el := range value { size += 2 + len(el) }") default: return fmt.Errorf("unknown type: %s", field.Type) } return nil } func (this *Protocol) marshalField(out io.Writer, field Field) error { switch field.Type { case "I8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name) case "I16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name) case "I32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name) case "I64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name) case "U8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name) case "U16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name) case "U32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name) case "U64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name) case "I8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name) case "I16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name) case "I32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name) case "I64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name) case "U8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name) case "U16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name) case "U32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name) case "U64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name) case "String": fmt.Fprintf(out, "tape.EncodeString(buffer[offset%s:], value)", field.Name) case "StringArray": fmt.Fprintf(out, "tape.EncodeStringArray(buffer[offset%s:], value)", field.Name) default: return fmt.Errorf("unknown type: %s", field.Type) } return nil } func (this *Protocol) unmarshalMessage(out io.Writer, message Message) error { fmt.Fprintf(out, "// UnmarshalBinary dencodes the data from a buffer int this message.\n") fmt.Fprintf(out, "func (msg *Message%s) UnmarshalBinary(buffer []byte) error {\n", message.Name) if len(message.Fields) < 1 { fmt.Fprintf(out, "\t// no fields\n") fmt.Fprintf(out, "\treturn nil\n") fmt.Fprintf(out, "}\n\n") return nil } fmt.Fprintf(out, "\tpairs, err := tape.DecodePairs(buffer)\n") fmt.Fprintf(out, "\tif err != nil { return err }\n") requiredTotal := 0 for _, field := range message.Fields { if field.Optional { requiredTotal ++ } } if requiredTotal > 0 { fmt.Fprintf(out, "\tfoundRequired := 0\n") } fmt.Fprintf(out, "\tfor tag, data := range pairs {\n") fmt.Fprintf(out, "\t\tswitch tag {\n") for _, field := range message.Fields { fmt.Fprintf(out, "\t\tcase %d:\n", field.Tag) fmt.Fprintf(out, "\t\t\t") err := this.unmarshalField(out, field) if err != nil { return err } fmt.Fprintf(out, "\n") fmt.Fprintf(out, "\t\t\tif err != nil { return err }\n") if field.Optional { fmt.Fprintf(out, "\t\t\tmsg.%s = hopp.O(value)\n", field.Name) } else { fmt.Fprintf(out, "\t\t\tmsg.%s = value\n", field.Name) if requiredTotal > 0 { fmt.Fprintf(out, "\t\t\tfoundRequired ++\n") } } } fmt.Fprintf(out, "\t\t}\n") fmt.Fprintf(out, "\t}\n") if requiredTotal > 0 { fmt.Fprintf(out, "\tif foundRequired != %d { return hopp.ErrTablePairMissing }\n", requiredTotal) } fmt.Fprintf(out, "\treturn nil\n") fmt.Fprintf(out, "}\n\n") return nil } func (this *Protocol) unmarshalField(out io.Writer, field Field) error { typ, err := this.ResolveType(field.Type) if err != nil { return err } switch field.Type { case "I8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ) case "I16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ) case "I32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ) case "I64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ) case "U8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ) case "U16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ) case "U32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ) case "U64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ) case "I8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ) case "I16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ) case "I32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ) case "I64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ) case "U8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ) case "U16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ) case "U32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ) case "U64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ) case "String": fmt.Fprintf(out, "value, err := tape.DecodeString[%s](data)", typ) case "StringArray": fmt.Fprintf(out, "value, err := tape.DecodeStringArray[%s](data)", typ) default: return fmt.Errorf("unknown type: %s", field.Type) } return nil } func comment(prefix, text string) string { return prefix + " " + strings.ReplaceAll(strings.TrimSpace(text), "\n", "\n" + prefix + " ") }