generate: Cast strings and buffers when decoding

This commit is contained in:
Sasha Koshka 2025-08-27 14:55:10 -04:00
parent de6099fadc
commit 77bfc45fea

View File

@ -513,6 +513,8 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.%s%d()\n", destinationVar, prefix, typ.Bits) nn, err = this.iprintf("%s, nn, err = decoder.%s%d()\n", destinationVar, prefix, typ.Bits)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
if typeName == "" { if typeName == "" {
nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar) nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
@ -520,14 +522,22 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS
nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar) nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
} }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeFloat: case TypeFloat:
// FP: <value: FloatN> // FP: <value: FloatN>
nn, err := this.iprintf("*%s, nn, err = decoder.ReadFloat%d()\n", valueSource, typ.Bits) destinationVar := this.newTemporaryVar("destination")
nn, err := this.iprintf("var %s ", destinationVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadFloat%d()\n", destinationVar, typ.Bits)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck() nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
if typeName == "" {
nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar)
n += nn; if err != nil { return n, err }
}
case TypeString, TypeBuffer: case TypeString, TypeBuffer:
// SBA: <data: U8>* // SBA: <data: U8>*
// LBA: <length: UN> <data: U8>* // LBA: <length: UN> <data: U8>*
@ -560,11 +570,16 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck() nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
if _, ok := typ.(TypeString); ok { if typeName == "" {
nn, err = this.iprintf("*%s = string(buffer)\n", valueSource) if _, ok := typ.(TypeString); ok {
n += nn; if err != nil { return n, err } nn, err = this.iprintf("*%s = string(buffer)\n", valueSource)
n += nn; if err != nil { return n, err }
} else {
nn, err = this.iprintf("*%s = buffer\n", valueSource)
n += nn; if err != nil { return n, err }
}
} else { } else {
nn, err = this.iprintf("*%s = buffer\n", valueSource) nn, err = this.iprintf("*%s = %s(buffer)\n", valueSource, typeName)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
} }
case TypeArray: case TypeArray:
@ -876,10 +891,10 @@ func (this *Generator) generateTag(typ Type, source string) (n int, err error) {
nn, err := this.printf("tape.FP.WithCN(%d)", bitsToCN(typ.Bits)) nn, err := this.printf("tape.FP.WithCN(%d)", bitsToCN(typ.Bits))
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
case TypeString: case TypeString:
nn, err := this.printf("tape.StringTag(%s)", source) nn, err := this.printf("tape.StringTag(string(%s))", source)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
case TypeBuffer: case TypeBuffer:
nn, err := this.printf("tape.BufferTag(%s)", source) nn, err := this.printf("tape.BufferTag([]byte(%s))", source)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
case TypeArray: case TypeArray:
nn, err := this.printf("tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))))", source) nn, err := this.printf("tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))))", source)