hopp/generate/generate.go

303 lines
12 KiB
Go
Raw Permalink Normal View History

2025-01-21 14:22:48 -07:00
package generate
import "io"
import "fmt"
import "bufio"
2025-01-22 14:56:05 -07:00
import "strings"
2025-01-21 14:22:48 -07:00
const send =
`// Send sends one message along a transaction.
func Send(trans hopp.Trans, message hopp.Message) error {
buffer, err := message.MarshalBinary()
if err != nil { return err }
return trans.Send(message.Method(), buffer)
}
`
2025-01-21 14:22:48 -07:00
// ResolveType resolves a HOPP type name to a Go type. For now, it supports all
// data types defined in TAPE.
func (this *Protocol) ResolveType(hopp string) (string, error) {
switch hopp {
case "I8": return "int8", nil
case "I16": return "int16", nil
case "I32": return "int32", nil
case "I64": return "int64", nil
case "U8": return "uint8", nil
case "U16": return "uint16", nil
case "U32": return "uint32", nil
case "U64": return "uint64", nil
case "I8Array": return "[]int8", nil
case "I16Array": return "[]int16", nil
case "I32Array": return "[]int32", nil
case "I64Array": return "[]int64", nil
case "U8Array": return "[]uint8", nil
case "U16Array": return "[]uint16", nil
case "U32Array": return "[]uint32", nil
case "U64Array": return "[]uint64", nil
case "String": return "string", nil
case "StringArray": return "[]string", nil
default: return "", fmt.Errorf("unknown type: %s", hopp)
}
}
// Generate turns this protocol into code. The package name for the generated
// code must be specified.
func (this *Protocol) Generate(writer io.Writer, packag string) error {
out := bufio.NewWriter(writer)
defer out.Flush()
fmt.Fprintf(out, "package %s\n\n", packag)
fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp\"\n")
fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp/tape\"\n\n")
fmt.Fprintf(out, send)
this.receive(out)
2025-01-21 14:22:48 -07:00
for _, message := range this.Messages {
err := this.defineMessage(out, message)
if err != nil { return err }
err = this.marshalMessage(out, message)
if err != nil { return err }
err = this.unmarshalMessage(out, message)
if err != nil { return err }
}
2025-01-09 00:31:15 -07:00
2025-01-21 14:22:48 -07:00
return nil
}
func (this *Protocol) receive(out io.Writer) error {
fmt.Fprintf(out, "// Receive receives one message from a transaction.\n")
fmt.Fprintf(out, "func Receive(trans hopp.Trans) (hopp.Message, error) {\n")
fmt.Fprintf(out, "\tmethod, data, err := trans.Receive()\n")
fmt.Fprintf(out, "\tif err != nil { return nil, err }\n")
fmt.Fprintf(out, "\tswitch method {\n")
for _, message := range this.Messages {
fmt.Fprintf(out, "\tcase 0x%04X:\n", message.Method)
fmt.Fprintf(out, "\t\tmessage := &Message%s { }\n", message.Name)
fmt.Fprintf(out, "\t\terr := message.UnmarshalBinary(data)\n")
fmt.Fprintf(out, "\t\tif err != nil { return nil, err }\n")
fmt.Fprintf(out, "\t\treturn message, nil\n")
}
fmt.Fprintf(out, "\tdefault: return nil, hopp.ErrUnknownMethod\n")
fmt.Fprintf(out, "\t}\n")
fmt.Fprintf(out, "}\n\n")
return nil
}
2025-01-21 14:22:48 -07:00
func (this *Protocol) defineMessage(out io.Writer, message Message) error {
2025-01-22 14:56:05 -07:00
fmt.Fprintln(out, comment("//", fmt.Sprintf("(%d) %s\n", message.Method, message.Doc)))
2025-01-21 14:22:48 -07:00
fmt.Fprintf(out, "type Message%s struct {\n", message.Name)
for _, field := range message.Fields {
typ, err := this.ResolveType(field.Type)
if err != nil { return err }
if field.Doc != "" {
2025-01-22 14:56:05 -07:00
fmt.Fprintf(out, "\t%s\n", comment("\t//", field.Doc))
2025-01-21 14:22:48 -07:00
}
if field.Optional {
typ = fmt.Sprintf("hopp.Option[%s]", typ)
}
fmt.Fprintf(
out, "\t/* %d */ %s %s\n",
field.Tag, field.Name, typ)
}
fmt.Fprintf(out, "}\n\n")
fmt.Fprintf(out, "// Method returns the method number of the message.\n")
fmt.Fprintf(out, "func (msg Message%s) Method() uint16 {\n", message.Name)
fmt.Fprintf(out, "\treturn %d\n", message.Method)
fmt.Fprintf(out, "}\n\n")
return nil
}
func (this *Protocol) marshalMessage(out io.Writer, message Message) error {
fmt.Fprintf(out, "// MarshalBinary encodes the data in this message into a buffer.\n")
fmt.Fprintf(out, "func (msg *Message%s) MarshalBinary() ([]byte, error) {\n", message.Name)
2025-01-21 14:22:48 -07:00
requiredCount := 0
for _, field := range message.Fields {
if !field.Optional { requiredCount ++ }
}
fmt.Fprintf(out, "\tsize := 0\n")
fmt.Fprintf(out, "\tcount := %d\n", requiredCount)
for _, field := range message.Fields {
fmt.Fprintf(out, "\toffset%s := size\n", field.Name)
if field.Optional {
fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name)
fmt.Fprintf(out, "\t\tcount ++\n")
fmt.Fprintf(out, "\t\t")
err := this.marshalSizeOf(out, field)
if err != nil { return err }
fmt.Fprintf(out, " }\n")
} else {
fmt.Fprintf(out, "\t{")
fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name)
fmt.Fprintf(out, "\t\t")
err := this.marshalSizeOf(out, field)
if err != nil { return err }
fmt.Fprintf(out, " }\n")
}
}
fmt.Fprintf(out, "\tif size > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n")
fmt.Fprintf(out, "\tif count > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n")
fmt.Fprintf(out, "\tbuffer := make([]byte, 2 + 4 * count + size)\n")
fmt.Fprintf(out, "\ttape.EncodeI16(buffer[:2], uint16(count))\n")
for _, field := range message.Fields {
if field.Optional {
fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name)
fmt.Fprintf(out, "\t\t")
err := this.marshalField(out, field)
if err != nil { return err }
fmt.Fprintf(out, "}\n")
} else {
fmt.Fprintf(out, "\t{")
fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name)
fmt.Fprintf(out, "\t\t")
err := this.marshalField(out, field)
if err != nil { return err }
fmt.Fprintf(out, "}\n")
}
}
fmt.Fprintf(out, "\treturn buffer, nil\n")
fmt.Fprintf(out, "}\n\n")
return nil
}
func (this *Protocol) marshalSizeOf(out io.Writer, field Field) error {
switch field.Type {
case "I8": fmt.Fprintf(out, "size += 1; _ = value")
case "I16": fmt.Fprintf(out, "size += 2; _ = value")
case "I32": fmt.Fprintf(out, "size += 4; _ = value")
case "I64": fmt.Fprintf(out, "size += 8; _ = value")
case "U8": fmt.Fprintf(out, "size += 1; _ = value")
case "U16": fmt.Fprintf(out, "size += 2; _ = value")
case "U32": fmt.Fprintf(out, "size += 4; _ = value")
case "U64": fmt.Fprintf(out, "size += 8; _ = value")
case "I8Array": fmt.Fprintf(out, "size += len(value)")
case "I16Array": fmt.Fprintf(out, "size += len(value) * 2")
case "I32Array": fmt.Fprintf(out, "size += len(value) * 4")
case "I64Array": fmt.Fprintf(out, "size += len(value) * 8")
case "U8Array": fmt.Fprintf(out, "size += len(value)")
case "U16Array": fmt.Fprintf(out, "size += len(value) * 2")
case "U32Array": fmt.Fprintf(out, "size += len(value) * 4")
case "U64Array": fmt.Fprintf(out, "size += len(value) * 8")
case "String": fmt.Fprintf(out, "size += len(value)")
case "StringArray":
fmt.Fprintf(
out,
"for _, el := range value { size += 2 + len(el) }")
default:
return fmt.Errorf("unknown type: %s", field.Type)
}
return nil
}
func (this *Protocol) marshalField(out io.Writer, field Field) error {
switch field.Type {
case "I8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name)
case "I16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name)
case "I32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name)
case "I64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name)
case "U8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name)
case "U16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name)
case "U32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name)
case "U64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name)
case "I8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name)
case "I16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name)
case "I32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name)
case "I64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name)
case "U8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name)
case "U16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name)
case "U32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name)
case "U64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name)
case "String": fmt.Fprintf(out, "tape.EncodeString(buffer[offset%s:], value)", field.Name)
case "StringArray": fmt.Fprintf(out, "tape.EncodeStringArray(buffer[offset%s:], value)", field.Name)
default:
return fmt.Errorf("unknown type: %s", field.Type)
}
return nil
}
func (this *Protocol) unmarshalMessage(out io.Writer, message Message) error {
fmt.Fprintf(out, "// UnmarshalBinary dencodes the data from a buffer int this message.\n")
fmt.Fprintf(out,
"func (msg *Message%s) UnmarshalBinary(buffer []byte) error {\n",
2025-01-21 14:22:48 -07:00
message.Name)
2025-01-22 15:40:43 -07:00
if len(message.Fields) < 1 {
fmt.Fprintf(out, "\t// no fields\n")
fmt.Fprintf(out, "\treturn nil\n")
fmt.Fprintf(out, "}\n\n")
return nil
}
2025-01-21 14:22:48 -07:00
fmt.Fprintf(out, "\tpairs, err := tape.DecodePairs(buffer)\n")
fmt.Fprintf(out, "\tif err != nil { return err }\n")
requiredTotal := 0
for _, field := range message.Fields {
if field.Optional {
requiredTotal ++
}
}
if requiredTotal > 0 {
fmt.Fprintf(out, "\tfoundRequired := 0\n")
}
fmt.Fprintf(out, "\tfor tag, data := range pairs {\n")
fmt.Fprintf(out, "\t\tswitch tag {\n")
for _, field := range message.Fields {
fmt.Fprintf(out, "\t\tcase %d:\n", field.Tag)
fmt.Fprintf(out, "\t\t\t")
err := this.unmarshalField(out, field)
if err != nil { return err }
fmt.Fprintf(out, "\n")
fmt.Fprintf(out, "\t\t\tif err != nil { return err }\n")
if field.Optional {
2025-01-22 15:40:43 -07:00
fmt.Fprintf(out, "\t\t\tmsg.%s = hopp.O(value)\n", field.Name)
2025-01-21 14:22:48 -07:00
} else {
fmt.Fprintf(out, "\t\t\tmsg.%s = value\n", field.Name)
2025-01-22 14:56:05 -07:00
if requiredTotal > 0 {
fmt.Fprintf(out, "\t\t\tfoundRequired ++\n")
}
2025-01-21 14:22:48 -07:00
}
}
fmt.Fprintf(out, "\t\t}\n")
fmt.Fprintf(out, "\t}\n")
if requiredTotal > 0 {
fmt.Fprintf(out,
2025-01-21 14:37:50 -07:00
"\tif foundRequired != %d { return hopp.ErrTablePairMissing }\n",
2025-01-21 14:22:48 -07:00
requiredTotal)
}
fmt.Fprintf(out, "\treturn nil\n")
fmt.Fprintf(out, "}\n\n")
return nil
}
func (this *Protocol) unmarshalField(out io.Writer, field Field) error {
typ, err := this.ResolveType(field.Type)
if err != nil { return err }
switch field.Type {
case "I8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ)
case "I16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ)
case "I32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ)
case "I64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ)
case "U8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ)
case "U16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ)
case "U32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ)
case "U64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ)
case "I8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ)
case "I16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ)
case "I32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ)
case "I64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ)
case "U8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ)
case "U16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ)
case "U32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ)
case "U64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ)
case "String": fmt.Fprintf(out, "value, err := tape.DecodeString[%s](data)", typ)
case "StringArray": fmt.Fprintf(out, "value, err := tape.DecodeStringArray[%s](data)", typ)
default:
return fmt.Errorf("unknown type: %s", field.Type)
}
return nil
2025-01-09 00:31:15 -07:00
}
2025-01-22 14:56:05 -07:00
func comment(prefix, text string) string {
return prefix + " " + strings.ReplaceAll(strings.TrimSpace(text), "\n", "\n" + prefix + " ")
}