message-size-increase #3
@ -1,302 +0,0 @@
|
|||||||
package generate
|
|
||||||
|
|
||||||
import "io"
|
|
||||||
import "fmt"
|
|
||||||
import "bufio"
|
|
||||||
import "strings"
|
|
||||||
|
|
||||||
const send =
|
|
||||||
`// Send sends one message along a transaction.
|
|
||||||
func Send(trans hopp.Trans, message hopp.Message) error {
|
|
||||||
buffer, err := message.MarshalBinary()
|
|
||||||
if err != nil { return err }
|
|
||||||
return trans.Send(message.Method(), buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
`
|
|
||||||
|
|
||||||
// ResolveType resolves a HOPP type name to a Go type. For now, it supports all
|
|
||||||
// data types defined in TAPE.
|
|
||||||
func (this *Protocol) ResolveType(hopp string) (string, error) {
|
|
||||||
switch hopp {
|
|
||||||
case "I8": return "int8", nil
|
|
||||||
case "I16": return "int16", nil
|
|
||||||
case "I32": return "int32", nil
|
|
||||||
case "I64": return "int64", nil
|
|
||||||
case "U8": return "uint8", nil
|
|
||||||
case "U16": return "uint16", nil
|
|
||||||
case "U32": return "uint32", nil
|
|
||||||
case "U64": return "uint64", nil
|
|
||||||
case "I8Array": return "[]int8", nil
|
|
||||||
case "I16Array": return "[]int16", nil
|
|
||||||
case "I32Array": return "[]int32", nil
|
|
||||||
case "I64Array": return "[]int64", nil
|
|
||||||
case "U8Array": return "[]uint8", nil
|
|
||||||
case "U16Array": return "[]uint16", nil
|
|
||||||
case "U32Array": return "[]uint32", nil
|
|
||||||
case "U64Array": return "[]uint64", nil
|
|
||||||
case "String": return "string", nil
|
|
||||||
case "StringArray": return "[]string", nil
|
|
||||||
default: return "", fmt.Errorf("unknown type: %s", hopp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate turns this protocol into code. The package name for the generated
|
|
||||||
// code must be specified.
|
|
||||||
func (this *Protocol) Generate(writer io.Writer, packag string) error {
|
|
||||||
out := bufio.NewWriter(writer)
|
|
||||||
defer out.Flush()
|
|
||||||
|
|
||||||
fmt.Fprintf(out, "package %s\n\n", packag)
|
|
||||||
fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp\"\n")
|
|
||||||
fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp/tape\"\n\n")
|
|
||||||
|
|
||||||
fmt.Fprintf(out, send)
|
|
||||||
this.receive(out)
|
|
||||||
|
|
||||||
for _, message := range this.Messages {
|
|
||||||
err := this.defineMessage(out, message)
|
|
||||||
if err != nil { return err }
|
|
||||||
err = this.marshalMessage(out, message)
|
|
||||||
if err != nil { return err }
|
|
||||||
err = this.unmarshalMessage(out, message)
|
|
||||||
if err != nil { return err }
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *Protocol) receive(out io.Writer) error {
|
|
||||||
fmt.Fprintf(out, "// Receive receives one message from a transaction.\n")
|
|
||||||
fmt.Fprintf(out, "func Receive(trans hopp.Trans) (hopp.Message, error) {\n")
|
|
||||||
fmt.Fprintf(out, "\tmethod, data, err := trans.Receive()\n")
|
|
||||||
fmt.Fprintf(out, "\tif err != nil { return nil, err }\n")
|
|
||||||
fmt.Fprintf(out, "\tswitch method {\n")
|
|
||||||
for _, message := range this.Messages {
|
|
||||||
fmt.Fprintf(out, "\tcase 0x%04X:\n", message.Method)
|
|
||||||
fmt.Fprintf(out, "\t\tmessage := &Message%s { }\n", message.Name)
|
|
||||||
fmt.Fprintf(out, "\t\terr := message.UnmarshalBinary(data)\n")
|
|
||||||
fmt.Fprintf(out, "\t\tif err != nil { return nil, err }\n")
|
|
||||||
fmt.Fprintf(out, "\t\treturn message, nil\n")
|
|
||||||
}
|
|
||||||
fmt.Fprintf(out, "\tdefault: return nil, hopp.ErrUnknownMethod\n")
|
|
||||||
fmt.Fprintf(out, "\t}\n")
|
|
||||||
fmt.Fprintf(out, "}\n\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *Protocol) defineMessage(out io.Writer, message Message) error {
|
|
||||||
fmt.Fprintln(out, comment("//", fmt.Sprintf("(%d) %s\n", message.Method, message.Doc)))
|
|
||||||
fmt.Fprintf(out, "type Message%s struct {\n", message.Name)
|
|
||||||
for _, field := range message.Fields {
|
|
||||||
typ, err := this.ResolveType(field.Type)
|
|
||||||
if err != nil { return err }
|
|
||||||
if field.Doc != "" {
|
|
||||||
fmt.Fprintf(out, "\t%s\n", comment("\t//", field.Doc))
|
|
||||||
}
|
|
||||||
if field.Optional {
|
|
||||||
typ = fmt.Sprintf("hopp.Option[%s]", typ)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(
|
|
||||||
out, "\t/* %d */ %s %s\n",
|
|
||||||
field.Tag, field.Name, typ)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(out, "}\n\n")
|
|
||||||
|
|
||||||
fmt.Fprintf(out, "// Method returns the method number of the message.\n")
|
|
||||||
fmt.Fprintf(out, "func (msg Message%s) Method() uint16 {\n", message.Name)
|
|
||||||
fmt.Fprintf(out, "\treturn %d\n", message.Method)
|
|
||||||
fmt.Fprintf(out, "}\n\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *Protocol) marshalMessage(out io.Writer, message Message) error {
|
|
||||||
fmt.Fprintf(out, "// MarshalBinary encodes the data in this message into a buffer.\n")
|
|
||||||
fmt.Fprintf(out, "func (msg *Message%s) MarshalBinary() ([]byte, error) {\n", message.Name)
|
|
||||||
requiredCount := 0
|
|
||||||
for _, field := range message.Fields {
|
|
||||||
if !field.Optional { requiredCount ++ }
|
|
||||||
}
|
|
||||||
fmt.Fprintf(out, "\tsize := 0\n")
|
|
||||||
fmt.Fprintf(out, "\tcount := %d\n", requiredCount)
|
|
||||||
for _, field := range message.Fields {
|
|
||||||
fmt.Fprintf(out, "\toffset%s := size\n", field.Name)
|
|
||||||
if field.Optional {
|
|
||||||
fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name)
|
|
||||||
fmt.Fprintf(out, "\t\tcount ++\n")
|
|
||||||
fmt.Fprintf(out, "\t\t")
|
|
||||||
err := this.marshalSizeOf(out, field)
|
|
||||||
if err != nil { return err }
|
|
||||||
fmt.Fprintf(out, " }\n")
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(out, "\t{")
|
|
||||||
fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name)
|
|
||||||
fmt.Fprintf(out, "\t\t")
|
|
||||||
err := this.marshalSizeOf(out, field)
|
|
||||||
if err != nil { return err }
|
|
||||||
fmt.Fprintf(out, " }\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Fprintf(out, "\tif size > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n")
|
|
||||||
fmt.Fprintf(out, "\tif count > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n")
|
|
||||||
fmt.Fprintf(out, "\tbuffer := make([]byte, 2 + 4 * count + size)\n")
|
|
||||||
fmt.Fprintf(out, "\ttape.EncodeI16(buffer[:2], uint16(count))\n")
|
|
||||||
for _, field := range message.Fields {
|
|
||||||
if field.Optional {
|
|
||||||
fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name)
|
|
||||||
fmt.Fprintf(out, "\t\t")
|
|
||||||
err := this.marshalField(out, field)
|
|
||||||
if err != nil { return err }
|
|
||||||
fmt.Fprintf(out, "}\n")
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(out, "\t{")
|
|
||||||
fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name)
|
|
||||||
fmt.Fprintf(out, "\t\t")
|
|
||||||
err := this.marshalField(out, field)
|
|
||||||
if err != nil { return err }
|
|
||||||
fmt.Fprintf(out, "}\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Fprintf(out, "\treturn buffer, nil\n")
|
|
||||||
fmt.Fprintf(out, "}\n\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *Protocol) marshalSizeOf(out io.Writer, field Field) error {
|
|
||||||
switch field.Type {
|
|
||||||
case "I8": fmt.Fprintf(out, "size += 1; _ = value")
|
|
||||||
case "I16": fmt.Fprintf(out, "size += 2; _ = value")
|
|
||||||
case "I32": fmt.Fprintf(out, "size += 4; _ = value")
|
|
||||||
case "I64": fmt.Fprintf(out, "size += 8; _ = value")
|
|
||||||
case "U8": fmt.Fprintf(out, "size += 1; _ = value")
|
|
||||||
case "U16": fmt.Fprintf(out, "size += 2; _ = value")
|
|
||||||
case "U32": fmt.Fprintf(out, "size += 4; _ = value")
|
|
||||||
case "U64": fmt.Fprintf(out, "size += 8; _ = value")
|
|
||||||
case "I8Array": fmt.Fprintf(out, "size += len(value)")
|
|
||||||
case "I16Array": fmt.Fprintf(out, "size += len(value) * 2")
|
|
||||||
case "I32Array": fmt.Fprintf(out, "size += len(value) * 4")
|
|
||||||
case "I64Array": fmt.Fprintf(out, "size += len(value) * 8")
|
|
||||||
case "U8Array": fmt.Fprintf(out, "size += len(value)")
|
|
||||||
case "U16Array": fmt.Fprintf(out, "size += len(value) * 2")
|
|
||||||
case "U32Array": fmt.Fprintf(out, "size += len(value) * 4")
|
|
||||||
case "U64Array": fmt.Fprintf(out, "size += len(value) * 8")
|
|
||||||
case "String": fmt.Fprintf(out, "size += len(value)")
|
|
||||||
case "StringArray":
|
|
||||||
fmt.Fprintf(
|
|
||||||
out,
|
|
||||||
"for _, el := range value { size += 2 + len(el) }")
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown type: %s", field.Type)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *Protocol) marshalField(out io.Writer, field Field) error {
|
|
||||||
switch field.Type {
|
|
||||||
case "I8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "I16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "I32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "I64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "U8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "U16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "U32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "U64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "I8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "I16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "I32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "I64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "U8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "U16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "U32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "U64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "String": fmt.Fprintf(out, "tape.EncodeString(buffer[offset%s:], value)", field.Name)
|
|
||||||
case "StringArray": fmt.Fprintf(out, "tape.EncodeStringArray(buffer[offset%s:], value)", field.Name)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown type: %s", field.Type)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *Protocol) unmarshalMessage(out io.Writer, message Message) error {
|
|
||||||
fmt.Fprintf(out, "// UnmarshalBinary dencodes the data from a buffer int this message.\n")
|
|
||||||
fmt.Fprintf(out,
|
|
||||||
"func (msg *Message%s) UnmarshalBinary(buffer []byte) error {\n",
|
|
||||||
message.Name)
|
|
||||||
if len(message.Fields) < 1 {
|
|
||||||
fmt.Fprintf(out, "\t// no fields\n")
|
|
||||||
fmt.Fprintf(out, "\treturn nil\n")
|
|
||||||
fmt.Fprintf(out, "}\n\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
fmt.Fprintf(out, "\tpairs, err := tape.DecodePairs(buffer)\n")
|
|
||||||
fmt.Fprintf(out, "\tif err != nil { return err }\n")
|
|
||||||
requiredTotal := 0
|
|
||||||
for _, field := range message.Fields {
|
|
||||||
if field.Optional {
|
|
||||||
requiredTotal ++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if requiredTotal > 0 {
|
|
||||||
fmt.Fprintf(out, "\tfoundRequired := 0\n")
|
|
||||||
}
|
|
||||||
fmt.Fprintf(out, "\tfor tag, data := range pairs {\n")
|
|
||||||
fmt.Fprintf(out, "\t\tswitch tag {\n")
|
|
||||||
for _, field := range message.Fields {
|
|
||||||
fmt.Fprintf(out, "\t\tcase %d:\n", field.Tag)
|
|
||||||
fmt.Fprintf(out, "\t\t\t")
|
|
||||||
err := this.unmarshalField(out, field)
|
|
||||||
if err != nil { return err }
|
|
||||||
fmt.Fprintf(out, "\n")
|
|
||||||
fmt.Fprintf(out, "\t\t\tif err != nil { return err }\n")
|
|
||||||
if field.Optional {
|
|
||||||
fmt.Fprintf(out, "\t\t\tmsg.%s = hopp.O(value)\n", field.Name)
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(out, "\t\t\tmsg.%s = value\n", field.Name)
|
|
||||||
if requiredTotal > 0 {
|
|
||||||
fmt.Fprintf(out, "\t\t\tfoundRequired ++\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Fprintf(out, "\t\t}\n")
|
|
||||||
fmt.Fprintf(out, "\t}\n")
|
|
||||||
if requiredTotal > 0 {
|
|
||||||
fmt.Fprintf(out,
|
|
||||||
"\tif foundRequired != %d { return hopp.ErrTablePairMissing }\n",
|
|
||||||
requiredTotal)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(out, "\treturn nil\n")
|
|
||||||
fmt.Fprintf(out, "}\n\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *Protocol) unmarshalField(out io.Writer, field Field) error {
|
|
||||||
typ, err := this.ResolveType(field.Type)
|
|
||||||
if err != nil { return err }
|
|
||||||
switch field.Type {
|
|
||||||
case "I8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ)
|
|
||||||
case "I16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ)
|
|
||||||
case "I32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ)
|
|
||||||
case "I64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ)
|
|
||||||
case "U8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ)
|
|
||||||
case "U16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ)
|
|
||||||
case "U32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ)
|
|
||||||
case "U64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ)
|
|
||||||
case "I8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ)
|
|
||||||
case "I16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ)
|
|
||||||
case "I32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ)
|
|
||||||
case "I64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ)
|
|
||||||
case "U8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ)
|
|
||||||
case "U16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ)
|
|
||||||
case "U32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ)
|
|
||||||
case "U64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ)
|
|
||||||
case "String": fmt.Fprintf(out, "value, err := tape.DecodeString[%s](data)", typ)
|
|
||||||
case "StringArray": fmt.Fprintf(out, "value, err := tape.DecodeStringArray[%s](data)", typ)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown type: %s", field.Type)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func comment(prefix, text string) string {
|
|
||||||
return prefix + " " + strings.ReplaceAll(strings.TrimSpace(text), "\n", "\n" + prefix + " ")
|
|
||||||
}
|
|
@ -1,244 +0,0 @@
|
|||||||
package generate
|
|
||||||
|
|
||||||
import "io"
|
|
||||||
import "fmt"
|
|
||||||
import "errors"
|
|
||||||
import "strconv"
|
|
||||||
import "strings"
|
|
||||||
import "github.com/gomarkdown/markdown"
|
|
||||||
import "github.com/gomarkdown/markdown/ast"
|
|
||||||
import "github.com/gomarkdown/markdown/parser"
|
|
||||||
|
|
||||||
// Protocol describes a protocol.
|
|
||||||
type Protocol struct {
|
|
||||||
Messages []Message
|
|
||||||
}
|
|
||||||
|
|
||||||
// Message describes a protocol message.
|
|
||||||
type Message struct {
|
|
||||||
Doc string
|
|
||||||
Method uint16
|
|
||||||
Name string
|
|
||||||
Fields []Field
|
|
||||||
}
|
|
||||||
|
|
||||||
// Field describes a named value within a message.
|
|
||||||
type Field struct {
|
|
||||||
Doc string
|
|
||||||
Tag uint16
|
|
||||||
Name string
|
|
||||||
Optional bool
|
|
||||||
Type string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseReader parses a protocol definition from a reader.
|
|
||||||
func ParseReader(reader io.Reader) (*Protocol, error) {
|
|
||||||
data, err := io.ReadAll(reader)
|
|
||||||
if err != nil { return nil, err }
|
|
||||||
protocol := new(Protocol)
|
|
||||||
err = protocol.UnmarshalText(data)
|
|
||||||
if err != nil { return nil, err }
|
|
||||||
return protocol, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnmarshalText unmarshals markdown-formatted text data into the protocol.
|
|
||||||
func (this *Protocol) UnmarshalText(text []byte) error {
|
|
||||||
var state int; const (
|
|
||||||
stateIdle = iota
|
|
||||||
stateMessage
|
|
||||||
stateMessageDoc
|
|
||||||
stateMessageField
|
|
||||||
)
|
|
||||||
|
|
||||||
var message *Message
|
|
||||||
addMessage := func(method uint16, name string) {
|
|
||||||
this.Messages = append(this.Messages, Message {
|
|
||||||
Method: method,
|
|
||||||
Name: name,
|
|
||||||
})
|
|
||||||
message = &this.Messages[len(this.Messages) - 1]
|
|
||||||
}
|
|
||||||
|
|
||||||
root := markdown.Parse(text, parser.New())
|
|
||||||
for _, node := range root.GetChildren() {
|
|
||||||
if node, ok := node.(*ast.Heading); ok {
|
|
||||||
if node.Level == 2 {
|
|
||||||
if removeBreaks(flatten(node)) == "Messages" {
|
|
||||||
state = stateMessage
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if node.Level > 3 {
|
|
||||||
state = stateIdle
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if state != stateIdle && node.Level == 3 {
|
|
||||||
heading := removeBreaks(flatten(node))
|
|
||||||
method, name, err := splitMessageHeading(heading)
|
|
||||||
if err != nil { return err }
|
|
||||||
addMessage(method, name)
|
|
||||||
state = stateMessageDoc
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if state == stateIdle { continue }
|
|
||||||
if message == nil { continue }
|
|
||||||
|
|
||||||
// TODO when we are adding text content to the doc comment, it
|
|
||||||
// might be wise to do stuff like indent lists and quotes so
|
|
||||||
// that go doc renders them correctly
|
|
||||||
switch node := node.(type) {
|
|
||||||
case *ast.Paragraph:
|
|
||||||
if message.Doc != "" { message.Doc += "\n\n" }
|
|
||||||
message.Doc += removeBreaks(flatten(node))
|
|
||||||
case *ast.BlockQuote:
|
|
||||||
if message.Doc != "" { message.Doc += "\n\n> " }
|
|
||||||
message.Doc += removeBreaks(flatten(node))
|
|
||||||
case *ast.List:
|
|
||||||
// FIXME format the list
|
|
||||||
if message.Doc != "" { message.Doc += "\n\n" }
|
|
||||||
message.Doc += removeBreaks(flatten(node))
|
|
||||||
case *ast.Table:
|
|
||||||
fields, err := processFieldTable(node)
|
|
||||||
if err != nil { return err}
|
|
||||||
message.Fields = append(message.Fields, fields...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func processFieldTable(node *ast.Table) ([]Field, error) {
|
|
||||||
fields := []Field { }
|
|
||||||
children := node.GetChildren()
|
|
||||||
if len(children) != 2 {
|
|
||||||
return nil, errors.New("malformed field table")
|
|
||||||
}
|
|
||||||
|
|
||||||
// get columns
|
|
||||||
columns := []string { }
|
|
||||||
if header, ok := children[0].(*ast.TableHeader); ok {
|
|
||||||
children := header.GetChildren()
|
|
||||||
if len(children) != 1 {
|
|
||||||
return nil, errors.New("malformed field table header")
|
|
||||||
}
|
|
||||||
if row, ok := header.Children[0].(*ast.TableRow); ok {
|
|
||||||
for _, cell := range row.GetChildren() {
|
|
||||||
if cell, ok := cell.(*ast.TableCell); ok {
|
|
||||||
columns = append(columns, flatten(cell))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, errors.New("malformed field table header")
|
|
||||||
}
|
|
||||||
for index, column := range columns {
|
|
||||||
columns[index] = strings.ToLower(column)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nil, errors.New("malformed field table: no header")
|
|
||||||
}
|
|
||||||
|
|
||||||
// get data
|
|
||||||
if body, ok := children[1].(*ast.TableBody); ok {
|
|
||||||
for _, node := range body.GetChildren() {
|
|
||||||
if row, ok := node.(*ast.TableRow); ok {
|
|
||||||
children := row.GetChildren()
|
|
||||||
if len(children) != len(columns) {
|
|
||||||
return nil, errors.New (
|
|
||||||
"malformed field table row: wrong " +
|
|
||||||
"number of columns")
|
|
||||||
}
|
|
||||||
|
|
||||||
field := Field { }
|
|
||||||
|
|
||||||
for index, node := range children {
|
|
||||||
if cell, ok := node.(*ast.TableCell); ok {
|
|
||||||
text := flatten(cell)
|
|
||||||
switch columns[index] {
|
|
||||||
case "tag":
|
|
||||||
tag, err := parseTag(text)
|
|
||||||
if err != nil { return nil, err }
|
|
||||||
field.Tag = tag
|
|
||||||
case "name":
|
|
||||||
field.Name = text
|
|
||||||
case "required":
|
|
||||||
field.Optional = !parseBool(text)
|
|
||||||
case "optional":
|
|
||||||
field.Optional = parseBool(text)
|
|
||||||
case "type":
|
|
||||||
field.Type = text
|
|
||||||
case "comment", "purpose", "documentation":
|
|
||||||
field.Doc = text
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
|
|
||||||
fields = append(fields, field)
|
|
||||||
}}
|
|
||||||
} else {
|
|
||||||
return nil, errors.New("malformed field table: no body")
|
|
||||||
}
|
|
||||||
return fields, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type nodeFlattener struct {
|
|
||||||
text string
|
|
||||||
}
|
|
||||||
func (this *nodeFlattener) String() string { return this.text }
|
|
||||||
func (this *nodeFlattener) Visit(node ast.Node, entering bool) ast.WalkStatus {
|
|
||||||
if entering {
|
|
||||||
if node := node.AsLeaf(); node != nil {
|
|
||||||
this.text += string(node.Literal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ast.GoToNext
|
|
||||||
}
|
|
||||||
func flatten(node ast.Node) string {
|
|
||||||
flattener := new(nodeFlattener)
|
|
||||||
ast.Walk(node, flattener)
|
|
||||||
return flattener.text
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
func removeBreaks(text string) string {
|
|
||||||
text = strings.ReplaceAll(text, "\n", " ")
|
|
||||||
text = strings.ReplaceAll(text, "\r", "")
|
|
||||||
return text
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseBool(text string) bool {
|
|
||||||
switch(strings.ToLower(text)) {
|
|
||||||
case "yes": return true
|
|
||||||
case "no": return false
|
|
||||||
case "true": return true
|
|
||||||
case "false": return false
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTag(text string) (uint16, error) {
|
|
||||||
tag, err := strconv.ParseUint(text, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("malformed tag '%s': %w", text, err)
|
|
||||||
}
|
|
||||||
return uint16(tag), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitMessageHeading(text string) (uint16, string, error) {
|
|
||||||
text = strings.TrimSpace(text)
|
|
||||||
methodText, name, ok := strings.Cut(text, " ")
|
|
||||||
if !ok {
|
|
||||||
return 0, "", fmt.Errorf(
|
|
||||||
"malformed message heading '%s': no message name",
|
|
||||||
text)
|
|
||||||
}
|
|
||||||
method, err := strconv.ParseUint(methodText, 16, 16)
|
|
||||||
if err != nil {
|
|
||||||
return 0, "", fmt.Errorf(
|
|
||||||
"malformed method number '%s': %w",
|
|
||||||
methodText, err)
|
|
||||||
}
|
|
||||||
name = strings.TrimSpace(name)
|
|
||||||
return uint16(method), name, nil
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user