tape: Send reflect values where possible instead of pointers

This commit is contained in:
Sasha Koshka 2025-06-27 19:05:17 -04:00
parent 2a4e88d949
commit 3eb826735b

View File

@ -63,53 +63,58 @@ func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
// DecodeAny decodes data and places it into destination, which must be a // DecodeAny decodes data and places it into destination, which must be a
// pointer to a supported type. See [EncodeAny] for a list of supported types. // pointer to a supported type. See [EncodeAny] for a list of supported types.
func DecodeAny(decoder *Decoder, destination any, tag Tag) (n int, err error) { func DecodeAny(decoder *Decoder, destination any, tag Tag) (n int, err error) {
return decodeAny(decoder, reflect.ValueOf(destination), tag) reflectDestination := reflect.ValueOf(destination)
if reflectDestination.Kind() != reflect.Pointer {
return n, fmt.Errorf("expected pointer destination, not %v", destination)
}
return decodeAny(decoder, reflectDestination.Elem(), tag)
} }
// unknownSlicePlaceholder is inserted by skeletonValue and informs the program
// that the destination for the slice needs to be generated based on the item
// tag in the OTA.
type unknownSlicePlaceholder struct { }
var unknownSlicePlaceholderType = reflect.TypeOf(unknownSlicePlaceholder { })
// decodeAny is internal to [DecodeAny]. It takes in an addressable
// [reflect.Value] as the destination.
func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) { func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) {
errWrongDestinationType := func(expected string) error { errWrongDestinationType := func(expected string) error {
// panic(fmt.Errorf( panic(fmt.Errorf(
return fmt.Errorf( // return fmt.Errorf(
"expected %s destination, not %v", "expected %s destination, not %v",
expected, destination) expected, destination))
//)
}
if destination.Kind() != reflect.Pointer {
return n, errWrongDestinationType("pointer")
} }
switch tag.WithoutCN() { switch tag.WithoutCN() {
case SI: case SI:
// SI: (none) // SI: (none)
err = setIntPtr(destination, uint64(tag.CN())) err = setInt(destination, uint64(tag.CN()))
if err != nil { return n, err } if err != nil { return n, err }
case LI: case LI:
// LI: <value: IntN> // LI: <value: IntN>
nn, err := decodeAndSetIntPtr(decoder, destination, tag.CN() + 1) nn, err := decodeAndSetInt(decoder, destination, tag.CN() + 1)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
case FP: case FP:
// FP: <value: FloatN> // FP: <value: FloatN>
nn, err := decodeAndSetFloatPtr(decoder, destination, tag.CN() + 1) nn, err := decodeAndSetFloat(decoder, destination, tag.CN() + 1)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
case SBA: case SBA:
// SBA: <data: U8>* // SBA: <data: U8>*
destination, err := asByteArrayPtr(destination)
if err != nil { return n, err }
buffer := make([]byte, tag.CN()) buffer := make([]byte, tag.CN())
nn, err := decoder.Read(buffer) nn, err := decoder.Read(buffer)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
*destination = buffer err = setByteArray(destination, buffer)
if err != nil { return n, err }
case LBA: case LBA:
// LBA: <length: UN> <data: U8>* // LBA: <length: UN> <data: U8>*
destination, err := asByteArrayPtr(destination)
if err != nil { return n, err }
length, nn, err := decoder.ReadUintN(tag.CN() + 1) length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
buffer := make([]byte, length) buffer := make([]byte, length)
nn, err = decoder.Read(buffer) nn, err = decoder.Read(buffer)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
*destination = buffer err = setByteArray(destination, buffer)
if err != nil { return n, err }
case OTA: case OTA:
// OTA: <length: UN> <elementTag: tape.Tag> <values>* // OTA: <length: UN> <elementTag: tape.Tag> <values>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1) length, nn, err := decoder.ReadUintN(tag.CN() + 1)
@ -118,13 +123,13 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
var slice reflect.Value var slice reflect.Value
needSet := false needSet := false
elem := destination.Elem() if destination.Kind() == reflect.Struct && destination.Type() == unknownSlicePlaceholderType {
if elem.Kind() == reflect.Struct && elem.Type().Name() == "unknownSlicePlaceholder" {
needSet = true needSet = true
slice, err = skeletonValueSlice(oneTag, int(length)) slice, err = skeletonValueSlice(oneTag, int(length))
if err != nil { return n, err } if err != nil { return n, err }
slice = slice.Elem()
} else { } else {
slice = elem slice = destination
if slice.Kind() != reflect.Slice { if slice.Kind() != reflect.Slice {
return n, errWrongDestinationType("slice") return n, errWrongDestinationType("slice")
} }
@ -135,11 +140,11 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
} }
if needSet { if needSet {
destination.Elem().Set(slice) destination.Set(slice)
} }
case KTV: case KTV:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)* // KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
table := destination.Elem() table := destination
var dummyMap map[uint16] any var dummyMap map[uint16] any
if table.Type() != reflect.TypeOf(dummyMap) { if table.Type() != reflect.TypeOf(dummyMap) {
return n, errWrongDestinationType("map[uint16] any") return n, errWrongDestinationType("map[uint16] any")
@ -154,7 +159,7 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
value, err := skeletonValue(itemTag) value, err := skeletonValue(itemTag)
if err != nil { return n, err } if err != nil { return n, err }
nn, err = decodeAny(decoder, value, itemTag) nn, err = decodeAny(decoder, value.Elem(), itemTag)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
table.SetMapIndex(reflect.ValueOf(key), value) table.SetMapIndex(reflect.ValueOf(key), value)
} }
@ -240,58 +245,65 @@ func encodeAnyMap(encoder *Encoder, value any, tag Tag) (n int, err error) {
return n, nil return n, nil
} }
func setIntPtr(destination reflect.Value, value uint64) error { // setInt expects a settable destination.
elem := destination.Elem() func setInt(destination reflect.Value, value uint64) error {
if !elem.CanInt() { switch {
return fmt.Errorf("cannot assign integer to %T", elem.Interface()) case destination.CanInt():
destination.Set(reflect.ValueOf(int64(value)).Convert(destination.Type()))
case destination.CanUint():
destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
default:
return fmt.Errorf("cannot assign integer to %T", destination.Interface())
} }
elem.Set(reflect.ValueOf(value).Convert(elem.Type()))
return nil return nil
} }
func setFloatPtr(destination reflect.Value, value float64) error { // setInt expects a settable destination.
elem := destination.Elem() func setFloat(destination reflect.Value, value float64) error {
if !elem.CanFloat() { if !destination.CanFloat() {
return fmt.Errorf("cannot assign float to %T", elem.Interface()) return fmt.Errorf("cannot assign float to %T", destination.Interface())
} }
elem.Set(reflect.ValueOf(value).Convert(elem.Type())) destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
return nil return nil
} }
func decodeAndSetIntPtr(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) { // setByteArrayexpects a settable destination.
func setByteArray(destination reflect.Value, value []byte) error {
typ := destination.Type()
if typ.Kind() != reflect.Slice {
return fmt.Errorf("cannot assign %T to ", value)
}
if typ.Elem() != reflect.TypeOf(byte(0)) {
return fmt.Errorf("cannot convert %T to *[]byte", value)
}
destination.Set(reflect.ValueOf(value))
return nil
}
// decodeAndSetInt expects a settable destination.
func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
value, nn, err := decoder.ReadUintN(bytes) value, nn, err := decoder.ReadUintN(bytes)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
return n, setIntPtr(destination, value) return n, setInt(destination, value)
} }
func decodeAndSetFloatPtr(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) { // decodeAndSetInt expects a settable destination.
func decodeAndSetFloat(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
switch bytes { switch bytes {
case 8: case 8:
value, nn, err := decoder.ReadFloat64() value, nn, err := decoder.ReadFloat64()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
return n, setFloatPtr(destination, float64(value)) return n, setFloat(destination, float64(value))
case 4: case 4:
value, nn, err := decoder.ReadFloat32() value, nn, err := decoder.ReadFloat32()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
return n, setFloatPtr(destination, float64(value)) return n, setFloat(destination, float64(value))
} }
return n, fmt.Errorf("cannot decode float%d", bytes * 8) return n, fmt.Errorf("cannot decode float%d", bytes * 8)
} }
func asByteArrayPtr(value reflect.Value) (*[]byte, error) { // skeletonValue returns a pointer value. In order for it to be set, it must be
typ := value.Type() // dereferenced using Elem().
if typ.Kind() != reflect.Pointer {
return nil, fmt.Errorf("cannot convert %T to pointer", value)
}
if typ.Elem().Kind() != reflect.Slice {
return nil, fmt.Errorf("cannot convert %T to slice pointer", value)
}
if typ.Elem().Elem() != reflect.TypeOf(byte(0)) {
return nil, fmt.Errorf("cannot convert %T to *[]byte", value)
}
return value.Convert(reflect.PtrTo(reflect.SliceOf(reflect.TypeOf(byte(0))))).Interface().(*[]byte), nil
}
func skeletonValue(tag Tag) (reflect.Value, error) { func skeletonValue(tag Tag) (reflect.Value, error) {
switch tag.WithoutCN() { switch tag.WithoutCN() {
case SI: case SI:
@ -319,6 +331,8 @@ func skeletonValue(tag Tag) (reflect.Value, error) {
return reflect.Value { }, fmt.Errorf("unknown TN %d", tag.TN()) return reflect.Value { }, fmt.Errorf("unknown TN %d", tag.TN())
} }
// skeletonValueSlice returns a pointer value. In order for it to be set, it
// must be dereferenced using Elem().
func skeletonValueSlice(tag Tag, length int) (reflect.Value, error) { func skeletonValueSlice(tag Tag, length int) (reflect.Value, error) {
switch tag.WithoutCN() { switch tag.WithoutCN() {
case SI: case SI:
@ -345,8 +359,3 @@ func skeletonValueSlice(tag Tag, length int) (reflect.Value, error) {
} }
return reflect.Value { }, fmt.Errorf("unknown TN %d", tag.TN()) return reflect.Value { }, fmt.Errorf("unknown TN %d", tag.TN())
} }
// unknownSlicePlaceholder is inserted by skeletonValue and informs the program
// that the destination for the slice needs to be generated based on the item
// tag in the OTA.
type unknownSlicePlaceholder struct { }