Compare commits

...

5 Commits

2 changed files with 45 additions and 10 deletions

View File

@ -101,6 +101,7 @@ func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
// DecodeAnyInto decodes data and places it into destination, which must be a // DecodeAnyInto decodes data and places it into destination, which must be a
// pointer to a supported type. See [EncodeAny] for a list of supported types. // pointer to a supported type. See [EncodeAny] for a list of supported types.
// The head of the decoder must be at the start of the payload.
func DecodeAnyInto(decoder *Decoder, destination any, tag Tag) (n int, err error) { func DecodeAnyInto(decoder *Decoder, destination any, tag Tag) (n int, err error) {
reflectDestination := reflect.ValueOf(destination) reflectDestination := reflect.ValueOf(destination)
if reflectDestination.Kind() != reflect.Pointer { if reflectDestination.Kind() != reflect.Pointer {
@ -110,9 +111,10 @@ func DecodeAnyInto(decoder *Decoder, destination any, tag Tag) (n int, err error
} }
// DecodeAny is like [DecodeAnyInto], but it automatically creates the // DecodeAny is like [DecodeAnyInto], but it automatically creates the
// destination from the tag and data. // destination from the tag and data. The head of the decoder must be at the
// start of the payload.
func DecodeAny(decoder *Decoder, tag Tag) (value any, n int, err error) { func DecodeAny(decoder *Decoder, tag Tag) (value any, n int, err error) {
destination, err := skeletonValue(decoder, tag) destination, err := skeletonPointer(decoder, tag)
if err != nil { return nil, n, err } if err != nil { return nil, n, err }
nn, err := DecodeAnyInto(decoder, destination, tag) nn, err := DecodeAnyInto(decoder, destination, tag)
n += nn; if err != nil { return nil, n, err } n += nn; if err != nil { return nil, n, err }
@ -128,7 +130,8 @@ var unknownSlicePlaceholderType = reflect.TypeOf(unknownSlicePlaceholder { })
// decodeAny is internal to [DecodeAny]. It takes in an addressable // decodeAny is internal to [DecodeAny]. It takes in an addressable
// [reflect.Value] as the destination. If the decoded value cannot fit in the // [reflect.Value] as the destination. If the decoded value cannot fit in the
// destination, it skims over the payload, leaves the destination empty, and // destination, it skims over the payload, leaves the destination empty, and
// returns without an error. // returns without an error. The head of the decoder must be at the start of the
// payload.
func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) { func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) {
n, err = decodeAnyOrError(decoder, destination, tag) n, err = decodeAnyOrError(decoder, destination, tag)
if _, ok := err.(errCantAssign); ok { if _, ok := err.(errCantAssign); ok {
@ -145,7 +148,7 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err
// destination, it decodes nothing and returns an error of type errCantAssign, // destination, it decodes nothing and returns an error of type errCantAssign,
// except for the case of a mismatched OTA element tag, wherein it will skim // except for the case of a mismatched OTA element tag, wherein it will skim
// over the rest of the payload, leave the destination empty, and return without // over the rest of the payload, leave the destination empty, and return without
// an error. // an error. The head of the decoder must be at the start of the payload.
func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) { func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) {
err = canSet(destination.Type(), tag) err = canSet(destination.Type(), tag)
if err != nil { return n, err } if err != nil { return n, err }
@ -189,6 +192,13 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
setString(destination, string(buffer)) setString(destination, string(buffer))
case OTA: case OTA:
// OTA: <length: UN> <elementTag: tape.Tag> <values>* // OTA: <length: UN> <elementTag: tape.Tag> <values>*
oldDestination := destination
if isTypeAny(destination.Type()) {
// need a skeleton value if we are assigning to any.
value, err := skeletonValue(decoder, tag)
if err != nil { return n, err }
destination = value
}
length, nn, err := decoder.ReadUintN(tag.CN() + 1) length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
if length > uint64(MaxStructureLength) { if length > uint64(MaxStructureLength) {
@ -222,6 +232,7 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
return n, err return n, err
} }
} }
oldDestination.Set(destination)
case KTV: case KTV:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)* // KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
length, nn, err := decoder.ReadUintN(tag.CN() + 1) length, nn, err := decoder.ReadUintN(tag.CN() + 1)
@ -229,13 +240,21 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
if length > uint64(MaxStructureLength) { if length > uint64(MaxStructureLength) {
return 0, ErrTooLong return 0, ErrTooLong
} }
lengthCast, err := Uint64ToIntSafe(length)
if err != nil { return n, err }
if isTypeAny(destination.Type()) {
// need a skeleton value if we are assigning to any.
value := reflect.MakeMapWithSize(reflect.TypeOf(dummyMap), lengthCast)
destination.Set(value)
destination = value
}
destination.Clear() destination.Clear()
for _ = range length { for _ = range lengthCast {
key, nn, err := decoder.ReadUint16() key, nn, err := decoder.ReadUint16()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
itemTag, nn, err := decoder.ReadTag() itemTag, nn, err := decoder.ReadTag()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
value, err := skeletonValue(decoder, itemTag) value, err := skeletonPointer(decoder, itemTag)
if err != nil { return n, err } if err != nil { return n, err }
nn, err = decodeAny(decoder, value.Elem(), itemTag) nn, err = decodeAny(decoder, value.Elem(), itemTag)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
@ -356,6 +375,7 @@ func canSet(destination reflect.Type, tag Tag) error {
return errCantAssignf("cannot assign float to %v", destination) return errCantAssignf("cannot assign float to %v", destination)
} }
case SBA, LBA: case SBA, LBA:
if destination.Kind() == reflect.String { return nil }
if destination.Kind() != reflect.Slice { if destination.Kind() != reflect.Slice {
return errCantAssignf("cannot assign byte array to %v", destination) return errCantAssignf("cannot assign byte array to %v", destination)
} }
@ -424,6 +444,11 @@ func setByteArray(destination reflect.Value, value []byte) {
destination.Set(reflect.ValueOf(value)) destination.Set(reflect.ValueOf(value))
} }
// setString exepctes a settable destination
func setString(destination reflect.Value, value string) {
destination.Set(reflect.ValueOf(value))
}
// decodeAndSetInt expects a settable destination. // decodeAndSetInt expects a settable destination.
func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) { func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
value, nn, err := decoder.ReadIntN(bytes) value, nn, err := decoder.ReadIntN(bytes)
@ -457,16 +482,26 @@ func decodeAndSetFloat(decoder *Decoder, destination reflect.Value, bytes int) (
return n, errCantAssignf("unsupported bit width float%d", bytes * 8) return n, errCantAssignf("unsupported bit width float%d", bytes * 8)
} }
// skeletonValue returns a pointer value. In order for it to be set, it must be // skeletonValue returns an addressable value. It can be set directly. The head
// dereferenced using Elem(). // of the decoder must be at the start of the payload when calling.
func skeletonValue(decoder *Decoder, tag Tag) (reflect.Value, error) { func skeletonValue(decoder *Decoder, tag Tag) (reflect.Value, error) {
ptr, err := skeletonPointer(decoder, tag)
if err != nil { return reflect.Value { }, err }
return ptr.Elem(), nil
}
// skeletonPointer returns a pointer value. In order for it to be set, it must
// be dereferenced using Elem(). The head of the decoder must be at the start of
// the payload when calling.
func skeletonPointer(decoder *Decoder, tag Tag) (reflect.Value, error) {
typ, err := typeOf(decoder, tag) typ, err := typeOf(decoder, tag)
if err != nil { return reflect.Value { }, err } if err != nil { return reflect.Value { }, err }
return reflect.New(typ), nil return reflect.New(typ), nil
} }
// typeOf returns the type of the current tag being decoded. It does not use up // typeOf returns the type of the current tag being decoded. It does not use up
// the decoder, it only peeks. // the decoder, it only peeks. The head of the decoder must be at the start of
// the payload when calling.
func typeOf(decoder *Decoder, tag Tag) (reflect.Type, error) { func typeOf(decoder *Decoder, tag Tag) (reflect.Type, error) {
switch tag.WithoutCN() { switch tag.WithoutCN() {
case SI: case SI:

View File

@ -208,7 +208,7 @@ func TestEncodeDecodeAnyTable(test *testing.T) {
0x0102: uint32(2), 0x0102: uint32(2),
0x0103: int64(23432), 0x0103: int64(23432),
0x0104: int64(-88777), 0x0104: int64(-88777),
0x0000: []byte("hi!"), 0x0000: "hi!",
0xFFFF: []uint16 { 0xBEE5, 0x7777 }, 0xFFFF: []uint16 { 0xBEE5, 0x7777 },
0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} }, 0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} },
0x1F1F: float32(67.26), 0x1F1F: float32(67.26),