package generate import "io" import "fmt" import "maps" import "math" import "slices" import "strings" import "encoding/hex" import "git.tebibyte.media/sashakoshka/hopp/tape" const imports = ` import "git.tebibyte.media/sashakoshka/hopp/tape" ` const preamble = ` // Code generated by the Holanet PDL compiler. DO NOT EDIT. // The source file is located at // Please edit that file instead, and re-compile it to this location. // HOPP, TAPE, METADAPT, PDL/0 (c) 2025 holanet.xyz ` const static = ` // Table is a KTV table with an undefined schema. type Table = map[uint16] any // Message is any message that can be sent along this protocol. type Message interface { tape.Encodable tape.Decodable // Method returns the method code of the message. Method() uint16 } // canAssign determines if data from the given source tag can be assigned to // a Go type represented by destination. It is designed to receive destination // values from [generate.Generator.generateCanAssign]. The eventual Go type and // the destination tag must come from the same (or hash-equivalent) PDL type. func canAssign(destination, source tape.Tag) bool { if destination.Is(source) { return true } if (destination.Is(tape.SBA) || destination.Is(tape.LBA)) && (source.Is(tape.SBA) || source.Is(tape.LBA)) { return true } return false } // boolInt converts a bool to an integer. func boolInt(input bool) int { if input { return 1 } else { return 0 } } ` // Generator converts protocols into Go code. type Generator struct { // Output is where the generated code will be sent. Output io.Writer // PackageName is the package name that will be used in the file. If // left empty, the default is "protocol". PackageName string nestingLevel int temporaryVar int protocol *Protocol decodeBranchRequestQueue []decodeBranchRequest } type decodeBranchRequest struct { hash [16]byte typ Type name string } func (this *Generator) Generate(protocol *Protocol) (n int, err error) { this.nestingLevel = 0 this.protocol = protocol defer func() { this.protocol = nil }() // preamble and static section packageName := "protocol" if this.PackageName != "" { packageName = this.PackageName } nn, err := this.iprintf("package %s\n", packageName) n += nn; if err != nil { return n, err } nn, err = this.print(preamble) n += nn; if err != nil { return n, err } nn, err = this.print(imports) n += nn; if err != nil { return n, err } nn, err = this.print(static) n += nn; if err != nil { return n, err } // type definitions for _, name := range slices.Sorted(maps.Keys(protocol.Types)) { nn, err := this.generateTypedef(name, protocol.Types[name]) n += nn; if err != nil { return n, err } } // messages for _, method := range slices.Sorted(maps.Keys(protocol.Messages)) { nn, err := this.generateMessage(method, protocol.Messages[method]) n += nn; if err != nil { return n, err } } // request queue for { hash, typ, name, ok := this.pullDecodeBranchRequest() if !ok { break } nn, err := this.generateDecodeBranch(hash, typ, name) n += nn; if err != nil { return n, err } } return n, nil } func (this *Generator) generateTypedef(name string, typedef Typedef) (n int, err error) { typ := typedef.Type // type definition if typedef.Doc == "" { nn, err := this.iprintf( "\n// %s represents the protocol data type %s.\n", name, name) n += nn; if err != nil { return n, err } } else { nn, err := this.iprintf("\n%s\n", this.formatComment(typedef.Doc)) n += nn; if err != nil { return n, err } } nn, err := this.iprintf("type %s ", name) n += nn; if err != nil { return n, err } nn, err = this.generateType(typ) n += nn; if err != nil { return n, err } nn, err = this.println() n += nn; if err != nil { return n, err } // 'Tag' method // to be honest we probably don't need this method at all // nn, err = this.iprintf("\n// Tag returns the preferred TAPE tag.\n") // n += nn; if err != nil { return n, err } // nn, err = this.iprintf("func (this *%s) Tag() tape.Tag {\n", name) // n += nn; if err != nil { return n, err } // this.push() // nn, err = this.iprintf("return ") // n += nn; if err != nil { return n, err } // nn, err = this.generateTag(typ, "(*this)") // n += nn; if err != nil { return n, err } // nn, err = this.println() // n += nn; if err != nil { return n, err } // this.pop() // nn, err = this.iprintf("}\n") // n += nn; if err != nil { return n, err } // EncodeValue method nn, err = this.iprintf( "\n// EncodeValue encodes the value of this type without the " + "tag. The value is\n// encoded according to the parameters " + "specified by the tag, if possible.\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf( "func (this *%s) EncodeValue(encoder *tape.Encoder, tag tape.Tag) (n int, err error) {\n", name) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("var nn int\n") n += nn; if err != nil { return n, err } nn, err = this.generateEncodeValue(typ, "(*this)", "tag") n += nn; if err != nil { return n, err } nn, err = this.iprintf("return n, nil\n") n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } // DecodeValue method nn, err = this.iprintf( "\n // DecodeValue decodes the value of this type without " + "the tag. The value is\n// decoded according to the " + "parameters specified by the tag, if possible.\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf( "func (this *%s) DecodeValue(decoder *tape.Decoder, tag tape.Tag) (n int, err error) {\n", name) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("var nn int\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("if !(") n += nn; if err != nil { return n, err } nn, err = this.generateCanAssign(typ, "tag") n += nn; if err != nil { return n, err } nn, err = this.printf(") {\n") n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("nn, err = tape.Skim(decoder, tag)\n") n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } nn, err = this.iprintf("return n, nil\n") n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.generateDecodeValue(typ, name, "this", "tag") n += nn; if err != nil { return n, err } nn, err = this.iprintf("return n, nil\n") n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } return n, nil } // generateMessage generates the structure, as well as encoding decoding // functions for the given message. func (this *Generator) generateMessage(method uint16, message Message) (n int, err error) { if message.Doc == "" { nn, err := this.iprintf( "\n// %s represents the protocol message M%04X %s.\n", message.Name, method, message.Name) n += nn; if err != nil { return n, err } } else { nn, err := this.iprintf("\n%s\n", this.formatComment(message.Doc)) n += nn; if err != nil { return n, err } } nn, err := this.iprintf("type %s ", this.resolveMessageName(message.Name)) n += nn; if err != nil { return n, err } nn, err = this.generateType(message.Type) n += nn; if err != nil { return n, err } nn, err = this.println() n += nn; if err != nil { return n, err } // Method method nn, err = this.iprintf("\n// Method returns the message's method number.\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf( "func(this *%s) Method() uint16 { return 0x%04X }\n", this.resolveMessageName(message.Name), method) n += nn; if err != nil { return n, err } // Encode method nn, err = this.iprintf("\n// Encode encodes this message's tag and value.\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf( "func(this *%s) Encode(encoder *tape.Encoder) (n int, err error) {\n", this.resolveMessageName(message.Name)) n += nn; if err != nil { return n, err } this.push() tagVar, nn, err := this.generateTag(message.Type, "(*this)") n += nn; if err != nil { return n, err } nn, err = this.iprintf("nn, err := encoder.WriteTag(%s)\n", tagVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } nn, err = this.generateEncodeValue(message.Type, "(*this)", tagVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("return n, nil\n") n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } // Decode method nn, err = this.iprintf("\n// Decode decodes this message's tag and value.\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf( "func(this *%s) Decode(decoder *tape.Decoder) (n int, err error) {\n", this.resolveMessageName(message.Name)) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("tag, nn, err := decoder.ReadTag()\n") n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } nn, err = this.iprintf("if !(") n += nn; if err != nil { return n, err } nn, err = this.generateCanAssign(message.Type, "tag") n += nn; if err != nil { return n, err } nn, err = this.printf(") {\n") n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("nn, err = tape.Skim(decoder, tag)\n") n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } nn, err = this.iprintf("return n, nil\n") n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.generateDecodeValue(message.Type, this.resolveMessageName(message.Name), "this", "tag") n += nn; if err != nil { return n, err } nn, err = this.iprintf("return n, nil\n") n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } return n, nil } // generateEncodeValue generates code to encode a value of a specified type. It // pulls from the variable (or parenthetical statement) specified by // valueSource, and the value will be encoded according to the tag stored in // the variable (or parenthetical statement) specified by tagSource. // the code generated is a BLOCK and expects these variables to be defined: // // - encoder *tape.Encoder // - n int // - err error // - nn int func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource string) (n int, err error) { switch typ := typ.(type) { case TypeBool: // SI: (none) // SI stores the value in the tag, so we write nothing here case TypeInt: // SI: (none) // LI/LSI: if typ.Bits <= 5 { // SI stores the value in the tag, so we write nothing here break } prefix := "WriteUint" if typ.Signed { prefix = "WriteInt" } nn, err := this.iprintf("nn, err = encoder.%s%d(", prefix, typ.Bits) n += nn; if err != nil { return n, err } nn, err = this.generateType(typ) // TODO: cast like this for // every type n += nn; if err != nil { return n, err } nn, err = this.printf("(%s))\n", valueSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } case TypeFloat: // FP: nn, err := this.iprintf("nn, err = encoder.WriteFloat%d(", typ.Bits) n += nn; if err != nil { return n, err } nn, err = this.generateType(typ) n += nn; if err != nil { return n, err } nn, err = this.printf("(%s))\n", valueSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } case TypeString: // see TypeBuffer nn, err := this.generateEncodeValue(TypeBuffer { }, valueSource, tagSource) n += nn; if err != nil { return n, err } case TypeBuffer: // SBA: * // LBA: * nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("return n, tape.ErrTooLong\n") this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("if %s.Is(tape.LBA) {\n", tagSource) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf( "nn, err = encoder.WriteUintN(uint64(len(%s)), %s.CN())\n", valueSource, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("nn, err = encoder.Write([]byte(%s))\n", valueSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } case TypeArray: // OTA: * nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("return n, tape.ErrTooLong\n") this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf( "nn, err = encoder.WriteUintN(uint64(len(%s)), %s.CN() + 1)\n", valueSource, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } nn, err = this.iprintf("{\n") n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("itemTag := ") n += nn; if err != nil { return n, err } nn, err = this.generateTN(typ.Element) n += nn; if err != nil { return n, err } nn, err = this.println() n += nn; if err != nil { return n, err } // TODO: we don't have to do this for loop for some // types such as integers because the CN will be the // same nn, err = this.iprintf("for _, item := range %s {\n", valueSource) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("_ = item\n") n += nn; if err != nil { return n, err } tagVar, nn, err := this.generateTag(typ.Element, "item") n += nn; if err != nil { return n, err } nn, err = this.iprintf("if %s.Is(tape.SBA) { continue }\n", tagVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("if %s.CN() > itemTag.CN() { itemTag = %s }\n", tagVar, tagVar) n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("if itemTag.Is(tape.SBA) { itemTag += 1 << 5 }\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("nn, err = encoder.WriteTag(itemTag)\n") n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } nn, err = this.iprintf("for _, item := range %s {\n", valueSource) n += nn; if err != nil { return n, err } this.push() nn, err = this.generateEncodeValue(typ.Element, "item", "itemTag") n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } case TypeTable: // KTV: ( )* nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("return n, tape.ErrTooLong\n") this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf( "nn, err = tape.EncodeAny(encoder, %s, %s)\n", valueSource, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } case TypeTableDefined: // KTV: ( )* nn, err := this.iprintf("if %d > tape.MaxStructureLength {\n", len(typ.Fields)) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("return n, tape.ErrTooLong\n") this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf( "nn, err = encoder.WriteUintN(%d, %s.CN() + 1)\n", len(typ.Fields), tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } nn, err = this.iprintf("{\n") n += nn; if err != nil { return n, err } this.push() for key, field := range typ.Fields { nn, err = this.iprintf("nn, err = encoder.WriteUint16(0x%04X)\n", key) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } fieldSource := fmt.Sprintf("%s.%s", valueSource, field.Name) tagVar, nn, err := this.generateTag(field.Type, fieldSource) n += nn; if err != nil { return n, err } nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(%s))\n", tagVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } nn, err = this.generateEncodeValue(field.Type, fieldSource, tagVar) n += nn; if err != nil { return n, err } } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } case TypeNamed: // WHATEVER: [WHATEVER] nn, err := this.iprintf("nn, err = %s.EncodeValue(encoder, %s)\n", valueSource, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } case TypeAny: // WHATEVER: [WHATEVER] nn, err := this.iprintf("nn, err = tape.EncodeAny(encoder, %s, %s)\n", valueSource, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } default: panic(fmt.Errorf("unknown type: %T", typ)) } return n, nil } // generateDencodeValue generates code to decode a value of a specified type. It // overwrites memory pointed to by the variable (or parenthetical statement) // specified by valueSource, and the value will be encoded according to the tag // stored in the variable (or parenthetical statement) specified by tagSource. // the code generated is a BLOCK and expects these variables to be defined: // // - decoder *tape.Decoder // - n int // - err error // - nn int // // The typeName paramterer is handled in the way described in the documentation // for [Generator.generateDecodeBranch]. func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagSource string) (n int, err error) { switch typ := typ.(type) { case TypeBool: // SI: (none) // SI stores the value in the tag nn, err := this.iprintf("*%s = %s.CN() > 0\n", valueSource, tagSource) n += nn; if err != nil { return n, err } case TypeInt: // SI: (none) // LI/LSI: if typ.Bits <= 5 { // SI stores the value in the tag if typeName == "" { nn, err := this.iprintf("*%s = uint8(%s.CN())\n", valueSource, tagSource) n += nn; if err != nil { return n, err } } else { nn, err := this.iprintf("*%s = %s(%s.CN())\n", valueSource, typeName, tagSource) n += nn; if err != nil { return n, err } } break } prefix := "ReadUint" if typ.Signed { prefix = "ReadInt" } destinationVar := this.newTemporaryVar("destination") nn, err := this.iprintf("var %s ", destinationVar) n += nn; if err != nil { return n, err } nn, err = this.generateType(typ) n += nn; if err != nil { return n, err } nn, err = this.print("\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s, nn, err = decoder.%s%d()\n", destinationVar, prefix, typ.Bits) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } if typeName == "" { nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar) n += nn; if err != nil { return n, err } } else { nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar) n += nn; if err != nil { return n, err } } case TypeFloat: // FP: destinationVar := this.newTemporaryVar("destination") nn, err := this.iprintf("var %s ", destinationVar) n += nn; if err != nil { return n, err } nn, err = this.generateType(typ) n += nn; if err != nil { return n, err } nn, err = this.print("\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s, nn, err = decoder.ReadFloat%d()\n", destinationVar, typ.Bits) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } if typeName == "" { nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar) n += nn; if err != nil { return n, err } } else { nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar) n += nn; if err != nil { return n, err } } case TypeString, TypeBuffer: // SBA: * // LBA: * lengthVar := this.newTemporaryVar("length") nn, err := this.iprintf("var %s uint64\n", lengthVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("if %s.Is(tape.LBA) {\n", tagSource) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf( "%s, nn, err = decoder.ReadUintN(int(%s.CN()))\n", lengthVar, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("} else {\n") n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf( "%s = uint64(%s.CN())\n", lengthVar, tagSource) n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("return n, tape.ErrTooLong\n") this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("buffer := make([]byte, %s)\n", lengthVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("nn, err = decoder.Read(buffer)\n") n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } if typeName == "" { if _, ok := typ.(TypeString); ok { nn, err = this.iprintf("*%s = string(buffer)\n", valueSource) n += nn; if err != nil { return n, err } } else { nn, err = this.iprintf("*%s = buffer\n", valueSource) n += nn; if err != nil { return n, err } } } else { nn, err = this.iprintf("*%s = %s(buffer)\n", valueSource, typeName) n += nn; if err != nil { return n, err } } case TypeArray: // OTA: * nn, err := this.generateDecodeBranchCall(typ, typeName, valueSource, tagSource) n += nn; if err != nil { return n, err } case TypeTable: // KTV: ( )* nn, err := this.iprintf( "nn, err = tape.DecodeAnyInto(decoder, %s, %s)\n", valueSource, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } case TypeTableDefined: // KTV: ( )* nn, err := this.generateDecodeBranchCall(typ, typeName, valueSource, tagSource) n += nn; if err != nil { return n, err } case TypeNamed: // WHATEVER: [WHATEVER] nn, err := this.iprintf("nn, err = %s.DecodeValue(decoder, %s)\n", valueSource, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } case TypeAny: // WHATEVER: [WHATEVER] nn, err := this.iprintf("*%s, nn, err = tape.DecodeAny(decoder, %s)\n", valueSource, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } default: panic(fmt.Errorf("unknown type: %T", typ)) } return n, nil } // generateDecodeBranchCall generates code to call an aggregate decoder function, // for a specified type. The definition of the function is deferred so no // duplicates are created. The function overwrites memory pointed to by the // variable (or parenthetical statement) specified by valueSource, and the value // will be encoded according to the tag stored in the variable (or parenthetical // statement) specified by tagSource. the code generated is a BLOCK and expects // these variables to be defined: // // - decoder *tape.Decoder // - n int // - err error // - nn int // // The typeName paramterer is handled in the way described in the documentation // for [Generator.generateDecodeBranch]. func (this *Generator) generateDecodeBranchCall(typ Type, typeName, valueSource, tagSource string) (n int, err error) { hash := HashType(typ) nn, err := this.iprintf( "nn, err = %s(%s, decoder, %s)\n", this.decodeBranchName(hash, typeName), valueSource, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } this.pushDecodeBranchRequest(hash, typ, typeName) return n, nil } // generateDecodeBranch generates an aggregate decoder function definition for a // specified type. It assumes that hash == HashType(typ). If typeName is not // empty, it will be used as the type in the argument list instead of the result // of [Generator.generateType]. func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName string) (n int, err error) { nn, err := this.iprintf("\nfunc %s(this *", this.decodeBranchName(hash, typeName)) n += nn; if err != nil { return n, err } if typeName == "" { nn, err = this.generateType(typ) n += nn; if err != nil { return n, err } } else { nn, err = this.print(typeName) n += nn; if err != nil { return n, err } } nn, err = this.printf(", decoder *tape.Decoder, tag tape.Tag) (n int, err error) {\n") n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("var nn int\n") n += nn; if err != nil { return n, err } switch typ := typ.(type) { case TypeArray: // OTA: * // read header lengthVar := this.newTemporaryVar("length") nn, err := this.iprintf("var %s uint64\n", lengthVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("return n, tape.ErrTooLong\n") this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()) + 1)\n", lengthVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } elementTagVar := this.newTemporaryVar("elementTag") nn, err = this.iprintf("var %s tape.Tag\n", elementTagVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s, nn, err = decoder.ReadTag()\n", elementTagVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } // abort macro abort := func() (n int, err error) { // skim entire array nn, err = this.iprintf("for _ = range %s {\n", lengthVar) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("nn, err = tape.Skim(decoder, %s)\n", elementTagVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("return n, nil\n") n += nn; if err != nil { return n, err } return n, nil } // validate header // TODO: here, validate that length is less than the // max, whatever that is configured to be. the reason we // want to read it here is that we would have to skip // the tag anyway so why not. nn, err = this.iprintf("if !(") n += nn; if err != nil { return n, err } nn, err = this.generateCanAssign(typ.Element, elementTagVar) n += nn; if err != nil { return n, err } nn, err = this.printf(") {\n") n += nn; if err != nil { return n, err } this.push() nn, err = abort() n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } // decode payloads nn, err = this.iprintf("*this = make(") n += nn; if err != nil { return n, err } nn, err = this.generateType(typ) n += nn; if err != nil { return n, err } nn, err = this.printf(", %s)\n", lengthVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("for index := range int(%s) {\n", lengthVar) n += nn; if err != nil { return n, err } this.push() nn, err = this.generateDecodeValue(typ.Element, "", "(&(*this)[index])", elementTagVar) n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } case TypeTableDefined: // KTV: ( )* // read header lengthVar := this.newTemporaryVar("length") nn, err := this.iprintf("var %s uint64\n", lengthVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf("return n, tape.ErrTooLong\n") this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()) + 1)\n", lengthVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } // validate header // TODO: here, validate that length is less than the // max, whatever that is configured to be. if not, stop // ALL decoding. skimming huge big ass data could cause // problems // read fields nn, err = this.iprintf("for _ = range %s {\n", lengthVar) n += nn; if err != nil { return n, err } this.push() // read field header fieldKeyVar := this.newTemporaryVar("fieldKey") nn, err = this.iprintf("var %s uint16\n", fieldKeyVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s, nn, err = decoder.ReadUint16()\n", fieldKeyVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } fieldTagVar := this.newTemporaryVar("fieldTag") nn, err = this.iprintf("var %s tape.Tag\n", fieldTagVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s, nn, err = decoder.ReadTag()\n", fieldTagVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } // abort field macro abortField := func() (n int, err error) { nn, err = this.iprintf("tape.Skim(decoder, %s)\n", fieldTagVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("continue\n") n += nn; if err != nil { return n, err } return n, nil } // switch on tag nn, err = this.iprintf("switch %s {\n", fieldKeyVar) n += nn; if err != nil { return n, err } for _, key := range slices.Sorted(maps.Keys(typ.Fields)) { field := typ.Fields[key] nn, err = this.iprintf("case 0x%04X:\n", key) n += nn; if err != nil { return n, err } this.push() // validate field header nn, err = this.iprintf("if !(") n += nn; if err != nil { return n, err } nn, err = this.generateCanAssign(field.Type, fieldTagVar) n += nn; if err != nil { return n, err } nn, err = this.printf(") {\n") n += nn; if err != nil { return n, err } this.push() nn, err = abortField() n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } // decode payload nn, err = this.generateDecodeValue( field.Type, "", fmt.Sprintf("(&(this.%s))", field.Name), fieldTagVar) n += nn; if err != nil { return n, err } this.pop() } nn, err = this.iprintf("default:\n") n += nn; if err != nil { return n, err } this.push() abortField() this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } // TODO once options are implemented, have a set of // bools for each non-optional field, and check here // that they are all true. a counter will not work // because if someone specifies a non-optional field // twice, they can neglect to specify another // non-optional field and we won't even know because the // count will still be even. we shouldn't use a map // either because its an allocation and its way more // memory than just, like 5 bools (on the stack no less) default: return n, fmt.Errorf("unexpected type: %T", typ) } nn, err = this.iprintf("return n, nil\n") this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } return n, nil } func (this *Generator) decodeBranchName(hash [16]byte, name string) string { if name == "" { return fmt.Sprintf("decodeBranch_%s", hex.EncodeToString(hash[:])) } else { return fmt.Sprintf("decodeBranch_%s_%s", hex.EncodeToString(hash[:]), name) } } // pushDecodeBranchRequest pushes a new branch decode function request to the // back of the queue, if it is not already in the queue. func (this *Generator) pushDecodeBranchRequest(hash [16]byte, typ Type, name string) { for _, item := range this.decodeBranchRequestQueue { if item.hash == hash && item.name == name { return } } this.decodeBranchRequestQueue = append(this.decodeBranchRequestQueue, decodeBranchRequest { hash: hash, typ: typ, name: name, }) } // pullDecodeBranchRequest pulls a branch decode function request from the front // of the queue. func (this *Generator) pullDecodeBranchRequest() (hash [16]byte, typ Type, name string, ok bool) { if len(this.decodeBranchRequestQueue) < 1 { return [16]byte { }, nil, "", false } request := this.decodeBranchRequestQueue[0] this.decodeBranchRequestQueue = this.decodeBranchRequestQueue[1:] return request.hash, request.typ, request.name, true } func (this *Generator) generateErrorCheck() (n int, err error) { return this.iprintf("n += nn; if err != nil { return n, err }\n") } func (this *Generator) generateBareErrorCheck() (n int, err error) { return this.iprintf("if err != nil { return n, err }\n") } // generateTag generates the preferred TN and CN for the given type and value. // The generated code is a BLOCK. func (this *Generator) generateTag(typ Type, source string) (tagVar string, n int, err error) { tagVar = this.newTemporaryVar("tag") switch typ := typ.(type) { case TypeBool: nn, err := this.iprintf("%s := tape.SI.WithCN(boolInt(bool(%s)))\n", tagVar, source) n += nn; if err != nil { return tagVar, n, err } case TypeInt: if typ.Bits <= 5 { nn, err := this.iprintf("%s := tape.SI.WithCN(int(%s))\n", tagVar, source) n += nn; if err != nil { return tagVar, n, err } } else if typ.Signed { nn, err := this.iprintf("%s := tape.LSI.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits)) n += nn; if err != nil { return tagVar, n, err } } else { nn, err := this.iprintf("%s := tape.LI.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits)) n += nn; if err != nil { return tagVar, n, err } } case TypeFloat: nn, err := this.iprintf("%s := tape.FP.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits)) n += nn; if err != nil { return tagVar, n, err } case TypeString: nn, err := this.iprintf("%s := tape.StringTag(string(%s))\n", tagVar, source) n += nn; if err != nil { return tagVar, n, err } case TypeBuffer: nn, err := this.iprintf("%s := tape.BufferTag([]byte(%s))\n", tagVar, source) n += nn; if err != nil { return tagVar, n, err } case TypeArray: nn, err := this.iprintf("%s := tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))) - 1)\n", tagVar, source) n += nn; if err != nil { return tagVar, n, err } case TypeTable: nn, err := this.iprintf("%s := tape.KTV.WithCN(tape.IntBytes(uint64(len(%s))) - 1)\n", tagVar, source) n += nn; if err != nil { return tagVar, n, err } case TypeTableDefined: nn, err := this.iprintf("%s := tape.KTV.WithCN(%d)\n", tagVar, tape.IntBytes(uint64(len(typ.Fields))) - 1) n += nn; if err != nil { return tagVar, n, err } case TypeNamed: resolved, err := this.resolveTypeName(typ.Name) if err != nil { return tagVar, n, err } subTagVar, nn, err := this.generateTag(resolved, source) n += nn; if err != nil { return tagVar, n, err } tagVar = subTagVar case TypeAny: nn, err := this.iprintf("%s, err := tape.TagAny(%s)\n", tagVar, source) n += nn; if err != nil { return tagVar, n, err } nn, err = this.generateBareErrorCheck() n += nn; if err != nil { return tagVar, n, err } default: panic(fmt.Errorf("unknown type: %T", typ)) } return tagVar, n, nil } // generateTN generates the appropriate TN for the given type. The generated // code is INLINE. The generated tag will have a CN as zero. For types that // change TN based on their length, the TN capable of supporting more // information is chosen. func (this *Generator) generateTN(typ Type) (n int, err error) { switch typ := typ.(type) { case TypeBool: nn, err := this.printf("tape.SI") n += nn; if err != nil { return n, err } case TypeInt: if typ.Bits <= 5 { nn, err := this.printf("tape.SI") n += nn; if err != nil { return n, err } } else if typ.Signed { nn, err := this.printf("tape.LSI") n += nn; if err != nil { return n, err } } else { nn, err := this.printf("tape.LI") n += nn; if err != nil { return n, err } } case TypeFloat: nn, err := this.printf("tape.FP",) n += nn; if err != nil { return n, err } case TypeString: nn, err := this.generateTN(TypeBuffer { }) n += nn; if err != nil { return n, err } case TypeBuffer: nn, err := this.printf("tape.LBA") n += nn; if err != nil { return n, err } case TypeArray: nn, err := this.printf("tape.OTA") n += nn; if err != nil { return n, err } case TypeTable: nn, err := this.printf("tape.KTV") n += nn; if err != nil { return n, err } case TypeTableDefined: nn, err := this.printf("tape.KTV") n += nn; if err != nil { return n, err } case TypeNamed: resolved, err := this.resolveTypeName(typ.Name) if err != nil { return n, err } nn, err := this.generateTN(resolved) n += nn; if err != nil { return n, err } default: panic(fmt.Errorf("unknown type: %T", typ)) } return n, nil } func (this *Generator) generateType(typ Type) (n int, err error) { switch typ := typ.(type) { case TypeBool: nn, err := this.printf("bool") n += nn; if err != nil { return n, err } case TypeInt: if err := this.validateIntBitSize(typ.Bits); err != nil { return n, err } if typ.Bits <= 5 { nn, err := this.printf("uint8") n += nn; if err != nil { return n, err } break } if typ.Signed { nn, err := this.printf("int%d", typ.Bits) n += nn; if err != nil { return n, err } } else { nn, err := this.printf("uint%d", typ.Bits) n += nn; if err != nil { return n, err } } case TypeFloat: switch typ.Bits { case 16: nn, err := this.print("float32") n += nn; if err != nil { return n, err } case 32, 64: nn, err := this.printf("float%d", typ.Bits) n += nn; if err != nil { return n, err } default: return n, fmt.Errorf("floats of size %d are unsupported on this platform", typ.Bits) } case TypeString: nn, err := this.print("string") n += nn; if err != nil { return n, err } case TypeBuffer: nn, err := this.print("[]byte") n += nn; if err != nil { return n, err } case TypeArray: nn, err := this.print("[]") n += nn; if err != nil { return n, err } nn, err = this.generateType(typ.Element) n += nn; if err != nil { return n, err } case TypeTable: nn, err := this.print("Table") n += nn; if err != nil { return n, err } case TypeTableDefined: nn, err := this.generateTypeTableDefined(typ) n += nn; if err != nil { return n, err } case TypeNamed: nn, err := this.print(typ.Name) n += nn; if err != nil { return n, err } case TypeAny: nn, err := this.print("any") n += nn; if err != nil { return n, err } default: panic(fmt.Errorf("unknown type: %T", typ)) } return n, nil } func (this *Generator) generateTypeTableDefined(typ TypeTableDefined) (n int, err error) { nn, err := this.print("struct {\n") n += nn; if err != nil { return n, err } this.push() for _, key := range slices.Sorted(maps.Keys(typ.Fields)) { field := typ.Fields[key] nn, err := this.iprintf("%s\n", this.formatComment(field.Doc)) n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s ", field.Name) n += nn; if err != nil { return n, err } nn, err = this.generateType(field.Type) n += nn; if err != nil { return n, err } nn, err = this.print("\n") n += nn; if err != nil { return n, err } } this.pop() nn, err = this.iprint("}") n += nn; if err != nil { return n, err } return n, nil } // generateCanAssign generates an expression which checks if the tag specified // by tagSource can be assigned to a Go destination generated from typ. The // generated code is INLINE. func (this *Generator) generateCanAssign(typ Type, tagSource string) (n int, err error) { if _, ok := typ.(TypeAny); ok { nn, err := this.printf("true") n += nn; if err != nil { return n, err } } else { nn, err := this.printf("canAssign(") n += nn; if err != nil { return n, err } nn, err = this.generateTN(typ) n += nn; if err != nil { return n, err } nn, err = this.printf(", %s)", tagSource) n += nn; if err != nil { return n, err } } return n, nil } func (this *Generator) validateIntBitSize(size int) error { switch size { case 5, 8, 16, 32, 64: return nil default: return fmt.Errorf("integers of size %d are unsupported on this platform", size) } } func (this *Generator) validateFloatBitSize(size int) error { switch size { case 16, 32, 64: return nil default: return fmt.Errorf("floats of size %d are unsupported on this platform", size) } } func (this *Generator) push() { this.nestingLevel ++ } func (this *Generator) pop() { if this.nestingLevel < 1 { panic("cannot pop when nesting level is less than 1") } this.nestingLevel -- } func (this *Generator) indent() string { return strings.Repeat("\t", this.nestingLevel) } func (this *Generator) print(args ...any) (n int, err error) { return fmt.Fprint(this.Output, args...) } func (this *Generator) println(args ...any) (n int, err error) { return fmt.Fprintln(this.Output, args...) } func (this *Generator) printf(format string, args ...any) (n int, err error) { return fmt.Fprintf(this.Output, format, args...) } func (this *Generator) iprint(args ...any) (n int, err error) { return fmt.Fprint(this.Output, this.indent() + fmt.Sprint(args...)) } func (this *Generator) iprintln(args ...any) (n int, err error) { return fmt.Fprintln(this.Output, this.indent() + fmt.Sprint(args...)) } func (this *Generator) iprintf(format string, args ...any) (n int, err error) { return fmt.Fprintf(this.Output, this.indent() + format, args...) } func (this *Generator) formatComment(comment string) string { return "// " + strings.ReplaceAll(comment, "\n", "\n" + this.indent() + "// ") } func (this *Generator) resolveMessageName(message string) string { return "Message" + message } func (this *Generator) resolveTypeName(name string) (Type, error) { if typedef, ok := this.protocol.Types[name]; ok { if typ, ok := typedef.Type.(TypeNamed); ok { return this.resolveTypeName(typ.Name) } return typedef.Type, nil } return nil, fmt.Errorf("no type exists called %s", name) } func (this *Generator) newTemporaryVar(base string) string { this.temporaryVar += 1 return fmt.Sprintf("%s_%d", base, this.temporaryVar) } func bitsToBytes(bits int) int { return int(math.Ceil(float64(bits) / 8.0)) } func bitsToCN(bits int) int { return bitsToBytes(bits) - 1 }