message-size-increase #3

Open
sashakoshka wants to merge 209 commits from message-size-increase into main
2 changed files with 196 additions and 44 deletions
Showing only changes of commit a4da33536c - Show all commits

View File

@ -14,6 +14,14 @@ import "reflect"
var dummyMap map[uint16] any
var dummyBuffer []byte
type errCantAssign string
func (err errCantAssign) Error() string {
return string(err)
}
func errCantAssignf(format string, v ...any) errCantAssign {
return errCantAssign(fmt.Sprintf(format, v...))
}
// EncodeAny encodes an "any" value. Returns an error if the underlying type is
// unsupported. Supported types are:
//
@ -83,20 +91,34 @@ type unknownSlicePlaceholder struct { }
var unknownSlicePlaceholderType = reflect.TypeOf(unknownSlicePlaceholder { })
// decodeAny is internal to [DecodeAny]. It takes in an addressable
// [reflect.Value] as the destination.
// [reflect.Value] as the destination. If the decoded value cannot fit in the
// destination, it skims over the payload, leaves the destination empty, and
// returns without an error.
func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) {
errWrongDestinationType := func(expected string) error {
panic(fmt.Errorf(
// return fmt.Errorf(
"expected %s destination, not %v",
expected, destination))
n, err = decodeAnyOrError(decoder, destination, tag)
if _, ok := err.(errCantAssign); ok {
if n > 0 { panic(fmt.Sprintf("decodeAnyOrError decoded more than it should: %d", n)) }
nn, err := Skim(decoder, tag)
n += nn; if err != nil { return n, err }
return n, nil
}
return n, err
}
// decodeAnyOrError is internal to [decodeAny]. It takes in an addressable
// [reflect.Value] as the destination. If the decoded value cannot fit in the
// destination, it decodes nothing and returns an error of type errCantAssign,
// except for the case of a mismatched OTA element tag, wherein it will skim
// over the rest of the payload, leave the destination empty, and return without
// an error.
func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) {
err = canSet(destination.Type(), tag)
if err != nil { return n, err }
switch tag.WithoutCN() {
case SI:
// SI: (none)
err = setInt(destination, uint64(tag.CN()))
if err != nil { return n, err }
setInt(destination, uint64(tag.CN()))
case LI:
// LI: <value: IntN>
nn, err := decodeAndSetUint(decoder, destination, tag.CN() + 1)
@ -114,8 +136,7 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err
buffer := make([]byte, tag.CN())
nn, err := decoder.Read(buffer)
n += nn; if err != nil { return n, err }
err = setByteArray(destination, buffer)
if err != nil { return n, err }
setByteArray(destination, buffer)
case LBA:
// LBA: <length: UN> <data: U8>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
@ -123,34 +144,42 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err
buffer := make([]byte, length)
nn, err = decoder.Read(buffer)
n += nn; if err != nil { return n, err }
err = setByteArray(destination, buffer)
if err != nil { return n, err }
setByteArray(destination, buffer)
case OTA:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
oneTag, nn, err := decoder.ReadTag()
n += nn; if err != nil { return n, err }
if destination.Kind() != reflect.Slice {
return n, errWrongDestinationType("slice")
}
if destination.Cap() < int(length) {
destination.Grow(int(length) - destination.Cap())
}
// skip the rest of the array if the one tag doesn't
// match up with the destination
err = canSet(destination.Type().Elem(), oneTag)
if _, ok := err.(errCantAssign); ok {
for _ = range length {
nn, err := Skim(decoder, oneTag)
n += nn; if err != nil { return n, err }
}
break
}
if err != nil { return n, err }
destination.SetLen(int(length))
for index := range length {
nn, err := decodeAny(decoder, destination.Index(int(index)), oneTag)
n += nn; if err != nil { return n, err }
n += nn
if _, ok := err.(errCantAssign); ok {
continue
} else if err != nil {
return n, err
}
}
case KTV:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
table := destination
if table.Type() != reflect.TypeOf(dummyMap) {
return n, errWrongDestinationType("map[uint16] any")
}
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
table.Clear()
destination.Clear()
for _ = range length {
key, nn, err := decoder.ReadUint16()
n += nn; if err != nil { return n, err }
@ -160,7 +189,7 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err
if err != nil { return n, err }
nn, err = decodeAny(decoder, value.Elem(), itemTag)
n += nn; if err != nil { return n, err }
table.SetMapIndex(reflect.ValueOf(key), value.Elem())
destination.SetMapIndex(reflect.ValueOf(key), value.Elem())
}
default:
return n, fmt.Errorf("unknown TN %d", tag.TN())
@ -254,53 +283,79 @@ func encodeAnyMap(encoder *Encoder, value any, tag Tag) (n int, err error) {
return n, nil
}
func canSet(destination reflect.Type, tag Tag) error {
switch tag.WithoutCN() {
case SI, LI, LSI:
switch destination.Kind() {
case
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
default:
return errCantAssignf("cannot assign integer to %v", destination)
}
case FP:
switch destination.Kind() {
case reflect.Float32, reflect.Float64:
default:
return errCantAssignf("cannot assign float to %v", destination)
}
case SBA, LBA:
if destination.Kind() != reflect.Slice {
return errCantAssignf("cannot assign byte array to %v", destination)
}
if destination.Elem() != reflect.TypeOf(byte(0)) {
return errCantAssignf("cannot convert %v to *[]byte", destination)
}
case OTA:
if destination.Kind() != reflect.Slice {
return errCantAssignf("cannot assign array to %v", destination)
}
case KTV:
if destination != reflect.TypeOf(dummyMap) {
return errCantAssignf("cannot assign table to %v", destination)
}
default:
return fmt.Errorf("unknown TN %d", tag.TN())
}
return nil
}
// setInt expects a settable destination.
func setInt[T int64 | uint64](destination reflect.Value, value T) error {
func setInt[T int64 | uint64](destination reflect.Value, value T) {
switch {
case destination.CanInt():
destination.Set(reflect.ValueOf(int64(value)).Convert(destination.Type()))
case destination.CanUint():
destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
default:
return fmt.Errorf("cannot assign integer to %T", destination.Interface())
panic("setInt called on an unsupported type")
}
return nil
}
// setFloat expects a settable destination.
func setFloat(destination reflect.Value, value float64) error {
if !destination.CanFloat() {
return fmt.Errorf("cannot assign float to %T", destination.Interface())
}
func setFloat(destination reflect.Value, value float64) {
destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
return nil
}
// setByteArrayexpects a settable destination.
func setByteArray(destination reflect.Value, value []byte) error {
typ := destination.Type()
if typ.Kind() != reflect.Slice {
return fmt.Errorf("cannot assign %T to ", value)
}
if typ.Elem() != reflect.TypeOf(byte(0)) {
return fmt.Errorf("cannot convert %T to *[]byte", value)
}
func setByteArray(destination reflect.Value, value []byte) {
destination.Set(reflect.ValueOf(value))
return nil
}
// decodeAndSetInt expects a settable destination.
func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
value, nn, err := decoder.ReadIntN(bytes)
n += nn; if err != nil { return n, err }
return n, setInt(destination, value)
setInt(destination, value)
return n, nil
}
// decodeAndSetUint expects a settable destination.
func decodeAndSetUint(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
value, nn, err := decoder.ReadUintN(bytes)
n += nn; if err != nil { return n, err }
return n, setInt(destination, value)
setInt(destination, value)
return n, nil
}
// decodeAndSetInt expects a settable destination.
@ -309,13 +364,15 @@ func decodeAndSetFloat(decoder *Decoder, destination reflect.Value, bytes int) (
case 8:
value, nn, err := decoder.ReadFloat64()
n += nn; if err != nil { return n, err }
return n, setFloat(destination, float64(value))
setFloat(destination, float64(value))
return n, nil
case 4:
value, nn, err := decoder.ReadFloat32()
n += nn; if err != nil { return n, err }
return n, setFloat(destination, float64(value))
setFloat(destination, float64(value))
return n, nil
}
return n, fmt.Errorf("cannot decode float%d", bytes * 8)
return n, errCantAssignf("unsupported bit width float%d", bytes * 8)
}
// skeletonValue returns a pointer value. In order for it to be set, it must be

View File

@ -74,6 +74,101 @@ func TestEncodeAnyTable(test *testing.T) {
if err != nil { test.Fatal(err) }
}
func TestDecodeWrongType(test *testing.T) {
datas := [][]byte {
/* int8 */ []byte { byte(LSI.WithCN(0)), 0x45 },
/* int16 */ []byte { byte(LSI.WithCN(1)), 0x45, 0x67 },
/* int32 */ []byte { byte(LSI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB },
/* int64 */ []byte { byte(LSI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 },
/* uint5 */ []byte { byte(SI.WithCN(12)) },
/* uint8 */ []byte { byte(LI.WithCN(0)), 0x45 },
/* uint16 */ []byte { byte(LI.WithCN(1)), 0x45, 0x67 },
/* uint32 */ []byte { byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB },
/* uint64 */ []byte { byte(LI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 },
/* string */ []byte { byte(SBA.WithCN(7)), 'p', 'u', 'p', 'e', 'v', 'e', 'r' },
/* []byte */ []byte { byte(SBA.WithCN(5)), 'b', 'l', 'a', 'r', 'g' },
/* []string */ []byte {
byte(OTA.WithCN(0)), 2, byte(LBA.WithCN(0)),
0x08, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23,
0x05, 0x11, 0x11, 0x11, 0x11, 0x11,
},
/* map[uint16] any */ []byte {
byte(KTV.WithCN(0)), 2,
0x02, 0x23, byte(LSI.WithCN(1)), 0x45, 0x67,
0x02, 0x23, byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB,
},
}
for index, data := range datas {
test.Logf("data %2d %v [%s]", index, Tag(data[0]), tu.HexBytes(data[1:]))
// integers should only assign to other integers
if index > 8 {
cas := func(destination any) {
n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0]))
if err != nil { test.Fatalf("error: %v | n: %d", err, n) }
reflectValue := reflect.ValueOf(destination).Elem()
if reflectValue.CanInt() {
if reflectValue.Int() != 0 {
test.Fatalf("destination not zero: %v", reflectValue.Elem().Interface())
}
} else {
if reflectValue.Uint() != 0 {
test.Fatalf("destination not zero: %v", reflectValue.Elem().Interface())
}
}
if n != len(data) - 1 {
test.Fatalf("n not equal: %d != %d", n, len(data) - 1)
}
}
test.Log("- int8")
{ var dest int8; cas(&dest) }
test.Log("- int16")
{ var dest int16; cas(&dest) }
test.Log("- int32")
{ var dest int32; cas(&dest) }
test.Log("- int64")
{ var dest int64; cas(&dest) }
test.Log("- uint8")
{ var dest uint8; cas(&dest) }
test.Log("- uint16")
{ var dest uint16; cas(&dest) }
test.Log("- uint32")
{ var dest uint32; cas(&dest) }
test.Log("- uint64")
{ var dest uint64; cas(&dest) }
}
arrayCase := func(destination any) {
n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0]))
if err != nil { test.Fatalf("error: %v | n: %d", err, n) }
reflectDestination := reflect.ValueOf(destination)
reflectValue := reflectDestination.Elem()
if reflectValue.Len() != 0 {
test.Fatalf("len(destination) not zero: %v", reflectValue.Interface())
}
if n != len(data) - 1 {
test.Fatalf("n not equal: %d != %d", n, len(data) - 1)
}
}
// SBA/LBA types should only assign to other SBA/LBA types
if index != 9 && index != 10 {
test.Log("- string")
{ var dest string; arrayCase(&dest) }
test.Log("- []byte")
{ var dest []byte; arrayCase(&dest) }
}
// arrays should only assign to other arrays
if index != 11 {
test.Log("- []string")
{ var dest []string; arrayCase(&dest) }
}
// tables should only assign to other tables
if index != 12 {
test.Log("- map[uint16] any")
{ var dest = map[uint16] any { }; arrayCase(&dest) }
}
}
}
func TestEncodeDecodeAnyTable(test *testing.T) {
err := testEncodeDecodeAny(test, map[uint16] any {
0xF3B9: uint32(1),