diff --git a/tape/dynamic.go b/tape/dynamic.go index 86fdb1a..44c1a4b 100644 --- a/tape/dynamic.go +++ b/tape/dynamic.go @@ -172,35 +172,40 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err // underlying type is unsupported. See [EncodeAny] for a list of supported // types. func TagAny(value any) (Tag, error) { - // TODO use reflection for all of this to ignore type names + return tagAny(reflect.ValueOf(value)) +} + +func tagAny(reflectValue reflect.Value) (Tag, error) { // primitives - switch value := value.(type) { - case int: return LSI.WithCN(3), nil - case int8: return LSI.WithCN(0), nil - case int16: return LSI.WithCN(1), nil - case int32: return LSI.WithCN(3), nil - case int64: return LSI.WithCN(7), nil - case uint: return LI.WithCN(3), nil - case uint8: return LI.WithCN(0), nil - case uint16: return LI.WithCN(1), nil - case uint32: return LI.WithCN(3), nil - case uint64: return LI.WithCN(7), nil - case string: return bufferLenTag(len(value)), nil - case []byte: return bufferLenTag(len(value)), nil + switch reflectValue.Kind() { + case reflect.Int: return LSI.WithCN(3), nil + case reflect.Int8: return LSI.WithCN(0), nil + case reflect.Int16: return LSI.WithCN(1), nil + case reflect.Int32: return LSI.WithCN(3), nil + case reflect.Int64: return LSI.WithCN(7), nil + case reflect.Uint: return LI.WithCN(3), nil + case reflect.Uint8: return LI.WithCN(0), nil + case reflect.Uint16: return LI.WithCN(1), nil + case reflect.Uint32: return LI.WithCN(3), nil + case reflect.Uint64: return LI.WithCN(7), nil + case reflect.String: return bufferLenTag(reflectValue.Len()), nil + } + if reflectValue.CanConvert(reflect.TypeOf(dummyBuffer)) { + return bufferLenTag(reflectValue.Len()), nil } // aggregates - reflectType := reflect.TypeOf(value) + reflectType := reflectValue.Type() switch reflectType.Kind() { - case reflect.Slice: return OTA.WithCN(IntBytes(uint64(reflect.ValueOf(value).Len())) - 1), nil + case reflect.Slice: return OTA.WithCN(IntBytes(uint64(reflectValue.Len())) - 1), nil case reflect.Array: return OTA.WithCN(reflectType.Len()), nil case reflect.Map: if reflectType.Key() == reflect.TypeOf(uint16(0)) { - return KTV.WithCN(IntBytes(uint64(reflect.ValueOf(value).Len())) - 1), nil + return KTV.WithCN(IntBytes(uint64(reflectValue.Len())) - 1), nil } - return 0, fmt.Errorf("cannot encode map key %T, key must be uint16", value) + return 0, fmt.Errorf("cannot encode map key %v, key must be uint16", reflectType.Key()) } - return 0, fmt.Errorf("cannot get tag of type %T", value) + return 0, fmt.Errorf("cannot get tag of type %v", reflectType) } func encodeAnySlice(encoder *Encoder, value any, tag Tag) (n int, err error) { @@ -209,11 +214,10 @@ func encodeAnySlice(encoder *Encoder, value any, tag Tag) (n int, err error) { nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1) n += nn; if err != nil { return n, err } reflectType := reflect.TypeOf(value) - oneTag, err := TagAny(reflect.Zero(reflectType.Elem()).Interface()) + oneTag, err := tagAny(reflect.Zero(reflectType.Elem())) if err != nil { return n, err } for index := 0; index < reflectValue.Len(); index += 1 { - item := reflectValue.Index(index).Interface() - itemTag, err := TagAny(item) + itemTag, err := tagAny(reflectValue.Index(index)) if err != nil { return n, err } if itemTag.CN() > oneTag.CN() { oneTag = itemTag } } @@ -235,11 +239,12 @@ func encodeAnyMap(encoder *Encoder, value any, tag Tag) (n int, err error) { n += nn; if err != nil { return n, err } iter := reflectValue.MapRange() for iter.Next() { - key := iter.Key().Interface().(uint16) - value := iter.Value().Interface() + reflectValue := iter.Value().Elem() + key := iter.Key().Interface().(uint16) + value := reflectValue.Interface() nn, err = encoder.WriteUint16(key) n += nn; if err != nil { return n, err } - itemTag, err := TagAny(value) + itemTag, err := tagAny(reflectValue) if err != nil { return n, err } nn, err = encoder.WriteUint8(uint8(itemTag)) n += nn; if err != nil { return n, err }