From 92040a1bc44ace59e37190cafd74828ca3baa65c Mon Sep 17 00:00:00 2001 From: Sasha Koshka Date: Wed, 10 Sep 2025 09:45:25 -0400 Subject: [PATCH] generate: Implement encoding and decoding of Any type --- generate/generate.go | 127 +++++++++++++++++++++++++------------------ 1 file changed, 74 insertions(+), 53 deletions(-) diff --git a/generate/generate.go b/generate/generate.go index 08eb8c0..520171f 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -235,17 +235,13 @@ func (this *Generator) generateMessage(method uint16, message Message) (n int, e this.resolveMessageName(message.Name)) n += nn; if err != nil { return n, err } this.push() - nn, err = this.iprintf("tag := ") + tagVar, nn, err := this.generateTag(message.Type, "(*this)") n += nn; if err != nil { return n, err } - nn, err = this.generateTag(message.Type, "(*this)") - n += nn; if err != nil { return n, err } - nn, err = this.println() - n += nn; if err != nil { return n, err } - nn, err = this.iprintf("nn, err := encoder.WriteTag(tag)\n") + nn, err = this.iprintf("nn, err := encoder.WriteTag(%s)\n", tagVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } - nn, err = this.generateEncodeValue(message.Type, "(*this)", "tag") + nn, err = this.generateEncodeValue(message.Type, "(*this)", tagVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("return n, nil\n") n += nn; if err != nil { return n, err } @@ -398,15 +394,11 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri this.push() nn, err = this.iprintf("_ = item\n") n += nn; if err != nil { return n, err } - nn, err = this.iprintf("tag := ") + tagVar, nn, err := this.generateTag(typ.Element, "item") n += nn; if err != nil { return n, err } - nn, err = this.generateTag(typ.Element, "item") + nn, err = this.iprintf("if %s.Is(tape.SBA) { continue }\n", tagVar) n += nn; if err != nil { return n, err } - nn, err = this.println() - n += nn; if err != nil { return n, err } - nn, err = this.iprintf("if tag.Is(tape.SBA) { continue }\n") - n += nn; if err != nil { return n, err } - nn, err = this.iprintf("if tag.CN() > itemTag.CN() { itemTag = tag }\n") + nn, err = this.iprintf("if %s.CN() > itemTag.CN() { itemTag = %s }\n", tagVar, tagVar) n += nn; if err != nil { return n, err } this.pop() nn, err = this.iprintf("}\n") @@ -461,25 +453,19 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri nn, err = this.iprintf("{\n") n += nn; if err != nil { return n, err } this.push() - nn, err = this.iprintf("var tag tape.Tag\n") - n += nn; if err != nil { return n, err } for key, field := range typ.Fields { nn, err = this.iprintf("nn, err = encoder.WriteUint16(0x%04X)\n", key) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } - nn, err = this.iprintf("tag = ") - n += nn; if err != nil { return n, err } fieldSource := fmt.Sprintf("%s.%s", valueSource, field.Name) - nn, err = this.generateTag(field.Type, fieldSource) + tagVar, nn, err := this.generateTag(field.Type, fieldSource) n += nn; if err != nil { return n, err } - nn, err = this.println() - n += nn; if err != nil { return n, err } - nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(tag))\n") + nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(%s))\n", tagVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } - nn, err = this.generateEncodeValue(field.Type, fieldSource, "tag") + nn, err = this.generateEncodeValue(field.Type, fieldSource, tagVar) n += nn; if err != nil { return n, err } } this.pop() @@ -491,6 +477,12 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } + case TypeAny: + // WHATEVER: [WHATEVER] + nn, err := this.iprintf("nn, err = tape.EncodeAny(encoder, %s, %s)\n", valueSource, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } default: panic(fmt.Errorf("unknown type: %T", typ)) } @@ -642,6 +634,12 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } + case TypeAny: + // WHATEVER: [WHATEVER] + nn, err := this.iprintf("*%s, nn, err = tape.DecodeAny(decoder, %s)\n", valueSource, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } default: panic(fmt.Errorf("unknown type: %T", typ)) } @@ -924,49 +922,60 @@ func (this *Generator) generateErrorCheck() (n int, err error) { return this.iprintf("n += nn; if err != nil { return n, err }\n") } +func (this *Generator) generateBareErrorCheck() (n int, err error) { + return this.iprintf("if err != nil { return n, err }\n") +} + // generateTag generates the preferred TN and CN for the given type and value. -// The generated code is INLINE. -func (this *Generator) generateTag(typ Type, source string) (n int, err error) { +// The generated code is a BLOCK. +func (this *Generator) generateTag(typ Type, source string) (tagVar string, n int, err error) { + tagVar = this.newTemporaryVar("tag") switch typ := typ.(type) { case TypeInt: if typ.Bits <= 5 { - nn, err := this.printf("tape.SI.WithCN(int(%s))", source) - n += nn; if err != nil { return n, err } + nn, err := this.iprintf("%s := tape.SI.WithCN(int(%s))\n", tagVar, source) + n += nn; if err != nil { return tagVar, n, err } } else if typ.Signed { - nn, err := this.printf("tape.LSI.WithCN(%d)", bitsToCN(typ.Bits)) - n += nn; if err != nil { return n, err } + nn, err := this.iprintf("%s := tape.LSI.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits)) + n += nn; if err != nil { return tagVar, n, err } } else { - nn, err := this.printf("tape.LI.WithCN(%d)", bitsToCN(typ.Bits)) - n += nn; if err != nil { return n, err } + nn, err := this.iprintf("%s := tape.LI.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits)) + n += nn; if err != nil { return tagVar, n, err } } case TypeFloat: - nn, err := this.printf("tape.FP.WithCN(%d)", bitsToCN(typ.Bits)) - n += nn; if err != nil { return n, err } + nn, err := this.iprintf("%s := tape.FP.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits)) + n += nn; if err != nil { return tagVar, n, err } case TypeString: - nn, err := this.printf("tape.StringTag(string(%s))", source) - n += nn; if err != nil { return n, err } + nn, err := this.iprintf("%s := tape.StringTag(string(%s))\n", tagVar, source) + n += nn; if err != nil { return tagVar, n, err } case TypeBuffer: - nn, err := this.printf("tape.BufferTag([]byte(%s))", source) - n += nn; if err != nil { return n, err } + nn, err := this.iprintf("%s := tape.BufferTag([]byte(%s))\n", tagVar, source) + n += nn; if err != nil { return tagVar, n, err } case TypeArray: - nn, err := this.printf("tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))))", source) - n += nn; if err != nil { return n, err } + nn, err := this.iprintf("%s := tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))))\n", tagVar, source) + n += nn; if err != nil { return tagVar, n, err } case TypeTable: - nn, err := this.printf("tape.KTV.WithCN(tape.IntBytes(uint64(len(%s))))", source) - n += nn; if err != nil { return n, err } + nn, err := this.iprintf("%s := tape.KTV.WithCN(tape.IntBytes(uint64(len(%s))))\n", tagVar, source) + n += nn; if err != nil { return tagVar, n, err } case TypeTableDefined: - nn, err := this.printf("tape.KTV.WithCN(%d)", tape.IntBytes(uint64(len(typ.Fields)))) - n += nn; if err != nil { return n, err } + nn, err := this.iprintf("%s := tape.KTV.WithCN(%d)\n", tagVar, tape.IntBytes(uint64(len(typ.Fields)))) + n += nn; if err != nil { return tagVar, n, err } case TypeNamed: resolved, err := this.resolveTypeName(typ.Name) - if err != nil { return n, err } - nn, err := this.generateTag(resolved, source) - n += nn; if err != nil { return n, err } + if err != nil { return tagVar, n, err } + subTagVar, nn, err := this.generateTag(resolved, source) + n += nn; if err != nil { return tagVar, n, err } + tagVar = subTagVar + case TypeAny: + nn, err := this.iprintf("%s, err := tape.TagAny(%s)\n", tagVar, source) + n += nn; if err != nil { return tagVar, n, err } + nn, err = this.generateBareErrorCheck() + n += nn; if err != nil { return tagVar, n, err } default: panic(fmt.Errorf("unknown type: %T", typ)) } - return n, nil + return tagVar, n, nil } // generateTN generates the appropriate TN for the given type. The generated @@ -1009,6 +1018,8 @@ func (this *Generator) generateTN(typ Type) (n int, err error) { if err != nil { return n, err } nn, err := this.generateTN(resolved) n += nn; if err != nil { return n, err } + default: + panic(fmt.Errorf("unknown type: %T", typ)) } return n, nil @@ -1063,6 +1074,11 @@ func (this *Generator) generateType(typ Type) (n int, err error) { case TypeNamed: nn, err := this.print(typ.Name) n += nn; if err != nil { return n, err } + case TypeAny: + nn, err := this.print("any") + n += nn; if err != nil { return n, err } + default: + panic(fmt.Errorf("unknown type: %T", typ)) } return n, nil } @@ -1092,12 +1108,17 @@ func (this *Generator) generateTypeTableDefined(typ TypeTableDefined) (n int, er // by tagSource can be assigned to a Go destination generated from typ. The // generated code is INLINE. func (this *Generator) generateCanAssign(typ Type, tagSource string) (n int, err error) { - nn, err := this.printf("canAssign(") - n += nn; if err != nil { return n, err } - nn, err = this.generateTN(typ) - n += nn; if err != nil { return n, err } - nn, err = this.printf(", %s)", tagSource) - n += nn; if err != nil { return n, err } + if _, ok := typ.(TypeAny); ok { + nn, err := this.printf("true") + n += nn; if err != nil { return n, err } + } else { + nn, err := this.printf("canAssign(") + n += nn; if err != nil { return n, err } + nn, err = this.generateTN(typ) + n += nn; if err != nil { return n, err } + nn, err = this.printf(", %s)", tagSource) + n += nn; if err != nil { return n, err } + } return n, nil }