From d67a4fb9f2acb2749452075b20b817053728c44b Mon Sep 17 00:00:00 2001 From: Sasha Koshka Date: Wed, 4 Jun 2025 10:37:03 -0400 Subject: [PATCH] generate: Clean slate --- generate/generate.go | 302 ------------------------------------------- generate/protocol.go | 244 ---------------------------------- 2 files changed, 546 deletions(-) delete mode 100644 generate/generate.go delete mode 100644 generate/protocol.go diff --git a/generate/generate.go b/generate/generate.go deleted file mode 100644 index 65cbe15..0000000 --- a/generate/generate.go +++ /dev/null @@ -1,302 +0,0 @@ -package generate - -import "io" -import "fmt" -import "bufio" -import "strings" - -const send = -`// Send sends one message along a transaction. -func Send(trans hopp.Trans, message hopp.Message) error { - buffer, err := message.MarshalBinary() - if err != nil { return err } - return trans.Send(message.Method(), buffer) -} - -` - -// ResolveType resolves a HOPP type name to a Go type. For now, it supports all -// data types defined in TAPE. -func (this *Protocol) ResolveType(hopp string) (string, error) { - switch hopp { - case "I8": return "int8", nil - case "I16": return "int16", nil - case "I32": return "int32", nil - case "I64": return "int64", nil - case "U8": return "uint8", nil - case "U16": return "uint16", nil - case "U32": return "uint32", nil - case "U64": return "uint64", nil - case "I8Array": return "[]int8", nil - case "I16Array": return "[]int16", nil - case "I32Array": return "[]int32", nil - case "I64Array": return "[]int64", nil - case "U8Array": return "[]uint8", nil - case "U16Array": return "[]uint16", nil - case "U32Array": return "[]uint32", nil - case "U64Array": return "[]uint64", nil - case "String": return "string", nil - case "StringArray": return "[]string", nil - default: return "", fmt.Errorf("unknown type: %s", hopp) - } -} - -// Generate turns this protocol into code. The package name for the generated -// code must be specified. -func (this *Protocol) Generate(writer io.Writer, packag string) error { - out := bufio.NewWriter(writer) - defer out.Flush() - - fmt.Fprintf(out, "package %s\n\n", packag) - fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp\"\n") - fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp/tape\"\n\n") - - fmt.Fprintf(out, send) - this.receive(out) - - for _, message := range this.Messages { - err := this.defineMessage(out, message) - if err != nil { return err } - err = this.marshalMessage(out, message) - if err != nil { return err } - err = this.unmarshalMessage(out, message) - if err != nil { return err } - } - - return nil -} - -func (this *Protocol) receive(out io.Writer) error { - fmt.Fprintf(out, "// Receive receives one message from a transaction.\n") - fmt.Fprintf(out, "func Receive(trans hopp.Trans) (hopp.Message, error) {\n") - fmt.Fprintf(out, "\tmethod, data, err := trans.Receive()\n") - fmt.Fprintf(out, "\tif err != nil { return nil, err }\n") - fmt.Fprintf(out, "\tswitch method {\n") - for _, message := range this.Messages { - fmt.Fprintf(out, "\tcase 0x%04X:\n", message.Method) - fmt.Fprintf(out, "\t\tmessage := &Message%s { }\n", message.Name) - fmt.Fprintf(out, "\t\terr := message.UnmarshalBinary(data)\n") - fmt.Fprintf(out, "\t\tif err != nil { return nil, err }\n") - fmt.Fprintf(out, "\t\treturn message, nil\n") - } - fmt.Fprintf(out, "\tdefault: return nil, hopp.ErrUnknownMethod\n") - fmt.Fprintf(out, "\t}\n") - fmt.Fprintf(out, "}\n\n") - return nil -} - -func (this *Protocol) defineMessage(out io.Writer, message Message) error { - fmt.Fprintln(out, comment("//", fmt.Sprintf("(%d) %s\n", message.Method, message.Doc))) - fmt.Fprintf(out, "type Message%s struct {\n", message.Name) - for _, field := range message.Fields { - typ, err := this.ResolveType(field.Type) - if err != nil { return err } - if field.Doc != "" { - fmt.Fprintf(out, "\t%s\n", comment("\t//", field.Doc)) - } - if field.Optional { - typ = fmt.Sprintf("hopp.Option[%s]", typ) - } - fmt.Fprintf( - out, "\t/* %d */ %s %s\n", - field.Tag, field.Name, typ) - } - fmt.Fprintf(out, "}\n\n") - - fmt.Fprintf(out, "// Method returns the method number of the message.\n") - fmt.Fprintf(out, "func (msg Message%s) Method() uint16 {\n", message.Name) - fmt.Fprintf(out, "\treturn %d\n", message.Method) - fmt.Fprintf(out, "}\n\n") - return nil -} - -func (this *Protocol) marshalMessage(out io.Writer, message Message) error { - fmt.Fprintf(out, "// MarshalBinary encodes the data in this message into a buffer.\n") - fmt.Fprintf(out, "func (msg *Message%s) MarshalBinary() ([]byte, error) {\n", message.Name) - requiredCount := 0 - for _, field := range message.Fields { - if !field.Optional { requiredCount ++ } - } - fmt.Fprintf(out, "\tsize := 0\n") - fmt.Fprintf(out, "\tcount := %d\n", requiredCount) - for _, field := range message.Fields { - fmt.Fprintf(out, "\toffset%s := size\n", field.Name) - if field.Optional { - fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name) - fmt.Fprintf(out, "\t\tcount ++\n") - fmt.Fprintf(out, "\t\t") - err := this.marshalSizeOf(out, field) - if err != nil { return err } - fmt.Fprintf(out, " }\n") - } else { - fmt.Fprintf(out, "\t{") - fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name) - fmt.Fprintf(out, "\t\t") - err := this.marshalSizeOf(out, field) - if err != nil { return err } - fmt.Fprintf(out, " }\n") - } - } - fmt.Fprintf(out, "\tif size > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n") - fmt.Fprintf(out, "\tif count > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n") - fmt.Fprintf(out, "\tbuffer := make([]byte, 2 + 4 * count + size)\n") - fmt.Fprintf(out, "\ttape.EncodeI16(buffer[:2], uint16(count))\n") - for _, field := range message.Fields { - if field.Optional { - fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name) - fmt.Fprintf(out, "\t\t") - err := this.marshalField(out, field) - if err != nil { return err } - fmt.Fprintf(out, "}\n") - } else { - fmt.Fprintf(out, "\t{") - fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name) - fmt.Fprintf(out, "\t\t") - err := this.marshalField(out, field) - if err != nil { return err } - fmt.Fprintf(out, "}\n") - } - } - fmt.Fprintf(out, "\treturn buffer, nil\n") - fmt.Fprintf(out, "}\n\n") - return nil -} - -func (this *Protocol) marshalSizeOf(out io.Writer, field Field) error { - switch field.Type { - case "I8": fmt.Fprintf(out, "size += 1; _ = value") - case "I16": fmt.Fprintf(out, "size += 2; _ = value") - case "I32": fmt.Fprintf(out, "size += 4; _ = value") - case "I64": fmt.Fprintf(out, "size += 8; _ = value") - case "U8": fmt.Fprintf(out, "size += 1; _ = value") - case "U16": fmt.Fprintf(out, "size += 2; _ = value") - case "U32": fmt.Fprintf(out, "size += 4; _ = value") - case "U64": fmt.Fprintf(out, "size += 8; _ = value") - case "I8Array": fmt.Fprintf(out, "size += len(value)") - case "I16Array": fmt.Fprintf(out, "size += len(value) * 2") - case "I32Array": fmt.Fprintf(out, "size += len(value) * 4") - case "I64Array": fmt.Fprintf(out, "size += len(value) * 8") - case "U8Array": fmt.Fprintf(out, "size += len(value)") - case "U16Array": fmt.Fprintf(out, "size += len(value) * 2") - case "U32Array": fmt.Fprintf(out, "size += len(value) * 4") - case "U64Array": fmt.Fprintf(out, "size += len(value) * 8") - case "String": fmt.Fprintf(out, "size += len(value)") - case "StringArray": - fmt.Fprintf( - out, - "for _, el := range value { size += 2 + len(el) }") - default: - return fmt.Errorf("unknown type: %s", field.Type) - } - return nil -} - -func (this *Protocol) marshalField(out io.Writer, field Field) error { - switch field.Type { - case "I8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name) - case "I16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name) - case "I32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name) - case "I64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name) - case "U8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name) - case "U16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name) - case "U32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name) - case "U64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name) - case "I8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name) - case "I16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name) - case "I32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name) - case "I64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name) - case "U8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name) - case "U16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name) - case "U32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name) - case "U64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name) - case "String": fmt.Fprintf(out, "tape.EncodeString(buffer[offset%s:], value)", field.Name) - case "StringArray": fmt.Fprintf(out, "tape.EncodeStringArray(buffer[offset%s:], value)", field.Name) - default: - return fmt.Errorf("unknown type: %s", field.Type) - } - return nil -} - -func (this *Protocol) unmarshalMessage(out io.Writer, message Message) error { - fmt.Fprintf(out, "// UnmarshalBinary dencodes the data from a buffer int this message.\n") - fmt.Fprintf(out, - "func (msg *Message%s) UnmarshalBinary(buffer []byte) error {\n", - message.Name) - if len(message.Fields) < 1 { - fmt.Fprintf(out, "\t// no fields\n") - fmt.Fprintf(out, "\treturn nil\n") - fmt.Fprintf(out, "}\n\n") - return nil - } - fmt.Fprintf(out, "\tpairs, err := tape.DecodePairs(buffer)\n") - fmt.Fprintf(out, "\tif err != nil { return err }\n") - requiredTotal := 0 - for _, field := range message.Fields { - if field.Optional { - requiredTotal ++ - } - } - if requiredTotal > 0 { - fmt.Fprintf(out, "\tfoundRequired := 0\n") - } - fmt.Fprintf(out, "\tfor tag, data := range pairs {\n") - fmt.Fprintf(out, "\t\tswitch tag {\n") - for _, field := range message.Fields { - fmt.Fprintf(out, "\t\tcase %d:\n", field.Tag) - fmt.Fprintf(out, "\t\t\t") - err := this.unmarshalField(out, field) - if err != nil { return err } - fmt.Fprintf(out, "\n") - fmt.Fprintf(out, "\t\t\tif err != nil { return err }\n") - if field.Optional { - fmt.Fprintf(out, "\t\t\tmsg.%s = hopp.O(value)\n", field.Name) - } else { - fmt.Fprintf(out, "\t\t\tmsg.%s = value\n", field.Name) - if requiredTotal > 0 { - fmt.Fprintf(out, "\t\t\tfoundRequired ++\n") - } - } - } - fmt.Fprintf(out, "\t\t}\n") - fmt.Fprintf(out, "\t}\n") - if requiredTotal > 0 { - fmt.Fprintf(out, - "\tif foundRequired != %d { return hopp.ErrTablePairMissing }\n", - requiredTotal) - } - fmt.Fprintf(out, "\treturn nil\n") - fmt.Fprintf(out, "}\n\n") - return nil -} - -func (this *Protocol) unmarshalField(out io.Writer, field Field) error { - typ, err := this.ResolveType(field.Type) - if err != nil { return err } - switch field.Type { - case "I8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ) - case "I16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ) - case "I32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ) - case "I64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ) - case "U8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ) - case "U16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ) - case "U32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ) - case "U64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ) - case "I8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ) - case "I16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ) - case "I32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ) - case "I64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ) - case "U8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ) - case "U16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ) - case "U32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ) - case "U64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ) - case "String": fmt.Fprintf(out, "value, err := tape.DecodeString[%s](data)", typ) - case "StringArray": fmt.Fprintf(out, "value, err := tape.DecodeStringArray[%s](data)", typ) - default: - return fmt.Errorf("unknown type: %s", field.Type) - } - return nil -} - -func comment(prefix, text string) string { - return prefix + " " + strings.ReplaceAll(strings.TrimSpace(text), "\n", "\n" + prefix + " ") -} diff --git a/generate/protocol.go b/generate/protocol.go deleted file mode 100644 index 02fd456..0000000 --- a/generate/protocol.go +++ /dev/null @@ -1,244 +0,0 @@ -package generate - -import "io" -import "fmt" -import "errors" -import "strconv" -import "strings" -import "github.com/gomarkdown/markdown" -import "github.com/gomarkdown/markdown/ast" -import "github.com/gomarkdown/markdown/parser" - -// Protocol describes a protocol. -type Protocol struct { - Messages []Message -} - -// Message describes a protocol message. -type Message struct { - Doc string - Method uint16 - Name string - Fields []Field -} - -// Field describes a named value within a message. -type Field struct { - Doc string - Tag uint16 - Name string - Optional bool - Type string -} - -// ParseReader parses a protocol definition from a reader. -func ParseReader(reader io.Reader) (*Protocol, error) { - data, err := io.ReadAll(reader) - if err != nil { return nil, err } - protocol := new(Protocol) - err = protocol.UnmarshalText(data) - if err != nil { return nil, err } - return protocol, nil -} - -// UnmarshalText unmarshals markdown-formatted text data into the protocol. -func (this *Protocol) UnmarshalText(text []byte) error { - var state int; const ( - stateIdle = iota - stateMessage - stateMessageDoc - stateMessageField - ) - - var message *Message - addMessage := func(method uint16, name string) { - this.Messages = append(this.Messages, Message { - Method: method, - Name: name, - }) - message = &this.Messages[len(this.Messages) - 1] - } - - root := markdown.Parse(text, parser.New()) - for _, node := range root.GetChildren() { - if node, ok := node.(*ast.Heading); ok { - if node.Level == 2 { - if removeBreaks(flatten(node)) == "Messages" { - state = stateMessage - continue - } - } - - if node.Level > 3 { - state = stateIdle - continue - } - - if state != stateIdle && node.Level == 3 { - heading := removeBreaks(flatten(node)) - method, name, err := splitMessageHeading(heading) - if err != nil { return err } - addMessage(method, name) - state = stateMessageDoc - } - } - - if state == stateIdle { continue } - if message == nil { continue } - - // TODO when we are adding text content to the doc comment, it - // might be wise to do stuff like indent lists and quotes so - // that go doc renders them correctly - switch node := node.(type) { - case *ast.Paragraph: - if message.Doc != "" { message.Doc += "\n\n" } - message.Doc += removeBreaks(flatten(node)) - case *ast.BlockQuote: - if message.Doc != "" { message.Doc += "\n\n> " } - message.Doc += removeBreaks(flatten(node)) - case *ast.List: - // FIXME format the list - if message.Doc != "" { message.Doc += "\n\n" } - message.Doc += removeBreaks(flatten(node)) - case *ast.Table: - fields, err := processFieldTable(node) - if err != nil { return err} - message.Fields = append(message.Fields, fields...) - } - } - - return nil -} - -func processFieldTable(node *ast.Table) ([]Field, error) { - fields := []Field { } - children := node.GetChildren() - if len(children) != 2 { - return nil, errors.New("malformed field table") - } - - // get columns - columns := []string { } - if header, ok := children[0].(*ast.TableHeader); ok { - children := header.GetChildren() - if len(children) != 1 { - return nil, errors.New("malformed field table header") - } - if row, ok := header.Children[0].(*ast.TableRow); ok { - for _, cell := range row.GetChildren() { - if cell, ok := cell.(*ast.TableCell); ok { - columns = append(columns, flatten(cell)) - } - } - } else { - return nil, errors.New("malformed field table header") - } - for index, column := range columns { - columns[index] = strings.ToLower(column) - } - } else { - return nil, errors.New("malformed field table: no header") - } - - // get data - if body, ok := children[1].(*ast.TableBody); ok { - for _, node := range body.GetChildren() { - if row, ok := node.(*ast.TableRow); ok { - children := row.GetChildren() - if len(children) != len(columns) { - return nil, errors.New ( - "malformed field table row: wrong " + - "number of columns") - } - - field := Field { } - - for index, node := range children { - if cell, ok := node.(*ast.TableCell); ok { - text := flatten(cell) - switch columns[index] { - case "tag": - tag, err := parseTag(text) - if err != nil { return nil, err } - field.Tag = tag - case "name": - field.Name = text - case "required": - field.Optional = !parseBool(text) - case "optional": - field.Optional = parseBool(text) - case "type": - field.Type = text - case "comment", "purpose", "documentation": - field.Doc = text - } - }} - - fields = append(fields, field) - }} - } else { - return nil, errors.New("malformed field table: no body") - } - return fields, nil -} - -type nodeFlattener struct { - text string -} -func (this *nodeFlattener) String() string { return this.text } -func (this *nodeFlattener) Visit(node ast.Node, entering bool) ast.WalkStatus { - if entering { - if node := node.AsLeaf(); node != nil { - this.text += string(node.Literal) - } - } - return ast.GoToNext -} -func flatten(node ast.Node) string { - flattener := new(nodeFlattener) - ast.Walk(node, flattener) - return flattener.text -} - - -func removeBreaks(text string) string { - text = strings.ReplaceAll(text, "\n", " ") - text = strings.ReplaceAll(text, "\r", "") - return text -} - -func parseBool(text string) bool { - switch(strings.ToLower(text)) { - case "yes": return true - case "no": return false - case "true": return true - case "false": return false - } - return false -} - -func parseTag(text string) (uint16, error) { - tag, err := strconv.ParseUint(text, 10, 16) - if err != nil { - return 0, fmt.Errorf("malformed tag '%s': %w", text, err) - } - return uint16(tag), nil -} - -func splitMessageHeading(text string) (uint16, string, error) { - text = strings.TrimSpace(text) - methodText, name, ok := strings.Cut(text, " ") - if !ok { - return 0, "", fmt.Errorf( - "malformed message heading '%s': no message name", - text) - } - method, err := strconv.ParseUint(methodText, 16, 16) - if err != nil { - return 0, "", fmt.Errorf( - "malformed method number '%s': %w", - methodText, err) - } - name = strings.TrimSpace(name) - return uint16(method), name, nil -}