diff --git a/tape/dynamic.go b/tape/dynamic.go index 44c1a4b..5b90ff9 100644 --- a/tape/dynamic.go +++ b/tape/dynamic.go @@ -14,6 +14,14 @@ import "reflect" var dummyMap map[uint16] any var dummyBuffer []byte +type errCantAssign string +func (err errCantAssign) Error() string { + return string(err) +} +func errCantAssignf(format string, v ...any) errCantAssign { + return errCantAssign(fmt.Sprintf(format, v...)) +} + // EncodeAny encodes an "any" value. Returns an error if the underlying type is // unsupported. Supported types are: // @@ -83,20 +91,34 @@ type unknownSlicePlaceholder struct { } var unknownSlicePlaceholderType = reflect.TypeOf(unknownSlicePlaceholder { }) // decodeAny is internal to [DecodeAny]. It takes in an addressable -// [reflect.Value] as the destination. +// [reflect.Value] as the destination. If the decoded value cannot fit in the +// destination, it skims over the payload, leaves the destination empty, and +// returns without an error. func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) { - errWrongDestinationType := func(expected string) error { - panic(fmt.Errorf( - // return fmt.Errorf( - "expected %s destination, not %v", - expected, destination)) + n, err = decodeAnyOrError(decoder, destination, tag) + if _, ok := err.(errCantAssign); ok { + if n > 0 { panic(fmt.Sprintf("decodeAnyOrError decoded more than it should: %d", n)) } + nn, err := Skim(decoder, tag) + n += nn; if err != nil { return n, err } + return n, nil } + return n, err +} + +// decodeAnyOrError is internal to [decodeAny]. It takes in an addressable +// [reflect.Value] as the destination. If the decoded value cannot fit in the +// destination, it decodes nothing and returns an error of type errCantAssign, +// except for the case of a mismatched OTA element tag, wherein it will skim +// over the rest of the payload, leave the destination empty, and return without +// an error. +func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) { + err = canSet(destination.Type(), tag) + if err != nil { return n, err } switch tag.WithoutCN() { case SI: // SI: (none) - err = setInt(destination, uint64(tag.CN())) - if err != nil { return n, err } + setInt(destination, uint64(tag.CN())) case LI: // LI: nn, err := decodeAndSetUint(decoder, destination, tag.CN() + 1) @@ -114,8 +136,7 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err buffer := make([]byte, tag.CN()) nn, err := decoder.Read(buffer) n += nn; if err != nil { return n, err } - err = setByteArray(destination, buffer) - if err != nil { return n, err } + setByteArray(destination, buffer) case LBA: // LBA: * length, nn, err := decoder.ReadUintN(tag.CN() + 1) @@ -123,34 +144,42 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err buffer := make([]byte, length) nn, err = decoder.Read(buffer) n += nn; if err != nil { return n, err } - err = setByteArray(destination, buffer) - if err != nil { return n, err } + setByteArray(destination, buffer) case OTA: // OTA: * length, nn, err := decoder.ReadUintN(tag.CN() + 1) n += nn; if err != nil { return n, err } oneTag, nn, err := decoder.ReadTag() n += nn; if err != nil { return n, err } - if destination.Kind() != reflect.Slice { - return n, errWrongDestinationType("slice") - } if destination.Cap() < int(length) { destination.Grow(int(length) - destination.Cap()) } + // skip the rest of the array if the one tag doesn't + // match up with the destination + err = canSet(destination.Type().Elem(), oneTag) + if _, ok := err.(errCantAssign); ok { + for _ = range length { + nn, err := Skim(decoder, oneTag) + n += nn; if err != nil { return n, err } + } + break + } + if err != nil { return n, err } destination.SetLen(int(length)) for index := range length { nn, err := decodeAny(decoder, destination.Index(int(index)), oneTag) - n += nn; if err != nil { return n, err } + n += nn + if _, ok := err.(errCantAssign); ok { + continue + } else if err != nil { + return n, err + } } case KTV: // KTV: ( )* - table := destination - if table.Type() != reflect.TypeOf(dummyMap) { - return n, errWrongDestinationType("map[uint16] any") - } length, nn, err := decoder.ReadUintN(tag.CN() + 1) n += nn; if err != nil { return n, err } - table.Clear() + destination.Clear() for _ = range length { key, nn, err := decoder.ReadUint16() n += nn; if err != nil { return n, err } @@ -160,7 +189,7 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err if err != nil { return n, err } nn, err = decodeAny(decoder, value.Elem(), itemTag) n += nn; if err != nil { return n, err } - table.SetMapIndex(reflect.ValueOf(key), value.Elem()) + destination.SetMapIndex(reflect.ValueOf(key), value.Elem()) } default: return n, fmt.Errorf("unknown TN %d", tag.TN()) @@ -254,53 +283,79 @@ func encodeAnyMap(encoder *Encoder, value any, tag Tag) (n int, err error) { return n, nil } +func canSet(destination reflect.Type, tag Tag) error { + switch tag.WithoutCN() { + case SI, LI, LSI: + switch destination.Kind() { + case + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + default: + return errCantAssignf("cannot assign integer to %v", destination) + } + case FP: + switch destination.Kind() { + case reflect.Float32, reflect.Float64: + default: + return errCantAssignf("cannot assign float to %v", destination) + } + case SBA, LBA: + if destination.Kind() != reflect.Slice { + return errCantAssignf("cannot assign byte array to %v", destination) + } + if destination.Elem() != reflect.TypeOf(byte(0)) { + return errCantAssignf("cannot convert %v to *[]byte", destination) + } + case OTA: + if destination.Kind() != reflect.Slice { + return errCantAssignf("cannot assign array to %v", destination) + } + case KTV: + if destination != reflect.TypeOf(dummyMap) { + return errCantAssignf("cannot assign table to %v", destination) + } + default: + return fmt.Errorf("unknown TN %d", tag.TN()) + } + return nil +} + // setInt expects a settable destination. -func setInt[T int64 | uint64](destination reflect.Value, value T) error { +func setInt[T int64 | uint64](destination reflect.Value, value T) { switch { case destination.CanInt(): destination.Set(reflect.ValueOf(int64(value)).Convert(destination.Type())) case destination.CanUint(): destination.Set(reflect.ValueOf(value).Convert(destination.Type())) default: - return fmt.Errorf("cannot assign integer to %T", destination.Interface()) + panic("setInt called on an unsupported type") } - return nil } // setFloat expects a settable destination. -func setFloat(destination reflect.Value, value float64) error { - if !destination.CanFloat() { - return fmt.Errorf("cannot assign float to %T", destination.Interface()) - } +func setFloat(destination reflect.Value, value float64) { destination.Set(reflect.ValueOf(value).Convert(destination.Type())) - return nil } // setByteArrayexpects a settable destination. -func setByteArray(destination reflect.Value, value []byte) error { - typ := destination.Type() - if typ.Kind() != reflect.Slice { - return fmt.Errorf("cannot assign %T to ", value) - } - if typ.Elem() != reflect.TypeOf(byte(0)) { - return fmt.Errorf("cannot convert %T to *[]byte", value) - } +func setByteArray(destination reflect.Value, value []byte) { destination.Set(reflect.ValueOf(value)) - return nil } // decodeAndSetInt expects a settable destination. func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) { value, nn, err := decoder.ReadIntN(bytes) n += nn; if err != nil { return n, err } - return n, setInt(destination, value) + setInt(destination, value) + return n, nil } // decodeAndSetUint expects a settable destination. func decodeAndSetUint(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) { value, nn, err := decoder.ReadUintN(bytes) n += nn; if err != nil { return n, err } - return n, setInt(destination, value) + setInt(destination, value) + return n, nil } // decodeAndSetInt expects a settable destination. @@ -309,13 +364,15 @@ func decodeAndSetFloat(decoder *Decoder, destination reflect.Value, bytes int) ( case 8: value, nn, err := decoder.ReadFloat64() n += nn; if err != nil { return n, err } - return n, setFloat(destination, float64(value)) + setFloat(destination, float64(value)) + return n, nil case 4: value, nn, err := decoder.ReadFloat32() n += nn; if err != nil { return n, err } - return n, setFloat(destination, float64(value)) + setFloat(destination, float64(value)) + return n, nil } - return n, fmt.Errorf("cannot decode float%d", bytes * 8) + return n, errCantAssignf("unsupported bit width float%d", bytes * 8) } // skeletonValue returns a pointer value. In order for it to be set, it must be diff --git a/tape/dynamic_test.go b/tape/dynamic_test.go index b77655c..7e7bab7 100644 --- a/tape/dynamic_test.go +++ b/tape/dynamic_test.go @@ -74,6 +74,101 @@ func TestEncodeAnyTable(test *testing.T) { if err != nil { test.Fatal(err) } } +func TestDecodeWrongType(test *testing.T) { + datas := [][]byte { + /* int8 */ []byte { byte(LSI.WithCN(0)), 0x45 }, + /* int16 */ []byte { byte(LSI.WithCN(1)), 0x45, 0x67 }, + /* int32 */ []byte { byte(LSI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB }, + /* int64 */ []byte { byte(LSI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 }, + /* uint5 */ []byte { byte(SI.WithCN(12)) }, + /* uint8 */ []byte { byte(LI.WithCN(0)), 0x45 }, + /* uint16 */ []byte { byte(LI.WithCN(1)), 0x45, 0x67 }, + /* uint32 */ []byte { byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB }, + /* uint64 */ []byte { byte(LI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 }, + /* string */ []byte { byte(SBA.WithCN(7)), 'p', 'u', 'p', 'e', 'v', 'e', 'r' }, + /* []byte */ []byte { byte(SBA.WithCN(5)), 'b', 'l', 'a', 'r', 'g' }, + /* []string */ []byte { + byte(OTA.WithCN(0)), 2, byte(LBA.WithCN(0)), + 0x08, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, + 0x05, 0x11, 0x11, 0x11, 0x11, 0x11, + }, + /* map[uint16] any */ []byte { + byte(KTV.WithCN(0)), 2, + 0x02, 0x23, byte(LSI.WithCN(1)), 0x45, 0x67, + 0x02, 0x23, byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB, + }, + } + + for index, data := range datas { + test.Logf("data %2d %v [%s]", index, Tag(data[0]), tu.HexBytes(data[1:])) + // integers should only assign to other integers + if index > 8 { + cas := func(destination any) { + n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0])) + if err != nil { test.Fatalf("error: %v | n: %d", err, n) } + reflectValue := reflect.ValueOf(destination).Elem() + if reflectValue.CanInt() { + if reflectValue.Int() != 0 { + test.Fatalf("destination not zero: %v", reflectValue.Elem().Interface()) + } + } else { + if reflectValue.Uint() != 0 { + test.Fatalf("destination not zero: %v", reflectValue.Elem().Interface()) + } + } + if n != len(data) - 1 { + test.Fatalf("n not equal: %d != %d", n, len(data) - 1) + } + } + test.Log("- int8") + { var dest int8; cas(&dest) } + test.Log("- int16") + { var dest int16; cas(&dest) } + test.Log("- int32") + { var dest int32; cas(&dest) } + test.Log("- int64") + { var dest int64; cas(&dest) } + test.Log("- uint8") + { var dest uint8; cas(&dest) } + test.Log("- uint16") + { var dest uint16; cas(&dest) } + test.Log("- uint32") + { var dest uint32; cas(&dest) } + test.Log("- uint64") + { var dest uint64; cas(&dest) } + } + arrayCase := func(destination any) { + n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0])) + if err != nil { test.Fatalf("error: %v | n: %d", err, n) } + reflectDestination := reflect.ValueOf(destination) + reflectValue := reflectDestination.Elem() + if reflectValue.Len() != 0 { + test.Fatalf("len(destination) not zero: %v", reflectValue.Interface()) + } + if n != len(data) - 1 { + test.Fatalf("n not equal: %d != %d", n, len(data) - 1) + } + } + // SBA/LBA types should only assign to other SBA/LBA types + if index != 9 && index != 10 { + test.Log("- string") + { var dest string; arrayCase(&dest) } + test.Log("- []byte") + { var dest []byte; arrayCase(&dest) } + } + // arrays should only assign to other arrays + if index != 11 { + test.Log("- []string") + { var dest []string; arrayCase(&dest) } + } + // tables should only assign to other tables + if index != 12 { + test.Log("- map[uint16] any") + { var dest = map[uint16] any { }; arrayCase(&dest) } + } + } +} + func TestEncodeDecodeAnyTable(test *testing.T) { err := testEncodeDecodeAny(test, map[uint16] any { 0xF3B9: uint32(1),