3 Commits

2 changed files with 41 additions and 28 deletions

View File

@@ -172,35 +172,40 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err
// underlying type is unsupported. See [EncodeAny] for a list of supported // underlying type is unsupported. See [EncodeAny] for a list of supported
// types. // types.
func TagAny(value any) (Tag, error) { func TagAny(value any) (Tag, error) {
// TODO use reflection for all of this to ignore type names return tagAny(reflect.ValueOf(value))
}
func tagAny(reflectValue reflect.Value) (Tag, error) {
// primitives // primitives
switch value := value.(type) { switch reflectValue.Kind() {
case int: return LSI.WithCN(3), nil case reflect.Int: return LSI.WithCN(3), nil
case int8: return LSI.WithCN(0), nil case reflect.Int8: return LSI.WithCN(0), nil
case int16: return LSI.WithCN(1), nil case reflect.Int16: return LSI.WithCN(1), nil
case int32: return LSI.WithCN(3), nil case reflect.Int32: return LSI.WithCN(3), nil
case int64: return LSI.WithCN(7), nil case reflect.Int64: return LSI.WithCN(7), nil
case uint: return LI.WithCN(3), nil case reflect.Uint: return LI.WithCN(3), nil
case uint8: return LI.WithCN(0), nil case reflect.Uint8: return LI.WithCN(0), nil
case uint16: return LI.WithCN(1), nil case reflect.Uint16: return LI.WithCN(1), nil
case uint32: return LI.WithCN(3), nil case reflect.Uint32: return LI.WithCN(3), nil
case uint64: return LI.WithCN(7), nil case reflect.Uint64: return LI.WithCN(7), nil
case string: return bufferLenTag(len(value)), nil case reflect.String: return bufferLenTag(reflectValue.Len()), nil
case []byte: return bufferLenTag(len(value)), nil }
if reflectValue.CanConvert(reflect.TypeOf(dummyBuffer)) {
return bufferLenTag(reflectValue.Len()), nil
} }
// aggregates // aggregates
reflectType := reflect.TypeOf(value) reflectType := reflectValue.Type()
switch reflectType.Kind() { switch reflectType.Kind() {
case reflect.Slice: return OTA.WithCN(IntBytes(uint64(reflect.ValueOf(value).Len())) - 1), nil case reflect.Slice: return OTA.WithCN(IntBytes(uint64(reflectValue.Len())) - 1), nil
case reflect.Array: return OTA.WithCN(reflectType.Len()), nil case reflect.Array: return OTA.WithCN(reflectType.Len()), nil
case reflect.Map: case reflect.Map:
if reflectType.Key() == reflect.TypeOf(uint16(0)) { if reflectType.Key() == reflect.TypeOf(uint16(0)) {
return KTV.WithCN(IntBytes(uint64(reflect.ValueOf(value).Len())) - 1), nil return KTV.WithCN(IntBytes(uint64(reflectValue.Len())) - 1), nil
} }
return 0, fmt.Errorf("cannot encode map key %T, key must be uint16", value) return 0, fmt.Errorf("cannot encode map key %v, key must be uint16", reflectType.Key())
} }
return 0, fmt.Errorf("cannot get tag of type %T", value) return 0, fmt.Errorf("cannot get tag of type %v", reflectType)
} }
func encodeAnySlice(encoder *Encoder, value any, tag Tag) (n int, err error) { func encodeAnySlice(encoder *Encoder, value any, tag Tag) (n int, err error) {
@@ -209,11 +214,10 @@ func encodeAnySlice(encoder *Encoder, value any, tag Tag) (n int, err error) {
nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1) nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
reflectType := reflect.TypeOf(value) reflectType := reflect.TypeOf(value)
oneTag, err := TagAny(reflect.Zero(reflectType.Elem()).Interface()) oneTag, err := tagAny(reflect.Zero(reflectType.Elem()))
if err != nil { return n, err } if err != nil { return n, err }
for index := 0; index < reflectValue.Len(); index += 1 { for index := 0; index < reflectValue.Len(); index += 1 {
item := reflectValue.Index(index).Interface() itemTag, err := tagAny(reflectValue.Index(index))
itemTag, err := TagAny(item)
if err != nil { return n, err } if err != nil { return n, err }
if itemTag.CN() > oneTag.CN() { oneTag = itemTag } if itemTag.CN() > oneTag.CN() { oneTag = itemTag }
} }
@@ -235,11 +239,12 @@ func encodeAnyMap(encoder *Encoder, value any, tag Tag) (n int, err error) {
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
iter := reflectValue.MapRange() iter := reflectValue.MapRange()
for iter.Next() { for iter.Next() {
key := iter.Key().Interface().(uint16) reflectValue := iter.Value().Elem()
value := iter.Value().Interface() key := iter.Key().Interface().(uint16)
value := reflectValue.Interface()
nn, err = encoder.WriteUint16(key) nn, err = encoder.WriteUint16(key)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
itemTag, err := TagAny(value) itemTag, err := tagAny(reflectValue)
if err != nil { return n, err } if err != nil { return n, err }
nn, err = encoder.WriteUint8(uint8(itemTag)) nn, err = encoder.WriteUint8(uint8(itemTag))
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }

View File

@@ -6,6 +6,8 @@ import "testing"
import "reflect" import "reflect"
import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil" import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil"
type userDefinedInteger int16
func TestEncodeAnyInt(test *testing.T) { func TestEncodeAnyInt(test *testing.T) {
err := testEncodeAny(test, uint8(0xCA), LI.WithCN(0), tu.S(0xCA)) err := testEncodeAny(test, uint8(0xCA), LI.WithCN(0), tu.S(0xCA))
if err != nil { test.Fatal(err) } if err != nil { test.Fatal(err) }
@@ -23,7 +25,8 @@ func TestEncodeAnyTable(test *testing.T) {
0xFFFF: []uint16 { 0xBEE5, 0x7777 }, 0xFFFF: []uint16 { 0xBEE5, 0x7777 },
0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} }, 0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} },
0x2345: [][]int16 { []int16 { 0x5 }, []int16 { 0x17, -0xAAA } }, 0x2345: [][]int16 { []int16 { 0x5 }, []int16 { 0x17, -0xAAA } },
}, KTV.WithCN(0), tu.S(6).AddVar( 0x3456: userDefinedInteger(0x3921),
}, KTV.WithCN(0), tu.S(7).AddVar(
[]byte { []byte {
0xF3, 0xB9, 0xF3, 0xB9,
byte(LSI.WithCN(3)), byte(LSI.WithCN(3)),
@@ -62,6 +65,11 @@ func TestEncodeAnyTable(test *testing.T) {
0, 0x17, 0, 0x17,
0xF5, 0x56, 0xF5, 0x56,
}, },
[]byte {
0x34, 0x56,
byte(LSI.WithCN(1)),
0x39, 0x21,
},
)) ))
if err != nil { test.Fatal(err) } if err != nil { test.Fatal(err) }
} }
@@ -166,10 +174,10 @@ func testEncodeAny(test *testing.T, value any, correctTag Tag, correctBytes tu.S
test.Log("got: ", tu.HexBytes(bytes)) test.Log("got: ", tu.HexBytes(bytes))
test.Log("correct:", correctBytes) test.Log("correct:", correctBytes)
if tag != correctTag { if tag != correctTag {
return fmt.Errorf("tag not equal") return fmt.Errorf("tag not equal: %v != %v", tag, correctTag)
} }
if ok, n := correctBytes.Check(bytes); !ok { if ok, n := correctBytes.Check(bytes); !ok {
return fmt.Errorf("bytes not equal: %d", n) return fmt.Errorf("bytes not equal at index %d", n)
} }
if n != len(bytes) { if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes)) return fmt.Errorf("n not equal: %d != %d", n, len(bytes))