hopp/generate/generate.go

1269 lines
44 KiB
Go

package generate
import "io"
import "fmt"
import "maps"
import "math"
import "slices"
import "strings"
import "encoding/hex"
import "git.tebibyte.media/sashakoshka/hopp/tape"
const imports =
`
import "git.tebibyte.media/sashakoshka/hopp"
import "git.tebibyte.media/sashakoshka/hopp/tape"
`
const preamble = `
// Code generated by the Holanet PDL compiler. DO NOT EDIT.
// The source file is located at <path>
// Please edit that file instead, and re-compile it to this location.
// HOPP, TAPE, METADAPT, PDL/0 (c) 2025 holanet.xyz
`
const static = `
// Table is a KTV table with an undefined schema.
type Table = map[uint16] any
// Message is any message that can be sent along this protocol.
type Message interface {
tape.Encodable
tape.Decodable
// Method returns the method code of the message.
Method() uint16
}
// canAssign determines if data from the given source tag can be assigned to
// a Go type represented by destination. It is designed to receive destination
// values from [generate.Generator.generateCanAssign]. The eventual Go type and
// the destination tag must come from the same (or hash-equivalent) PDL type.
func canAssign(destination, source tape.Tag) bool {
if destination.Is(source) { return true }
if (destination.Is(tape.SBA) || destination.Is(tape.LBA)) &&
(source.Is(tape.SBA) || source.Is(tape.LBA)) {
return true
}
return false
}
// boolInt converts a bool to an integer.
func boolInt(input bool) int {
if input {
return 1
} else {
return 0
}
}
var _ hopp.Option[int]
`
// Generator converts protocols into Go code.
type Generator struct {
// Output is where the generated code will be sent.
Output io.Writer
// PackageName is the package name that will be used in the file. If
// left empty, the default is "protocol".
PackageName string
nestingLevel int
temporaryVar int
protocol *Protocol
decodeBranchRequestQueue []decodeBranchRequest
}
type decodeBranchRequest struct {
hash [16]byte
typ Type
name string
}
func (this *Generator) Generate(protocol *Protocol) (n int, err error) {
this.nestingLevel = 0
this.protocol = protocol
defer func() { this.protocol = nil }()
// preamble and static section
packageName := "protocol"
if this.PackageName != "" {
packageName = this.PackageName
}
nn, err := this.iprintf("package %s\n", packageName)
n += nn; if err != nil { return n, err }
nn, err = this.print(preamble)
n += nn; if err != nil { return n, err }
nn, err = this.print(imports)
n += nn; if err != nil { return n, err }
nn, err = this.print(static)
n += nn; if err != nil { return n, err }
// type definitions
for _, name := range slices.Sorted(maps.Keys(protocol.Types)) {
nn, err := this.generateTypedef(name, protocol.Types[name])
n += nn; if err != nil { return n, err }
}
// messages
for _, method := range slices.Sorted(maps.Keys(protocol.Messages)) {
nn, err := this.generateMessage(method, protocol.Messages[method])
n += nn; if err != nil { return n, err }
}
// request queue
for {
hash, typ, name, ok := this.pullDecodeBranchRequest()
if !ok { break }
nn, err := this.generateDecodeBranch(hash, typ, name)
n += nn; if err != nil { return n, err }
}
return n, nil
}
func (this *Generator) generateTypedef(name string, typedef Typedef) (n int, err error) {
typ := typedef.Type
// type definition
if typedef.Doc == "" {
nn, err := this.iprintf(
"\n// %s represents the protocol data type %s.\n",
name, name)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("\n%s\n", this.formatComment(typedef.Doc))
n += nn; if err != nil { return n, err }
}
nn, err := this.iprintf("type %s ", name)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
// 'Tag' method
// to be honest we probably don't need this method at all
// nn, err = this.iprintf("\n// Tag returns the preferred TAPE tag.\n")
// n += nn; if err != nil { return n, err }
// nn, err = this.iprintf("func (this *%s) Tag() tape.Tag {\n", name)
// n += nn; if err != nil { return n, err }
// this.push()
// nn, err = this.iprintf("return ")
// n += nn; if err != nil { return n, err }
// nn, err = this.generateTag(typ, "(*this)")
// n += nn; if err != nil { return n, err }
// nn, err = this.println()
// n += nn; if err != nil { return n, err }
// this.pop()
// nn, err = this.iprintf("}\n")
// n += nn; if err != nil { return n, err }
// EncodeValue method
nn, err = this.iprintf(
"\n// EncodeValue encodes the value of this type without the " +
"tag. The value is\n// encoded according to the parameters " +
"specified by the tag, if possible.\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"func (this *%s) EncodeValue(encoder *tape.Encoder, tag tape.Tag) (n int, err error) {\n",
name)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("var nn int\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(typ, "(*this)", "tag")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
// DecodeValue method
nn, err = this.iprintf(
"\n // DecodeValue decodes the value of this type without " +
"the tag. The value is\n// decoded according to the " +
"parameters specified by the tag, if possible.\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"func (this *%s) DecodeValue(decoder *tape.Decoder, tag tape.Tag) (n int, err error) {\n",
name)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("var nn int\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if !(")
n += nn; if err != nil { return n, err }
nn, err = this.generateCanAssign(typ, "tag")
n += nn; if err != nil { return n, err }
nn, err = this.printf(") {\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("nn, err = tape.Skim(decoder, tag)\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateDecodeValue(typ, name, "this", "tag")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
return n, nil
}
// generateMessage generates the structure, as well as encoding decoding
// functions for the given message.
func (this *Generator) generateMessage(method uint16, message Message) (n int, err error) {
if message.Doc == "" {
nn, err := this.iprintf(
"\n// %s represents the protocol message M%04X %s.\n",
message.Name, method, message.Name)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("\n%s\n", this.formatComment(message.Doc))
n += nn; if err != nil { return n, err }
}
nn, err := this.iprintf("type %s ", this.resolveMessageName(message.Name))
n += nn; if err != nil { return n, err }
nn, err = this.generateType(message.Type)
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
// Method method
nn, err = this.iprintf("\n// Method returns the message's method number.\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"func(this *%s) Method() uint16 { return 0x%04X }\n",
this.resolveMessageName(message.Name),
method)
n += nn; if err != nil { return n, err }
// Encode method
nn, err = this.iprintf("\n// Encode encodes this message's tag and value.\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"func(this *%s) Encode(encoder *tape.Encoder) (n int, err error) {\n",
this.resolveMessageName(message.Name))
n += nn; if err != nil { return n, err }
this.push()
tagVar, nn, err := this.generateTag(message.Type, "(*this)")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err := encoder.WriteTag(%s)\n", tagVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(message.Type, "(*this)", tagVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
// Decode method
nn, err = this.iprintf("\n// Decode decodes this message's tag and value.\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"func(this *%s) Decode(decoder *tape.Decoder) (n int, err error) {\n",
this.resolveMessageName(message.Name))
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("tag, nn, err := decoder.ReadTag()\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if !(")
n += nn; if err != nil { return n, err }
nn, err = this.generateCanAssign(message.Type, "tag")
n += nn; if err != nil { return n, err }
nn, err = this.printf(") {\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("nn, err = tape.Skim(decoder, tag)\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateDecodeValue(message.Type, this.resolveMessageName(message.Name), "this", "tag")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
return n, nil
}
// generateEncodeValue generates code to encode a value of a specified type. It
// pulls from the variable (or parenthetical statement) specified by
// valueSource, and the value will be encoded according to the tag stored in
// the variable (or parenthetical statement) specified by tagSource.
// the code generated is a BLOCK and expects these variables to be defined:
//
// - encoder *tape.Encoder
// - n int
// - err error
// - nn int
func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource string) (n int, err error) {
switch typ := typ.(type) {
case TypeBool:
// SI: (none)
// SI stores the value in the tag, so we write nothing here
case TypeInt:
// SI: (none)
// LI/LSI: <value: IntN>
if typ.Bits <= 5 {
// SI stores the value in the tag, so we write nothing here
break
}
prefix := "WriteUint"
if typ.Signed {
prefix = "WriteInt"
}
nn, err := this.iprintf("nn, err = encoder.%s%d(", prefix, typ.Bits)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ) // TODO: cast like this for
// every type
n += nn; if err != nil { return n, err }
nn, err = this.printf("(%s))\n", valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeFloat:
// FP: <value: FloatN>
nn, err := this.iprintf("nn, err = encoder.WriteFloat%d(", typ.Bits)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
nn, err = this.printf("(%s))\n", valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeString:
// see TypeBuffer
nn, err := this.generateEncodeValue(TypeBuffer { }, valueSource, tagSource)
n += nn; if err != nil { return n, err }
case TypeBuffer:
// SBA: <data: U8>*
// LBA: <length: UN> <data: U8>*
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s.Is(tape.LBA) {\n", tagSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf(
"nn, err = encoder.WriteUintN(uint64(len(%s)), %s.CN())\n",
valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.Write([]byte(%s))\n", valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeArray:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"nn, err = encoder.WriteUintN(uint64(len(%s)), %s.CN() + 1)\n",
valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("{\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("itemTag := ")
n += nn; if err != nil { return n, err }
nn, err = this.generateTN(typ.Element)
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
// TODO: we don't have to do this for loop for some
// types such as integers because the CN will be the
// same
nn, err = this.iprintf("for _, item := range %s {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("_ = item\n")
n += nn; if err != nil { return n, err }
tagVar, nn, err := this.generateTag(typ.Element, "item")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s.Is(tape.SBA) { continue }\n", tagVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s.CN() > itemTag.CN() { itemTag = %s }\n", tagVar, tagVar)
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if itemTag.Is(tape.SBA) { itemTag += 1 << 5 }\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.WriteTag(itemTag)\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("for _, item := range %s {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.generateEncodeValue(typ.Element, "item", "itemTag")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
case TypeTable:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"nn, err = tape.EncodeAny(encoder, %s, %s)\n",
valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeTableDefined:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf("if %d > tape.MaxStructureLength {\n", len(typ.Fields))
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"nn, err = encoder.WriteUintN(%d, %s.CN() + 1)\n",
len(typ.Fields), tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("{\n")
n += nn; if err != nil { return n, err }
this.push()
for key, field := range typ.Fields {
fieldSource := fmt.Sprintf("%s.%s", valueSource, field.Name)
if field.Option {
nn, err = this.iprintf("if value, ok := %s.Value(); ok {\n", fieldSource)
n += nn; if err != nil { return n, err }
fieldSource = "value"
this.push()
}
nn, err = this.iprintf("nn, err = encoder.WriteUint16(0x%04X)\n", key)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
tagVar, nn, err := this.generateTag(field.Type, fieldSource)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(%s))\n", tagVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(field.Type, fieldSource, tagVar)
n += nn; if err != nil { return n, err }
if field.Option {
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
}
}
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
case TypeNamed:
// WHATEVER: [WHATEVER]
nn, err := this.iprintf("nn, err = %s.EncodeValue(encoder, %s)\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeAny:
// WHATEVER: [WHATEVER]
nn, err := this.iprintf("nn, err = tape.EncodeAny(encoder, %s, %s)\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
default:
panic(fmt.Errorf("unknown type: %T", typ))
}
return n, nil
}
// generateDencodeValue generates code to decode a value of a specified type. It
// overwrites memory pointed to by the variable (or parenthetical statement)
// specified by valueSource, and the value will be encoded according to the tag
// stored in the variable (or parenthetical statement) specified by tagSource.
// the code generated is a BLOCK and expects these variables to be defined:
//
// - decoder *tape.Decoder
// - n int
// - err error
// - nn int
//
// The typeName paramterer is handled in the way described in the documentation
// for [Generator.generateDecodeBranch].
func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagSource string) (n int, err error) {
switch typ := typ.(type) {
case TypeBool:
// SI: (none)
// SI stores the value in the tag
nn, err := this.iprintf("*%s = %s.CN() > 0\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
case TypeInt:
// SI: (none)
// LI/LSI: <value: IntN>
if typ.Bits <= 5 {
// SI stores the value in the tag
if typeName == "" {
nn, err := this.iprintf("*%s = uint8(%s.CN())\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("*%s = %s(%s.CN())\n", valueSource, typeName, tagSource)
n += nn; if err != nil { return n, err }
}
break
}
prefix := "ReadUint"
if typ.Signed {
prefix = "ReadInt"
}
destinationVar := this.newTemporaryVar("destination")
nn, err := this.iprintf("%s, nn, err := decoder.%s%d()\n", destinationVar, prefix, typ.Bits)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
if typeName == "" {
nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar)
n += nn; if err != nil { return n, err }
}
case TypeFloat:
// FP: <value: FloatN>
destinationVar := this.newTemporaryVar("destination")
nn, err := this.iprintf("%s, nn, err := decoder.ReadFloat%d()\n", destinationVar, typ.Bits)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
if typeName == "" {
nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar)
n += nn; if err != nil { return n, err }
}
case TypeString, TypeBuffer:
// SBA: <data: U8>*
// LBA: <length: UN> <data: U8>*
lengthVar := this.newTemporaryVar("length")
nn, err := this.iprintf("var %s uint64\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s.Is(tape.LBA) {\n", tagSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf(
"%s, nn, err = decoder.ReadUintN(int(%s.CN()))\n",
lengthVar, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("} else {\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf(
"%s = uint64(%s.CN())\n",
lengthVar, tagSource)
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("buffer := make([]byte, %s)\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = decoder.Read(buffer)\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
if typeName == "" {
if _, ok := typ.(TypeString); ok {
nn, err = this.iprintf("*%s = string(buffer)\n", valueSource)
n += nn; if err != nil { return n, err }
} else {
nn, err = this.iprintf("*%s = buffer\n", valueSource)
n += nn; if err != nil { return n, err }
}
} else {
nn, err = this.iprintf("*%s = %s(buffer)\n", valueSource, typeName)
n += nn; if err != nil { return n, err }
}
case TypeArray:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
nn, err := this.generateDecodeBranchCall(typ, typeName, valueSource, tagSource)
n += nn; if err != nil { return n, err }
case TypeTable:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf(
"nn, err = tape.DecodeAnyInto(decoder, %s, %s)\n",
valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeTableDefined:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.generateDecodeBranchCall(typ, typeName, valueSource, tagSource)
n += nn; if err != nil { return n, err }
case TypeNamed:
// WHATEVER: [WHATEVER]
nn, err := this.iprintf("nn, err = %s.DecodeValue(decoder, %s)\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeAny:
// WHATEVER: [WHATEVER]
nn, err := this.iprintf("*%s, nn, err = tape.DecodeAny(decoder, %s)\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
default:
panic(fmt.Errorf("unknown type: %T", typ))
}
return n, nil
}
// generateDecodeBranchCall generates code to call an aggregate decoder function,
// for a specified type. The definition of the function is deferred so no
// duplicates are created. The function overwrites memory pointed to by the
// variable (or parenthetical statement) specified by valueSource, and the value
// will be encoded according to the tag stored in the variable (or parenthetical
// statement) specified by tagSource. the code generated is a BLOCK and expects
// these variables to be defined:
//
// - decoder *tape.Decoder
// - n int
// - err error
// - nn int
//
// The typeName paramterer is handled in the way described in the documentation
// for [Generator.generateDecodeBranch].
func (this *Generator) generateDecodeBranchCall(typ Type, typeName, valueSource, tagSource string) (n int, err error) {
hash := HashType(typ)
nn, err := this.iprintf(
"nn, err = %s(%s, decoder, %s)\n",
this.decodeBranchName(hash, typeName), valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
this.pushDecodeBranchRequest(hash, typ, typeName)
return n, nil
}
// generateDecodeBranch generates an aggregate decoder function definition for a
// specified type. It assumes that hash == HashType(typ). If typeName is not
// empty, it will be used as the type in the argument list instead of the result
// of [Generator.generateType].
func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName string) (n int, err error) {
nn, err := this.iprintf("\nfunc %s(this *", this.decodeBranchName(hash, typeName))
n += nn; if err != nil { return n, err }
if typeName == "" {
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
} else {
nn, err = this.print(typeName)
n += nn; if err != nil { return n, err }
}
nn, err = this.printf(", decoder *tape.Decoder, tag tape.Tag) (n int, err error) {\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("var nn int\n")
n += nn; if err != nil { return n, err }
switch typ := typ.(type) {
case TypeArray:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
// read header
lengthVar := this.newTemporaryVar("length")
nn, err := this.iprintf("var %s uint64\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()) + 1)\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
elementTagVar := this.newTemporaryVar("elementTag")
nn, err = this.iprintf("var %s tape.Tag\n", elementTagVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadTag()\n", elementTagVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
// abort macro
abort := func() (n int, err error) {
// skim entire array
nn, err = this.iprintf("for _ = range %s {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("nn, err = tape.Skim(decoder, %s)\n", elementTagVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
return n, nil
}
// validate header
// TODO: here, validate that length is less than the
// max, whatever that is configured to be. the reason we
// want to read it here is that we would have to skip
// the tag anyway so why not.
nn, err = this.iprintf("if !(")
n += nn; if err != nil { return n, err }
nn, err = this.generateCanAssign(typ.Element, elementTagVar)
n += nn; if err != nil { return n, err }
nn, err = this.printf(") {\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = abort()
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
// decode payloads
nn, err = this.iprintf("*this = make(")
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
nn, err = this.printf(", %s)\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("for index := range int(%s) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.generateDecodeValue(typ.Element, "", "(&(*this)[index])", elementTagVar)
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
case TypeTableDefined:
// TODO: options
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
// read header
lengthVar := this.newTemporaryVar("length")
nn, err := this.iprintf("var %s uint64\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()) + 1)\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
// validate header
// TODO: here, validate that length is less than the
// max, whatever that is configured to be. if not, stop
// ALL decoding. skimming huge big ass data could cause
// problems
// read fields
nn, err = this.iprintf("for _ = range %s {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
// read field header
fieldKeyVar := this.newTemporaryVar("fieldKey")
nn, err = this.iprintf("var %s uint16\n", fieldKeyVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadUint16()\n", fieldKeyVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
fieldTagVar := this.newTemporaryVar("fieldTag")
nn, err = this.iprintf("var %s tape.Tag\n", fieldTagVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadTag()\n", fieldTagVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
// abort field macro
abortField := func() (n int, err error) {
nn, err = this.iprintf("tape.Skim(decoder, %s)\n", fieldTagVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("continue\n")
n += nn; if err != nil { return n, err }
return n, nil
}
// switch on tag
nn, err = this.iprintf("switch %s {\n", fieldKeyVar)
n += nn; if err != nil { return n, err }
for _, key := range slices.Sorted(maps.Keys(typ.Fields)) {
field := typ.Fields[key]
nn, err = this.iprintf("case 0x%04X:\n", key)
n += nn; if err != nil { return n, err }
this.push()
// validate field header
nn, err = this.iprintf("if !(")
n += nn; if err != nil { return n, err }
nn, err = this.generateCanAssign(field.Type, fieldTagVar)
n += nn; if err != nil { return n, err }
nn, err = this.printf(") {\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = abortField()
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
// decode payload
if field.Option {
destination := this.newTemporaryVar("destination")
nn, err = this.iprintf("var %s ", destination)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(field.Type)
n += nn; if err != nil { return n, err }
nn, err = this.printf("\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateDecodeValue(
field.Type, "", fmt.Sprintf("(&%s)", destination), fieldTagVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("this.%s = hopp.O(%s)\n", field.Name, destination)
n += nn; if err != nil { return n, err }
} else {
nn, err = this.generateDecodeValue(
field.Type, "",
fmt.Sprintf("(&(this.%s))", field.Name), fieldTagVar)
n += nn; if err != nil { return n, err }
}
this.pop()
}
nn, err = this.iprintf("default:\n")
n += nn; if err != nil { return n, err }
this.push()
abortField()
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
default: return n, fmt.Errorf("unexpected type: %T", typ)
}
nn, err = this.iprintf("return n, nil\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
return n, nil
}
func (this *Generator) decodeBranchName(hash [16]byte, name string) string {
if name == "" {
return fmt.Sprintf("decodeBranch_%s", hex.EncodeToString(hash[:]))
} else {
return fmt.Sprintf("decodeBranch_%s_%s", hex.EncodeToString(hash[:]), name)
}
}
// pushDecodeBranchRequest pushes a new branch decode function request to the
// back of the queue, if it is not already in the queue.
func (this *Generator) pushDecodeBranchRequest(hash [16]byte, typ Type, name string) {
for _, item := range this.decodeBranchRequestQueue {
if item.hash == hash && item.name == name { return }
}
this.decodeBranchRequestQueue = append(this.decodeBranchRequestQueue, decodeBranchRequest {
hash: hash,
typ: typ,
name: name,
})
}
// pullDecodeBranchRequest pulls a branch decode function request from the front
// of the queue.
func (this *Generator) pullDecodeBranchRequest() (hash [16]byte, typ Type, name string, ok bool) {
if len(this.decodeBranchRequestQueue) < 1 {
return [16]byte { }, nil, "", false
}
request := this.decodeBranchRequestQueue[0]
this.decodeBranchRequestQueue = this.decodeBranchRequestQueue[1:]
return request.hash, request.typ, request.name, true
}
func (this *Generator) generateErrorCheck() (n int, err error) {
return this.iprintf("n += nn; if err != nil { return n, err }\n")
}
func (this *Generator) generateBareErrorCheck() (n int, err error) {
return this.iprintf("if err != nil { return n, err }\n")
}
// generateTag generates the preferred TN and CN for the given type and value.
// The generated code is a BLOCK.
func (this *Generator) generateTag(typ Type, source string) (tagVar string, n int, err error) {
tagVar = this.newTemporaryVar("tag")
switch typ := typ.(type) {
case TypeBool:
nn, err := this.iprintf("%s := tape.SI.WithCN(boolInt(bool(%s)))\n", tagVar, source)
n += nn; if err != nil { return tagVar, n, err }
case TypeInt:
if typ.Bits <= 5 {
nn, err := this.iprintf("%s := tape.SI.WithCN(int(%s))\n", tagVar, source)
n += nn; if err != nil { return tagVar, n, err }
} else if typ.Signed {
nn, err := this.iprintf("%s := tape.LSI.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits))
n += nn; if err != nil { return tagVar, n, err }
} else {
nn, err := this.iprintf("%s := tape.LI.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits))
n += nn; if err != nil { return tagVar, n, err }
}
case TypeFloat:
nn, err := this.iprintf("%s := tape.FP.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits))
n += nn; if err != nil { return tagVar, n, err }
case TypeString:
nn, err := this.iprintf("%s := tape.StringTag(string(%s))\n", tagVar, source)
n += nn; if err != nil { return tagVar, n, err }
case TypeBuffer:
nn, err := this.iprintf("%s := tape.BufferTag([]byte(%s))\n", tagVar, source)
n += nn; if err != nil { return tagVar, n, err }
case TypeArray:
nn, err := this.iprintf("%s := tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))) - 1)\n", tagVar, source)
n += nn; if err != nil { return tagVar, n, err }
case TypeTable:
nn, err := this.iprintf("%s := tape.KTV.WithCN(tape.IntBytes(uint64(len(%s))) - 1)\n", tagVar, source)
n += nn; if err != nil { return tagVar, n, err }
case TypeTableDefined:
nn, err := this.iprintf("%s := tape.KTV.WithCN(%d)\n", tagVar, tape.IntBytes(uint64(len(typ.Fields))) - 1)
n += nn; if err != nil { return tagVar, n, err }
case TypeNamed:
resolved, err := this.resolveTypeName(typ.Name)
if err != nil { return tagVar, n, err }
subTagVar, nn, err := this.generateTag(resolved, source)
n += nn; if err != nil { return tagVar, n, err }
tagVar = subTagVar
case TypeAny:
nn, err := this.iprintf("%s, err := tape.TagAny(%s)\n", tagVar, source)
n += nn; if err != nil { return tagVar, n, err }
nn, err = this.generateBareErrorCheck()
n += nn; if err != nil { return tagVar, n, err }
default:
panic(fmt.Errorf("unknown type: %T", typ))
}
return tagVar, n, nil
}
// generateTN generates the appropriate TN for the given type. The generated
// code is INLINE. The generated tag will have a CN as zero. For types that
// change TN based on their length, the TN capable of supporting more
// information is chosen.
func (this *Generator) generateTN(typ Type) (n int, err error) {
switch typ := typ.(type) {
case TypeBool:
nn, err := this.printf("tape.SI")
n += nn; if err != nil { return n, err }
case TypeInt:
if typ.Bits <= 5 {
nn, err := this.printf("tape.SI")
n += nn; if err != nil { return n, err }
} else if typ.Signed {
nn, err := this.printf("tape.LSI")
n += nn; if err != nil { return n, err }
} else {
nn, err := this.printf("tape.LI")
n += nn; if err != nil { return n, err }
}
case TypeFloat:
nn, err := this.printf("tape.FP",)
n += nn; if err != nil { return n, err }
case TypeString:
nn, err := this.generateTN(TypeBuffer { })
n += nn; if err != nil { return n, err }
case TypeBuffer:
nn, err := this.printf("tape.LBA")
n += nn; if err != nil { return n, err }
case TypeArray:
nn, err := this.printf("tape.OTA")
n += nn; if err != nil { return n, err }
case TypeTable:
nn, err := this.printf("tape.KTV")
n += nn; if err != nil { return n, err }
case TypeTableDefined:
nn, err := this.printf("tape.KTV")
n += nn; if err != nil { return n, err }
case TypeNamed:
resolved, err := this.resolveTypeName(typ.Name)
if err != nil { return n, err }
nn, err := this.generateTN(resolved)
n += nn; if err != nil { return n, err }
default:
panic(fmt.Errorf("unknown type: %T", typ))
}
return n, nil
}
func (this *Generator) generateType(typ Type) (n int, err error) {
switch typ := typ.(type) {
case TypeBool:
nn, err := this.printf("bool")
n += nn; if err != nil { return n, err }
case TypeInt:
if err := this.validateIntBitSize(typ.Bits); err != nil {
return n, err
}
if typ.Bits <= 5 {
nn, err := this.printf("uint8")
n += nn; if err != nil { return n, err }
break
}
if typ.Signed {
nn, err := this.printf("int%d", typ.Bits)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.printf("uint%d", typ.Bits)
n += nn; if err != nil { return n, err }
}
case TypeFloat:
switch typ.Bits {
case 16:
nn, err := this.print("float32")
n += nn; if err != nil { return n, err }
case 32, 64:
nn, err := this.printf("float%d", typ.Bits)
n += nn; if err != nil { return n, err }
default:
return n, fmt.Errorf("floats of size %d are unsupported on this platform", typ.Bits)
}
case TypeString:
nn, err := this.print("string")
n += nn; if err != nil { return n, err }
case TypeBuffer:
nn, err := this.print("[]byte")
n += nn; if err != nil { return n, err }
case TypeArray:
nn, err := this.print("[]")
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ.Element)
n += nn; if err != nil { return n, err }
case TypeTable:
nn, err := this.print("Table")
n += nn; if err != nil { return n, err }
case TypeTableDefined:
nn, err := this.generateTypeTableDefined(typ)
n += nn; if err != nil { return n, err }
case TypeNamed:
nn, err := this.print(typ.Name)
n += nn; if err != nil { return n, err }
case TypeAny:
nn, err := this.print("any")
n += nn; if err != nil { return n, err }
default:
panic(fmt.Errorf("unknown type: %T", typ))
}
return n, nil
}
func (this *Generator) generateTypeTableDefined(typ TypeTableDefined) (n int, err error) {
nn, err := this.print("struct {\n")
n += nn; if err != nil { return n, err }
this.push()
for _, key := range slices.Sorted(maps.Keys(typ.Fields)) {
field := typ.Fields[key]
nn, err := this.iprintf("%s\n", this.formatComment(field.Doc))
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s ", field.Name)
n += nn; if err != nil { return n, err }
if field.Option {
nn, err = this.print("hopp.Option[")
n += nn; if err != nil { return n, err }
}
nn, err = this.generateType(field.Type)
n += nn; if err != nil { return n, err }
if field.Option {
nn, err = this.print("]")
n += nn; if err != nil { return n, err }
}
nn, err = this.print("\n")
n += nn; if err != nil { return n, err }
}
this.pop()
nn, err = this.iprint("}")
n += nn; if err != nil { return n, err }
return n, nil
}
// generateCanAssign generates an expression which checks if the tag specified
// by tagSource can be assigned to a Go destination generated from typ. The
// generated code is INLINE.
func (this *Generator) generateCanAssign(typ Type, tagSource string) (n int, err error) {
if _, ok := typ.(TypeAny); ok {
nn, err := this.printf("true")
n += nn; if err != nil { return n, err }
} else {
nn, err := this.printf("canAssign(")
n += nn; if err != nil { return n, err }
nn, err = this.generateTN(typ)
n += nn; if err != nil { return n, err }
nn, err = this.printf(", %s)", tagSource)
n += nn; if err != nil { return n, err }
}
return n, nil
}
func (this *Generator) validateIntBitSize(size int) error {
switch size {
case 5, 8, 16, 32, 64: return nil
default: return fmt.Errorf("integers of size %d are unsupported on this platform", size)
}
}
func (this *Generator) validateFloatBitSize(size int) error {
switch size {
case 16, 32, 64: return nil
default: return fmt.Errorf("floats of size %d are unsupported on this platform", size)
}
}
func (this *Generator) push() {
this.nestingLevel ++
}
func (this *Generator) pop() {
if this.nestingLevel < 1 {
panic("cannot pop when nesting level is less than 1")
}
this.nestingLevel --
}
func (this *Generator) indent() string {
return strings.Repeat("\t", this.nestingLevel)
}
func (this *Generator) print(args ...any) (n int, err error) {
return fmt.Fprint(this.Output, args...)
}
func (this *Generator) println(args ...any) (n int, err error) {
return fmt.Fprintln(this.Output, args...)
}
func (this *Generator) printf(format string, args ...any) (n int, err error) {
return fmt.Fprintf(this.Output, format, args...)
}
func (this *Generator) iprint(args ...any) (n int, err error) {
return fmt.Fprint(this.Output, this.indent() + fmt.Sprint(args...))
}
func (this *Generator) iprintln(args ...any) (n int, err error) {
return fmt.Fprintln(this.Output, this.indent() + fmt.Sprint(args...))
}
func (this *Generator) iprintf(format string, args ...any) (n int, err error) {
return fmt.Fprintf(this.Output, this.indent() + format, args...)
}
func (this *Generator) formatComment(comment string) string {
return "// " + strings.ReplaceAll(comment, "\n", "\n" + this.indent() + "// ")
}
func (this *Generator) resolveMessageName(message string) string {
return "Message" + message
}
func (this *Generator) resolveTypeName(name string) (Type, error) {
if typedef, ok := this.protocol.Types[name]; ok {
if typ, ok := typedef.Type.(TypeNamed); ok {
return this.resolveTypeName(typ.Name)
}
return typedef.Type, nil
}
return nil, fmt.Errorf("no type exists called %s", name)
}
func (this *Generator) newTemporaryVar(base string) string {
this.temporaryVar += 1
return fmt.Sprintf("%s_%d", base, this.temporaryVar)
}
func bitsToBytes(bits int) int {
return int(math.Ceil(float64(bits) / 8.0))
}
func bitsToCN(bits int) int {
return bitsToBytes(bits) - 1
}