package generate import "io" import "fmt" import "errors" import "strconv" import "strings" import "github.com/gomarkdown/markdown" import "github.com/gomarkdown/markdown/ast" import "github.com/gomarkdown/markdown/parser" // Protocol describes a protocol. type Protocol struct { Messages []Message } // Message describes a protocol message. type Message struct { Doc string Method uint16 Name string Fields []Field } // Field describes a named value within a message. type Field struct { Doc string Tag uint16 Name string Optional bool Type string } // ParseReader parses a protocol definition from a reader. func ParseReader(reader io.Reader) (*Protocol, error) { data, err := io.ReadAll(reader) if err != nil { return nil, err } protocol := new(Protocol) err = protocol.UnmarshalText(data) if err != nil { return nil, err } return protocol, nil } // UnmarshalText unmarshals markdown-formatted text data into the protocol. func (this *Protocol) UnmarshalText(text []byte) error { var state int; const ( stateIdle = iota stateMessage stateMessageDoc stateMessageField ) var message *Message addMessage := func(method uint16, name string) { this.Messages = append(this.Messages, Message { Method: method, Name: name, }) message = &this.Messages[len(this.Messages) - 1] } root := markdown.Parse(text, parser.New()) for _, node := range root.GetChildren() { if node, ok := node.(*ast.Heading); ok { if node.Level == 2 { if removeBreaks(flatten(node)) == "Messages" { state = stateMessage continue } } if node.Level > 3 { state = stateIdle continue } if state != stateIdle && node.Level == 3 { heading := removeBreaks(flatten(node)) method, name, err := splitMessageHeading(heading) if err != nil { return err } addMessage(method, name) state = stateMessageDoc } } if state == stateIdle { continue } if message == nil { continue } // TODO when we are adding text content to the doc comment, it // might be wise to do stuff like indent lists and quotes so // that go doc renders them correctly switch node := node.(type) { case *ast.Paragraph: if message.Doc != "" { message.Doc += "\n\n" } message.Doc += removeBreaks(flatten(node)) case *ast.BlockQuote: if message.Doc != "" { message.Doc += "\n\n> " } message.Doc += removeBreaks(flatten(node)) case *ast.List: // FIXME format the list if message.Doc != "" { message.Doc += "\n\n" } message.Doc += removeBreaks(flatten(node)) case *ast.Table: fields, err := processFieldTable(node) if err != nil { return err} message.Fields = append(message.Fields, fields...) } } return nil } func processFieldTable(node *ast.Table) ([]Field, error) { fields := []Field { } children := node.GetChildren() if len(children) != 2 { return nil, errors.New("malformed field table") } // get columns columns := []string { } if header, ok := children[0].(*ast.TableHeader); ok { children := header.GetChildren() if len(children) != 1 { return nil, errors.New("malformed field table header") } if row, ok := header.Children[0].(*ast.TableRow); ok { for _, cell := range row.GetChildren() { if cell, ok := cell.(*ast.TableCell); ok { columns = append(columns, flatten(cell)) } } } else { return nil, errors.New("malformed field table header") } for index, column := range columns { columns[index] = strings.ToLower(column) } } else { return nil, errors.New("malformed field table: no header") } // get data if body, ok := children[1].(*ast.TableBody); ok { for _, node := range body.GetChildren() { if row, ok := node.(*ast.TableRow); ok { children := row.GetChildren() if len(children) != len(columns) { return nil, errors.New ( "malformed field table row: wrong " + "number of columns") } field := Field { } for index, node := range children { if cell, ok := node.(*ast.TableCell); ok { text := flatten(cell) switch columns[index] { case "tag": tag, err := parseTag(text) if err != nil { return nil, err } field.Tag = tag case "name": field.Name = text case "required": field.Optional = !parseBool(text) case "optional": field.Optional = parseBool(text) case "type": field.Type = text case "comment", "purpose", "documentation": field.Doc = text } }} fields = append(fields, field) }} } else { return nil, errors.New("malformed field table: no body") } return fields, nil } type nodeFlattener struct { text string } func (this *nodeFlattener) String() string { return this.text } func (this *nodeFlattener) Visit(node ast.Node, entering bool) ast.WalkStatus { if entering { if node := node.AsLeaf(); node != nil { this.text += string(node.Literal) } } return ast.GoToNext } func flatten(node ast.Node) string { flattener := new(nodeFlattener) ast.Walk(node, flattener) return flattener.text } func removeBreaks(text string) string { text = strings.ReplaceAll(text, "\n", " ") text = strings.ReplaceAll(text, "\r", "") return text } func parseBool(text string) bool { switch(strings.ToLower(text)) { case "yes": return true case "no": return false case "true": return true case "false": return false } return false } func parseTag(text string) (uint16, error) { tag, err := strconv.ParseUint(text, 10, 16) if err != nil { return 0, fmt.Errorf("malformed tag '%s': %w", text, err) } return uint16(tag), nil } func splitMessageHeading(text string) (uint16, string, error) { text = strings.TrimSpace(text) methodText, name, ok := strings.Cut(text, " ") if !ok { return 0, "", fmt.Errorf( "malformed message heading '%s': no message name", text) } method, err := strconv.ParseUint(methodText, 16, 16) if err != nil { return 0, "", fmt.Errorf( "malformed method number '%s': %w", methodText, err) } name = strings.TrimSpace(name) return uint16(method), name, nil }