From b826cbf83ea2a2058eb059bba0bca4791de07a36 Mon Sep 17 00:00:00 2001 From: Sasha Koshka Date: Thu, 17 Jul 2025 10:26:31 -0400 Subject: [PATCH] generate: The test now compiles --- generate/generate.go | 91 +++++++++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 44 deletions(-) diff --git a/generate/generate.go b/generate/generate.go index 3180957..2125ed5 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -44,6 +44,7 @@ type Generator struct { PackageName string nestingLevel int + temporaryVar int protocol *Protocol } @@ -143,6 +144,8 @@ func (this *Generator) generateTypedef(name string, typ Type) (n int, err error) name) n += nn; if err != nil { return n, err } this.push() + nn, err = this.iprintf("var nn int\n") + n += nn; if err != nil { return n, err } nn, err = this.generateDecodeValue(typ, "this", "tag") n += nn; if err != nil { return n, err } nn, err = this.iprintf("return n, nil\n") @@ -206,7 +209,7 @@ func (this *Generator) generateMessage(method uint16, message Message) (n int, e nn, err = this.iprintf("\n// Decode decodes this message's tag and value.\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf( - "func(this *%s) Decode(encoder *tape.Decoder) (n int, err error) {\n", + "func(this *%s) Decode(decoder *tape.Decoder) (n int, err error) {\n", this.resolveMessageName(message.Name)) n += nn; if err != nil { return n, err } this.push() @@ -412,13 +415,13 @@ func (this *Generator) generateDecodeValue(typ Type, valueSource, tagSource stri if typ.Signed { prefix = "ReadInt" } - nn, err := this.iprintf("*%s, nn, err = encoder.%s%d()\n", valueSource, prefix, typ.Bits) + nn, err := this.iprintf("*%s, nn, err = decoder.%s%d()\n", valueSource, prefix, typ.Bits) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } case TypeFloat: // FP: - nn, err := this.iprintf("%s, nn, err = encoder.ReadFloat%d(%s)\n", valueSource, typ.Bits) + nn, err := this.iprintf("%s, nn, err = decoder.ReadFloat%d(%s)\n", valueSource, typ.Bits) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } @@ -434,7 +437,7 @@ func (this *Generator) generateDecodeValue(typ Type, valueSource, tagSource stri n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf( - "length, nn, err = encoder.ReadUintN(uint64(%s.CN()))\n", + "length, nn, err = decoder.ReadUintN(int(%s.CN()))\n", tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() @@ -450,7 +453,7 @@ func (this *Generator) generateDecodeValue(typ Type, valueSource, tagSource stri n += nn; if err != nil { return n, err } nn, err = this.iprintf("buffer := make([]byte, int(length))\n") n += nn; if err != nil { return n, err } - nn, err = this.iprintf("nn, err = encoder.Read(buffer)\n") + nn, err = this.iprintf("nn, err = decoder.Read(buffer)\n") n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } @@ -472,7 +475,7 @@ func (this *Generator) generateDecodeValue(typ Type, valueSource, tagSource stri nn, err = this.iprintf("var length uint64\n") n += nn; if err != nil { return n, err } nn, err = this.iprintf( - "length, nn, err = encoder.ReadUintN(uint64(%s.CN()))\n", + "length, nn, err = decoder.ReadUintN(int(%s.CN()))\n", tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() @@ -485,7 +488,7 @@ func (this *Generator) generateDecodeValue(typ Type, valueSource, tagSource stri n += nn; if err != nil { return n, err } nn, err = this.iprintf("var itemTag tape.Tag\n") n += nn; if err != nil { return n, err } - nn, err = this.iprintf("itemTag, nn, err = encoder.ReadTag()\n") + nn, err = this.iprintf("itemTag, nn, err = decoder.ReadTag()\n") n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } @@ -513,51 +516,46 @@ func (this *Generator) generateDecodeValue(typ Type, valueSource, tagSource stri n += nn; if err != nil { return n, err } case TypeTableDefined: // KTV: ( )* - nn, err := this.iprintf("{\n") + lengthVar := this.newTemporaryVar("length") + nn, err := this.iprintf("var %s uint64\n", lengthVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( + "%s, nn, err = decoder.ReadUintN(int(%s.CN()))\n", + lengthVar, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("for _ = range %s {\n", lengthVar) n += nn; if err != nil { return n, err } this.push() - nn, err = this.iprintf("var length uint64\n") + nn, err = this.iprintf("var key uint16\n") n += nn; if err != nil { return n, err } - nn, err = this.iprintf( - "length, nn, err = encoder.ReadUintN(uint64(%s.CN()))\n", - tagSource) + nn, err = this.iprintf("key, nn, err = decoder.ReadUint16()\n") n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } - nn, err = this.iprintf("for _ = range length {\n") + nn, err = this.iprintf("var itemTag tape.Tag\n") n += nn; if err != nil { return n, err } - this.push() - nn, err = this.iprintf("var key uint16\n") + nn, err = this.iprintf("itemTag, nn, err = decoder.ReadTag()\n") + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("switch key {\n") + n += nn; if err != nil { return n, err } + keys := slices.Collect(maps.Keys(typ.Fields)) + slices.Sort(keys) + for _, key := range keys { + field := typ.Fields[key] + nn, err = this.iprintf("case 0x%04X:\n", key) n += nn; if err != nil { return n, err } - nn, err = this.iprintf("key, nn, err = decoder.ReadUint16()\n") - n += nn; if err != nil { return n, err } - nn, err = this.generateErrorCheck() - n += nn; if err != nil { return n, err } - nn, err = this.iprintf("var itemTag tape.Tag\n") - n += nn; if err != nil { return n, err } - nn, err = this.iprintf("itemTag, nn, err = decoder.ReadTag()\n") - n += nn; if err != nil { return n, err } - nn, err = this.generateErrorCheck() - n += nn; if err != nil { return n, err } - nn, err = this.iprintf("switch key {\n") - n += nn; if err != nil { return n, err } - keys := slices.Collect(maps.Keys(typ.Fields)) - slices.Sort(keys) - for _, key := range keys { - field := typ.Fields[key] - nn, err = this.iprintf("case 0x%04X:\n", key) + this.push() + nn, err = this.generateDecodeValue( + field.Type, + fmt.Sprintf("(&%s.%s)", valueSource, field.Name), + "itemTag") n += nn; if err != nil { return n, err } - this.push() - nn, err = this.generateDecodeValue( - field.Type, - fmt.Sprintf("(&%s.%s)", valueSource, field.Name), - "itemTag") - n += nn; if err != nil { return n, err } - this.pop() - } - nn, err = this.iprintf("}\n") - n += nn; if err != nil { return n, err } - this.pop() + this.pop() + } nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } this.pop() @@ -565,7 +563,7 @@ func (this *Generator) generateDecodeValue(typ Type, valueSource, tagSource stri n += nn; if err != nil { return n, err } case TypeNamed: // WHATEVER: [WHATEVER] - nn, err := this.iprintf("nn, err = %s.DecodeValue(encoder, %s)\n", valueSource, tagSource) + nn, err := this.iprintf("nn, err = %s.DecodeValue(decoder, %s)\n", valueSource, tagSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } @@ -801,6 +799,11 @@ func (this *Generator) resolveTypeName(name string) (Type, error) { return nil, fmt.Errorf("no type exists called %s", name) } +func (this *Generator) newTemporaryVar(base string) string { + this.temporaryVar += 1 + return fmt.Sprintf("%s_%d", base, this.temporaryVar) +} + func bitsToBytes(bits int) int { return int(math.Ceil(float64(bits) / 8.0)) }