package protocol import "io" import "fmt" import "bytes" import "errors" import "strings" import "github.com/gomarkdown/markdown" import "github.com/gomarkdown/markdown/ast" import "github.com/gomarkdown/markdown/parser" // Protocol describes a protocol. type Protocol struct { Messages []Message } // Message describes a protocol message. type Message struct { Doc string Name string Fields []Field } // Field describes a named value within a message. type Field struct { Doc string Name string Optional bool Type string } // ParseReader parses a protocol definition from a reader. func ParseReader (reader io.Reader) (*Protocol, error) { data, err := io.ReadAll(reader) if err != nil { return nil, err } protocol := new(Protocol) err = protocol.UnmarshalText(data) if err != nil { return nil, err } return protocol, nil } // UnmarshalText unmarshals markdown-formatted text data into the protocol. func (this *Protocol) UnmarshalText (text []byte) error { var state int; const ( stateIdle = iota stateMessage stateMessageDoc stateMessageField ) var message *Message addMessage := func (name string) { this.Messages = append(this.Messages, Message { Name: name, }) message = &this.Messages[len(this.Messages) - 1] } root := markdown.Parse(text, parser.New()) for _, node := range root.GetChildren() { if node, ok := node.(*ast.Heading); ok { if node.Level == 2 { if removeBreaks(flatten(node)) == "Messages" { state = stateMessage continue } } if node.Level > 3 { state = stateIdle continue } if state != stateIdle && node.Level == 3 { addMessage(removeBreaks(flatten(node))) state = stateMessageDoc } } if state == stateIdle { continue } if message == nil { continue } switch node := node.(type) { case *ast.Paragraph: if message.Doc != "" { message.Doc += "\n\n" } message.Doc += removeBreaks(flatten(node)) case *ast.BlockQuote: if message.Doc != "" { message.Doc += "\n\n> " } message.Doc += removeBreaks(flatten(node)) case *ast.List: fields, err := processFieldList(node) if err != nil { return err} message.Fields = append(message.Fields, fields...) case *ast.Table: fields, err := processFieldTable(node) if err != nil { return err} message.Fields = append(message.Fields, fields...) } } return nil } func (this *Protocol) MarshalText () ([]byte, error) { buffer := bytes.NewBuffer(nil) fmt.Fprint(buffer, "# Messages\n") for index, message := range this.Messages { if index > 0 { fmt.Fprint(buffer, "\n") } fmt.Fprintf(buffer, "## %s\n", message.Name) if message.Doc != "" { fmt.Fprintf(buffer, "%s\n", message.Doc) } if len(message.Fields) > 0 { fmt.Fprint(buffer, "\n") } for _, field := range message.Fields { fmt.Fprintf(buffer, "- %s", field.Name) if field.Type != "" { fmt.Fprintf(buffer, " %s", field.Type) } if field.Doc != "" { fmt.Fprintf(buffer, ": %s", field.Doc) } fmt.Fprint(buffer, "\n") } } return buffer.Bytes(), nil } func processFieldList (node *ast.List) ([]Field, error) { fields := []Field { } for _, item := range node.GetChildren() { if item, ok := item.(*ast.ListItem); ok { field, err := processFieldListItem(removeBreaks(flatten(item))) if err != nil { return nil, err } fields = append(fields, field) } } return fields, nil } func processFieldTable (node *ast.Table) ([]Field, error) { fields := []Field { } children := node.GetChildren() if len(children) != 2 { return nil, errors.New("malformed field table") } // get columns columns := []string { } if header, ok := children[0].(*ast.TableHeader); ok { children := header.GetChildren() if len(children) != 1 { return nil, errors.New("malformed field table header") } if row, ok := header.Children[0].(*ast.TableRow); ok { for _, cell := range row.GetChildren() { if cell, ok := cell.(*ast.TableCell); ok { columns = append(columns, flatten(cell)) } } } else { return nil, errors.New("malformed field table header") } for index, column := range columns { columns[index] = strings.ToLower(column) } } else { return nil, errors.New("malformed field table: no header") } // get data if body, ok := children[1].(*ast.TableBody); ok { for _, node := range body.GetChildren() { if row, ok := node.(*ast.TableRow); ok { children := row.GetChildren() if len(children) != len(columns) { return nil, errors.New ( "malformed field table row: wrong " + "number of columns") } field := Field { } for index, node := range children { if cell, ok := node.(*ast.TableCell); ok { text := flatten(cell) switch columns[index] { case "name": field.Name = text case "required": field.Optional = !parseBool(text) case "optional": field.Optional = parseBool(text) case "type": field.Type = text case "comment", "purpose", "documentation": field.Doc = text } }} fields = append(fields, field) }} } else { return nil, errors.New("malformed field table: no body") } return fields, nil } type nodeFlattener struct { text string } func (this *nodeFlattener) String () string { return this.text } func (this *nodeFlattener) Visit (node ast.Node, entering bool) ast.WalkStatus { if entering { if node := node.AsLeaf(); node != nil { this.text += string(node.Literal) } } return ast.GoToNext } func flatten (node ast.Node) string { flattener := new(nodeFlattener) ast.Walk(node, flattener) return flattener.text } func processFieldListItem (bullet string) (Field, error) { field := Field { } nameAndType, doc, hasDoc := strings.Cut(bullet, ":") if hasDoc { field.Doc = strings.TrimSpace(doc) } name, ty, hasType := strings.Cut(strings.TrimSpace(nameAndType), " ") if !hasType { return Field { }, errors.New(fmt.Sprintf("unspecified type near '%s'", bullet)) } field.Name = strings.TrimSpace(name) field.Type = strings.TrimSpace(ty) return field, nil } func removeBreaks (text string) string { text = strings.ReplaceAll(text, "\n", " ") text = strings.ReplaceAll(text, "\r", "") return text } func parseBool (text string) bool { switch(strings.ToLower(text)) { case "yes": return true case "no": return false case "true": return true case "false": return false } return false }