generate: Alter protocol parsing to support HOPP features

This commit is contained in:
Sasha Koshka 2025-01-20 13:56:03 -05:00
parent 382cc113c5
commit 6a883cebbf
2 changed files with 63 additions and 75 deletions

View File

@ -1,4 +1,4 @@
package generate package protocol
// Generate turns this protocol into code. // Generate turns this protocol into code.
func (this *Protocol) Generate(writer io.Writer) error { func (this *Protocol) Generate(writer io.Writer) error {

View File

@ -2,8 +2,8 @@ package protocol
import "io" import "io"
import "fmt" import "fmt"
import "bytes"
import "errors" import "errors"
import "strconv"
import "strings" import "strings"
import "github.com/gomarkdown/markdown" import "github.com/gomarkdown/markdown"
import "github.com/gomarkdown/markdown/ast" import "github.com/gomarkdown/markdown/ast"
@ -16,21 +16,23 @@ type Protocol struct {
// Message describes a protocol message. // Message describes a protocol message.
type Message struct { type Message struct {
Doc string Doc string
Name string Method uint16
Name string
Fields []Field Fields []Field
} }
// Field describes a named value within a message. // Field describes a named value within a message.
type Field struct { type Field struct {
Doc string Doc string
Tag uint16
Name string Name string
Optional bool Optional bool
Type string Type string
} }
// ParseReader parses a protocol definition from a reader. // ParseReader parses a protocol definition from a reader.
func ParseReader (reader io.Reader) (*Protocol, error) { func ParseReader(reader io.Reader) (*Protocol, error) {
data, err := io.ReadAll(reader) data, err := io.ReadAll(reader)
if err != nil { return nil, err } if err != nil { return nil, err }
protocol := new(Protocol) protocol := new(Protocol)
@ -40,7 +42,7 @@ func ParseReader (reader io.Reader) (*Protocol, error) {
} }
// UnmarshalText unmarshals markdown-formatted text data into the protocol. // UnmarshalText unmarshals markdown-formatted text data into the protocol.
func (this *Protocol) UnmarshalText (text []byte) error { func (this *Protocol) UnmarshalText(text []byte) error {
var state int; const ( var state int; const (
stateIdle = iota stateIdle = iota
stateMessage stateMessage
@ -49,9 +51,10 @@ func (this *Protocol) UnmarshalText (text []byte) error {
) )
var message *Message var message *Message
addMessage := func (name string) { addMessage := func(method uint16, name string) {
this.Messages = append(this.Messages, Message { this.Messages = append(this.Messages, Message {
Name: name, Method: method,
Name: name,
}) })
message = &this.Messages[len(this.Messages) - 1] message = &this.Messages[len(this.Messages) - 1]
} }
@ -72,7 +75,10 @@ func (this *Protocol) UnmarshalText (text []byte) error {
} }
if state != stateIdle && node.Level == 3 { if state != stateIdle && node.Level == 3 {
addMessage(removeBreaks(flatten(node))) heading := removeBreaks(flatten(node))
method, name, err := splitMessageHeading(heading)
if err != nil { return err }
addMessage(method, name)
state = stateMessageDoc state = stateMessageDoc
} }
} }
@ -80,6 +86,9 @@ func (this *Protocol) UnmarshalText (text []byte) error {
if state == stateIdle { continue } if state == stateIdle { continue }
if message == nil { continue } if message == nil { continue }
// TODO when we are adding text content to the doc comment, it
// might be wise to do stuff like indent lists and quotes so
// that go doc renders them correctly
switch node := node.(type) { switch node := node.(type) {
case *ast.Paragraph: case *ast.Paragraph:
if message.Doc != "" { message.Doc += "\n\n" } if message.Doc != "" { message.Doc += "\n\n" }
@ -88,9 +97,9 @@ func (this *Protocol) UnmarshalText (text []byte) error {
if message.Doc != "" { message.Doc += "\n\n> " } if message.Doc != "" { message.Doc += "\n\n> " }
message.Doc += removeBreaks(flatten(node)) message.Doc += removeBreaks(flatten(node))
case *ast.List: case *ast.List:
fields, err := processFieldList(node) // FIXME format the list
if err != nil { return err} if message.Doc != "" { message.Doc += "\n\n" }
message.Fields = append(message.Fields, fields...) message.Doc += removeBreaks(flatten(node))
case *ast.Table: case *ast.Table:
fields, err := processFieldTable(node) fields, err := processFieldTable(node)
if err != nil { return err} if err != nil { return err}
@ -101,49 +110,7 @@ func (this *Protocol) UnmarshalText (text []byte) error {
return nil return nil
} }
func (this *Protocol) MarshalText () ([]byte, error) { func processFieldTable(node *ast.Table) ([]Field, error) {
buffer := bytes.NewBuffer(nil)
fmt.Fprint(buffer, "# Messages\n")
for index, message := range this.Messages {
if index > 0 {
fmt.Fprint(buffer, "\n")
}
fmt.Fprintf(buffer, "## %s\n", message.Name)
if message.Doc != "" {
fmt.Fprintf(buffer, "%s\n", message.Doc)
}
if len(message.Fields) > 0 {
fmt.Fprint(buffer, "\n")
}
for _, field := range message.Fields {
fmt.Fprintf(buffer, "- %s", field.Name)
if field.Type != "" {
fmt.Fprintf(buffer, " %s", field.Type)
}
if field.Doc != "" {
fmt.Fprintf(buffer, ": %s", field.Doc)
}
fmt.Fprint(buffer, "\n")
}
}
return buffer.Bytes(), nil
}
func processFieldList (node *ast.List) ([]Field, error) {
fields := []Field { }
for _, item := range node.GetChildren() {
if item, ok := item.(*ast.ListItem); ok {
field, err := processFieldListItem(removeBreaks(flatten(item)))
if err != nil { return nil, err }
fields = append(fields, field)
}
}
return fields, nil
}
func processFieldTable (node *ast.Table) ([]Field, error) {
fields := []Field { } fields := []Field { }
children := node.GetChildren() children := node.GetChildren()
if len(children) != 2 { if len(children) != 2 {
@ -190,6 +157,10 @@ func processFieldTable (node *ast.Table) ([]Field, error) {
if cell, ok := node.(*ast.TableCell); ok { if cell, ok := node.(*ast.TableCell); ok {
text := flatten(cell) text := flatten(cell)
switch columns[index] { switch columns[index] {
case "tag":
tag, err := parseTag(text)
if err != nil { return nil, err }
field.Tag = tag
case "name": case "name":
field.Name = text field.Name = text
case "required": case "required":
@ -214,8 +185,8 @@ func processFieldTable (node *ast.Table) ([]Field, error) {
type nodeFlattener struct { type nodeFlattener struct {
text string text string
} }
func (this *nodeFlattener) String () string { return this.text } func (this *nodeFlattener) String() string { return this.text }
func (this *nodeFlattener) Visit (node ast.Node, entering bool) ast.WalkStatus { func (this *nodeFlattener) Visit(node ast.Node, entering bool) ast.WalkStatus {
if entering { if entering {
if node := node.AsLeaf(); node != nil { if node := node.AsLeaf(); node != nil {
this.text += string(node.Literal) this.text += string(node.Literal)
@ -223,35 +194,20 @@ func (this *nodeFlattener) Visit (node ast.Node, entering bool) ast.WalkStatus {
} }
return ast.GoToNext return ast.GoToNext
} }
func flatten (node ast.Node) string { func flatten(node ast.Node) string {
flattener := new(nodeFlattener) flattener := new(nodeFlattener)
ast.Walk(node, flattener) ast.Walk(node, flattener)
return flattener.text return flattener.text
} }
func processFieldListItem (bullet string) (Field, error) {
field := Field { }
nameAndType, doc, hasDoc := strings.Cut(bullet, ":")
if hasDoc {
field.Doc = strings.TrimSpace(doc)
}
name, ty, hasType := strings.Cut(strings.TrimSpace(nameAndType), " ") func removeBreaks(text string) string {
if !hasType {
return Field { }, errors.New(fmt.Sprintf("unspecified type near '%s'", bullet))
}
field.Name = strings.TrimSpace(name)
field.Type = strings.TrimSpace(ty)
return field, nil
}
func removeBreaks (text string) string {
text = strings.ReplaceAll(text, "\n", " ") text = strings.ReplaceAll(text, "\n", " ")
text = strings.ReplaceAll(text, "\r", "") text = strings.ReplaceAll(text, "\r", "")
return text return text
} }
func parseBool (text string) bool { func parseBool(text string) bool {
switch(strings.ToLower(text)) { switch(strings.ToLower(text)) {
case "yes": return true case "yes": return true
case "no": return false case "no": return false
@ -260,3 +216,35 @@ func parseBool (text string) bool {
} }
return false return false
} }
func parseTag(text string) (uint16, error) {
tag, err := strconv.ParseUint(text, 10, 16)
if err != nil {
return 0, fmt.Errorf("malformed tag '%s': %w", text, err)
}
return uint16(tag), nil
}
func splitMessageHeading(text string) (uint16, string, error) {
text = strings.TrimSpace(text)
if !strings.HasPrefix(text, "(") {
return 0, "", fmt.Errorf(
"malformed message heading '%s': no method number",
text)
}
text = strings.TrimPrefix(text, "(")
methodText, name, ok := strings.Cut(text, ")")
if !ok {
return 0, "", fmt.Errorf(
"malformed message heading '%s': no message name",
text)
}
method, err := strconv.ParseUint(text, 10, 16)
if err != nil {
return 0, "", fmt.Errorf(
"malformed method number '%s': %w",
methodText, err)
}
name = strings.TrimSpace(text)
return uint16(method), name, nil
}