From 81d95dcd90bb9f247bacf35d5297211d393b51b1 Mon Sep 17 00:00:00 2001 From: Sasha Koshka Date: Wed, 15 Oct 2025 21:00:15 -0400 Subject: [PATCH] generate: Add support for options in static system --- generate/generate.go | 77 ++++++++++++++++++++++++--------------- generate/generate_test.go | 28 ++++++++++++++ 2 files changed, 76 insertions(+), 29 deletions(-) diff --git a/generate/generate.go b/generate/generate.go index 2c64b33..96a0722 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -11,6 +11,7 @@ import "git.tebibyte.media/sashakoshka/hopp/tape" const imports = ` +import "git.tebibyte.media/sashakoshka/hopp" import "git.tebibyte.media/sashakoshka/hopp/tape" ` @@ -56,6 +57,8 @@ func boolInt(input bool) int { return 0 } } + +var _ hopp.Option[int] ` // Generator converts protocols into Go code. @@ -479,19 +482,33 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri n += nn; if err != nil { return n, err } this.push() for key, field := range typ.Fields { + fieldSource := fmt.Sprintf("%s.%s", valueSource, field.Name) + if field.Option { + nn, err = this.iprintf("if value, ok := %s.Value(); ok {\n", fieldSource) + n += nn; if err != nil { return n, err } + fieldSource = "value" + this.push() + } + nn, err = this.iprintf("nn, err = encoder.WriteUint16(0x%04X)\n", key) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } - fieldSource := fmt.Sprintf("%s.%s", valueSource, field.Name) tagVar, nn, err := this.generateTag(field.Type, fieldSource) n += nn; if err != nil { return n, err } nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(%s))\n", tagVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } + nn, err = this.generateEncodeValue(field.Type, fieldSource, tagVar) n += nn; if err != nil { return n, err } + + if field.Option { + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + } } this.pop() nn, err = this.iprintf("}\n") @@ -554,13 +571,7 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS prefix = "ReadInt" } destinationVar := this.newTemporaryVar("destination") - nn, err := this.iprintf("var %s ", destinationVar) - n += nn; if err != nil { return n, err } - nn, err = this.generateType(typ) - n += nn; if err != nil { return n, err } - nn, err = this.print("\n") - n += nn; if err != nil { return n, err } - nn, err = this.iprintf("%s, nn, err = decoder.%s%d()\n", destinationVar, prefix, typ.Bits) + nn, err := this.iprintf("%s, nn, err := decoder.%s%d()\n", destinationVar, prefix, typ.Bits) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } @@ -574,13 +585,7 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS case TypeFloat: // FP: destinationVar := this.newTemporaryVar("destination") - nn, err := this.iprintf("var %s ", destinationVar) - n += nn; if err != nil { return n, err } - nn, err = this.generateType(typ) - n += nn; if err != nil { return n, err } - nn, err = this.print("\n") - n += nn; if err != nil { return n, err } - nn, err = this.iprintf("%s, nn, err = decoder.ReadFloat%d()\n", destinationVar, typ.Bits) + nn, err := this.iprintf("%s, nn, err := decoder.ReadFloat%d()\n", destinationVar, typ.Bits) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } @@ -803,6 +808,7 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } case TypeTableDefined: + // TODO: options // KTV: ( )* // read header lengthVar := this.newTemporaryVar("length") @@ -879,10 +885,25 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st n += nn; if err != nil { return n, err } // decode payload - nn, err = this.generateDecodeValue( - field.Type, "", - fmt.Sprintf("(&(this.%s))", field.Name), fieldTagVar) - n += nn; if err != nil { return n, err } + if field.Option { + destination := this.newTemporaryVar("destination") + nn, err = this.iprintf("var %s ", destination) + n += nn; if err != nil { return n, err } + nn, err = this.generateType(field.Type) + n += nn; if err != nil { return n, err } + nn, err = this.printf("\n") + n += nn; if err != nil { return n, err } + nn, err = this.generateDecodeValue( + field.Type, "", fmt.Sprintf("(&%s)", destination), fieldTagVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("this.%s = hopp.O(%s)\n", field.Name, destination) + n += nn; if err != nil { return n, err } + } else { + nn, err = this.generateDecodeValue( + field.Type, "", + fmt.Sprintf("(&(this.%s))", field.Name), fieldTagVar) + n += nn; if err != nil { return n, err } + } this.pop() } nn, err = this.iprintf("default:\n") @@ -895,16 +916,6 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } - - // TODO once options are implemented, have a set of - // bools for each non-optional field, and check here - // that they are all true. a counter will not work - // because if someone specifies a non-optional field - // twice, they can neglect to specify another - // non-optional field and we won't even know because the - // count will still be even. we shouldn't use a map - // either because its an allocation and its way more - // memory than just, like 5 bools (on the stack no less) default: return n, fmt.Errorf("unexpected type: %T", typ) } @@ -1133,8 +1144,16 @@ func (this *Generator) generateTypeTableDefined(typ TypeTableDefined) (n int, er n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s ", field.Name) n += nn; if err != nil { return n, err } + if field.Option { + nn, err = this.print("hopp.Option[") + n += nn; if err != nil { return n, err } + } nn, err = this.generateType(field.Type) n += nn; if err != nil { return n, err } + if field.Option { + nn, err = this.print("]") + n += nn; if err != nil { return n, err } + } nn, err = this.print("\n") n += nn; if err != nil { return n, err } } diff --git a/generate/generate_test.go b/generate/generate_test.go index 25e3384..97b9628 100644 --- a/generate/generate_test.go +++ b/generate/generate_test.go @@ -84,6 +84,34 @@ func init() { }, }, } + exampleProtocol.Messages[0x0006] = Message { + Name: "Option", + Type: TypeTableDefined { + Fields: map[uint16] Field { + 0x0000: Field { Name: "OU5", Option: true, Type: TypeInt { Bits: 5 } }, + 0x0001: Field { Name: "OU8", Option: true, Type: TypeInt { Bits: 8 } }, + 0x0002: Field { Name: "OU16", Option: true, Type: TypeInt { Bits: 16 } }, + 0x0003: Field { Name: "OU32", Option: true, Type: TypeInt { Bits: 32 } }, + 0x0004: Field { Name: "OU64", Option: true, Type: TypeInt { Bits: 64 } }, + 0x0005: Field { Name: "OI8", Option: true, Type: TypeInt { Bits: 8, Signed: true } }, + 0x0006: Field { Name: "OI16", Option: true, Type: TypeInt { Bits: 16, Signed: true } }, + 0x0007: Field { Name: "OI32", Option: true, Type: TypeInt { Bits: 32, Signed: true } }, + 0x0008: Field { Name: "OI64", Option: true, Type: TypeInt { Bits: 64, Signed: true } }, + 0x0009: Field { Name: "OF32", Option: true, Type: TypeFloat { Bits: 32 } }, + 0x000A: Field { Name: "OF64", Option: true, Type: TypeFloat { Bits: 64 } }, + 0x000B: Field { Name: "OBool", Option: true, Type: TypeBool { } }, + 0x000C: Field { Name: "OString", Option: true, Type: TypeString { } }, + 0x000D: Field { Name: "OArray", Option: true, Type: TypeArray { + Element: TypeInt { Bits: 16, Signed: true } } }, + 0x000E: Field { Name: "OTable", Option: true, Type: TypeTableDefined { + Fields: map[uint16] Field { + 0x0001: Field { Name: "A", Type: TypeInt { Bits: 8, Signed: true }, }, + 0x0002: Field { Name: "B", Type: TypeFloat { Bits: 64 }, }, + } } }, + 0x000F: Field { Name: "T0", Option: true, Type: TypeTable { } }, + }, + }, + } exampleProtocol.Types["User"] = Typedef { Type: TypeTableDefined { Fields: map[uint16] Field {