Merge pull request 'any-type' (#20) from any-type into main

Reviewed-on: #20
This commit is contained in:
Sasha Koshka 2025-10-12 11:03:53 -06:00
commit f5de450c39
11 changed files with 454 additions and 190 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
/generate/test /generate/test
/debug

View File

@ -30,6 +30,7 @@ PDL allows defining a protocol using HOPP and TAPE.
| []\<TYPE\> | OTA | * | Array of any type[^1] | []\<TYPE\> | OTA | * | Array of any type[^1]
| Table | KTV | * | Table with undefined schema | Table | KTV | * | Table with undefined schema
| {...} | KTV | * | Table with defined schema | {...} | KTV | * | Table with defined schema
| Any | * | * | Value of an undefined type
[^1]: Excluding SI and SBA. I5 and U5 cannot be used in an array, but String and [^1]: Excluding SI and SBA. I5 and U5 cannot be used in an array, but String and
Buffer are simply forced to use their "long" variant. Buffer are simply forced to use their "long" variant.

View File

@ -235,17 +235,13 @@ func (this *Generator) generateMessage(method uint16, message Message) (n int, e
this.resolveMessageName(message.Name)) this.resolveMessageName(message.Name))
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
this.push() this.push()
nn, err = this.iprintf("tag := ") tagVar, nn, err := this.generateTag(message.Type, "(*this)")
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateTag(message.Type, "(*this)") nn, err = this.iprintf("nn, err := encoder.WriteTag(%s)\n", tagVar)
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err := encoder.WriteTag(tag)\n")
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck() nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(message.Type, "(*this)", "tag") nn, err = this.generateEncodeValue(message.Type, "(*this)", tagVar)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n") nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
@ -398,15 +394,11 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
this.push() this.push()
nn, err = this.iprintf("_ = item\n") nn, err = this.iprintf("_ = item\n")
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.iprintf("tag := ") tagVar, nn, err := this.generateTag(typ.Element, "item")
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateTag(typ.Element, "item") nn, err = this.iprintf("if %s.Is(tape.SBA) { continue }\n", tagVar)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.println() nn, err = this.iprintf("if %s.CN() > itemTag.CN() { itemTag = %s }\n", tagVar, tagVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if tag.Is(tape.SBA) { continue }\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if tag.CN() > itemTag.CN() { itemTag = tag }\n")
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
this.pop() this.pop()
nn, err = this.iprintf("}\n") nn, err = this.iprintf("}\n")
@ -461,25 +453,19 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
nn, err = this.iprintf("{\n") nn, err = this.iprintf("{\n")
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
this.push() this.push()
nn, err = this.iprintf("var tag tape.Tag\n")
n += nn; if err != nil { return n, err }
for key, field := range typ.Fields { for key, field := range typ.Fields {
nn, err = this.iprintf("nn, err = encoder.WriteUint16(0x%04X)\n", key) nn, err = this.iprintf("nn, err = encoder.WriteUint16(0x%04X)\n", key)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck() nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.iprintf("tag = ")
n += nn; if err != nil { return n, err }
fieldSource := fmt.Sprintf("%s.%s", valueSource, field.Name) fieldSource := fmt.Sprintf("%s.%s", valueSource, field.Name)
nn, err = this.generateTag(field.Type, fieldSource) tagVar, nn, err := this.generateTag(field.Type, fieldSource)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.println() nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(%s))\n", tagVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(tag))\n")
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck() nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(field.Type, fieldSource, "tag") nn, err = this.generateEncodeValue(field.Type, fieldSource, tagVar)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
} }
this.pop() this.pop()
@ -491,6 +477,12 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck() nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
case TypeAny:
// WHATEVER: [WHATEVER]
nn, err := this.iprintf("nn, err = tape.EncodeAny(encoder, %s, %s)\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
default: default:
panic(fmt.Errorf("unknown type: %T", typ)) panic(fmt.Errorf("unknown type: %T", typ))
} }
@ -627,7 +619,7 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS
case TypeTable: case TypeTable:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)* // KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf( nn, err := this.iprintf(
"nn, err = tape.DecodeAny(decoder, %s, %s)\n", "nn, err = tape.DecodeAnyInto(decoder, %s, %s)\n",
valueSource, tagSource) valueSource, tagSource)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck() nn, err = this.generateErrorCheck()
@ -642,6 +634,12 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck() nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
case TypeAny:
// WHATEVER: [WHATEVER]
nn, err := this.iprintf("*%s, nn, err = tape.DecodeAny(decoder, %s)\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
default: default:
panic(fmt.Errorf("unknown type: %T", typ)) panic(fmt.Errorf("unknown type: %T", typ))
} }
@ -924,49 +922,60 @@ func (this *Generator) generateErrorCheck() (n int, err error) {
return this.iprintf("n += nn; if err != nil { return n, err }\n") return this.iprintf("n += nn; if err != nil { return n, err }\n")
} }
func (this *Generator) generateBareErrorCheck() (n int, err error) {
return this.iprintf("if err != nil { return n, err }\n")
}
// generateTag generates the preferred TN and CN for the given type and value. // generateTag generates the preferred TN and CN for the given type and value.
// The generated code is INLINE. // The generated code is a BLOCK.
func (this *Generator) generateTag(typ Type, source string) (n int, err error) { func (this *Generator) generateTag(typ Type, source string) (tagVar string, n int, err error) {
tagVar = this.newTemporaryVar("tag")
switch typ := typ.(type) { switch typ := typ.(type) {
case TypeInt: case TypeInt:
if typ.Bits <= 5 { if typ.Bits <= 5 {
nn, err := this.printf("tape.SI.WithCN(int(%s))", source) nn, err := this.iprintf("%s := tape.SI.WithCN(int(%s))\n", tagVar, source)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return tagVar, n, err }
} else if typ.Signed { } else if typ.Signed {
nn, err := this.printf("tape.LSI.WithCN(%d)", bitsToCN(typ.Bits)) nn, err := this.iprintf("%s := tape.LSI.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits))
n += nn; if err != nil { return n, err } n += nn; if err != nil { return tagVar, n, err }
} else { } else {
nn, err := this.printf("tape.LI.WithCN(%d)", bitsToCN(typ.Bits)) nn, err := this.iprintf("%s := tape.LI.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits))
n += nn; if err != nil { return n, err } n += nn; if err != nil { return tagVar, n, err }
} }
case TypeFloat: case TypeFloat:
nn, err := this.printf("tape.FP.WithCN(%d)", bitsToCN(typ.Bits)) nn, err := this.iprintf("%s := tape.FP.WithCN(%d)\n", tagVar, bitsToCN(typ.Bits))
n += nn; if err != nil { return n, err } n += nn; if err != nil { return tagVar, n, err }
case TypeString: case TypeString:
nn, err := this.printf("tape.StringTag(string(%s))", source) nn, err := this.iprintf("%s := tape.StringTag(string(%s))\n", tagVar, source)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return tagVar, n, err }
case TypeBuffer: case TypeBuffer:
nn, err := this.printf("tape.BufferTag([]byte(%s))", source) nn, err := this.iprintf("%s := tape.BufferTag([]byte(%s))\n", tagVar, source)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return tagVar, n, err }
case TypeArray: case TypeArray:
nn, err := this.printf("tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))))", source) nn, err := this.iprintf("%s := tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))))\n", tagVar, source)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return tagVar, n, err }
case TypeTable: case TypeTable:
nn, err := this.printf("tape.KTV.WithCN(tape.IntBytes(uint64(len(%s))))", source) nn, err := this.iprintf("%s := tape.KTV.WithCN(tape.IntBytes(uint64(len(%s))))\n", tagVar, source)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return tagVar, n, err }
case TypeTableDefined: case TypeTableDefined:
nn, err := this.printf("tape.KTV.WithCN(%d)", tape.IntBytes(uint64(len(typ.Fields)))) nn, err := this.iprintf("%s := tape.KTV.WithCN(%d)\n", tagVar, tape.IntBytes(uint64(len(typ.Fields))))
n += nn; if err != nil { return n, err } n += nn; if err != nil { return tagVar, n, err }
case TypeNamed: case TypeNamed:
resolved, err := this.resolveTypeName(typ.Name) resolved, err := this.resolveTypeName(typ.Name)
if err != nil { return n, err } if err != nil { return tagVar, n, err }
nn, err := this.generateTag(resolved, source) subTagVar, nn, err := this.generateTag(resolved, source)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return tagVar, n, err }
tagVar = subTagVar
case TypeAny:
nn, err := this.iprintf("%s, err := tape.TagAny(%s)\n", tagVar, source)
n += nn; if err != nil { return tagVar, n, err }
nn, err = this.generateBareErrorCheck()
n += nn; if err != nil { return tagVar, n, err }
default: default:
panic(fmt.Errorf("unknown type: %T", typ)) panic(fmt.Errorf("unknown type: %T", typ))
} }
return n, nil return tagVar, n, nil
} }
// generateTN generates the appropriate TN for the given type. The generated // generateTN generates the appropriate TN for the given type. The generated
@ -1009,6 +1018,8 @@ func (this *Generator) generateTN(typ Type) (n int, err error) {
if err != nil { return n, err } if err != nil { return n, err }
nn, err := this.generateTN(resolved) nn, err := this.generateTN(resolved)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
default:
panic(fmt.Errorf("unknown type: %T", typ))
} }
return n, nil return n, nil
@ -1063,6 +1074,11 @@ func (this *Generator) generateType(typ Type) (n int, err error) {
case TypeNamed: case TypeNamed:
nn, err := this.print(typ.Name) nn, err := this.print(typ.Name)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
case TypeAny:
nn, err := this.print("any")
n += nn; if err != nil { return n, err }
default:
panic(fmt.Errorf("unknown type: %T", typ))
} }
return n, nil return n, nil
} }
@ -1092,12 +1108,17 @@ func (this *Generator) generateTypeTableDefined(typ TypeTableDefined) (n int, er
// by tagSource can be assigned to a Go destination generated from typ. The // by tagSource can be assigned to a Go destination generated from typ. The
// generated code is INLINE. // generated code is INLINE.
func (this *Generator) generateCanAssign(typ Type, tagSource string) (n int, err error) { func (this *Generator) generateCanAssign(typ Type, tagSource string) (n int, err error) {
if _, ok := typ.(TypeAny); ok {
nn, err := this.printf("true")
n += nn; if err != nil { return n, err }
} else {
nn, err := this.printf("canAssign(") nn, err := this.printf("canAssign(")
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.generateTN(typ) nn, err = this.generateTN(typ)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
nn, err = this.printf(", %s)", tagSource) nn, err = this.printf(", %s)", tagSource)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
}
return n, nil return n, nil
} }

View File

@ -62,6 +62,27 @@ func init() {
}, },
}, },
} }
exampleProtocol.Messages[0x0005] = Message {
Name: "Dynamic",
Type: TypeTableDefined {
Fields: map[uint16] Field {
0x0000: Field { Name: "AU8", Type: TypeAny { } },
0x0001: Field { Name: "AU16", Type: TypeAny { } },
0x0002: Field { Name: "AU32", Type: TypeAny { } },
0x0003: Field { Name: "AU64", Type: TypeAny { } },
0x0004: Field { Name: "AI8", Type: TypeAny { } },
0x0005: Field { Name: "AI16", Type: TypeAny { } },
0x0006: Field { Name: "AI32", Type: TypeAny { } },
0x0007: Field { Name: "AI64", Type: TypeAny { } },
0x0008: Field { Name: "AF32", Type: TypeAny { } },
0x0009: Field { Name: "AF64", Type: TypeAny { } },
0x000A: Field { Name: "AString", Type: TypeAny { } },
0x000B: Field { Name: "AArray", Type: TypeAny { } },
0x000C: Field { Name: "ATable", Type: TypeAny { } },
0x000D: Field { Name: "T0", Type: TypeTable { } },
},
},
}
exampleProtocol.Types["User"] = TypeTableDefined { exampleProtocol.Types["User"] = TypeTableDefined {
Fields: map[uint16] Field { Fields: map[uint16] Field {
0x0000: Field { Name: "Name", Type: TypeString { } }, 0x0000: Field { Name: "Name", Type: TypeString { } },
@ -196,6 +217,53 @@ func TestGenerateRunEncodeDecode(test *testing.T) {
[]byte { 0x00, 0x0D, 0x43, 0xEF, 0x1E, 0xCB, 0x37 }, []byte { 0x00, 0x0D, 0x43, 0xEF, 0x1E, 0xCB, 0x37 },
[]byte { 0x00, 0x0E, 0x47, 0x9C, 0x6E, 0xF6, 0x43, 0xEF, 0x1E, 0xCB, 0x37 }, []byte { 0x00, 0x0E, 0x47, 0x9C, 0x6E, 0xF6, 0x43, 0xEF, 0x1E, 0xCB, 0x37 },
)) ))
log.Println("MessageDynamic")
messageDynamic := MessageDynamic {
AU8: uint8(0x23),
AU16: uint16(0x3247),
AU32: uint32(0x87324523),
AU64: uint64(0x3284029034098234),
AI8: int8(0x23),
AI16: int16(0x3247),
AI32: int32(0x57324523),
AI64: int64(0x3284029034098234),
AF32: float32(2342.2378),
AF64: float64(324.8899992),
AString: "fox bed",
AArray: []int16 { 0x7, 0x6, 0x5, 0x4 },
ATable: map[uint16] any {
0x0001: int8(0x8),
0x0002: float64(4.4),
},
T0: map[uint16] any {
0x0001: float32(489.5),
0x0002: "hi",
0x0003: uint16(0x3992),
},
}
testEncodeDecode(
&messageDynamic,
tu.S(0xE1, 14).AddVar(
[]byte { 0x00, 0x00, 0x20, 0x23 },
[]byte { 0x00, 0x01, 0x21, 0x32, 0x47 },
[]byte { 0x00, 0x02, 0x23, 0x87, 0x32, 0x45, 0x23 },
[]byte { 0x00, 0x03, 0x27, 0x32, 0x84, 0x02, 0x90, 0x34, 0x09, 0x82, 0x34 },
[]byte { 0x00, 0x04, 0x40, 0x23 },
[]byte { 0x00, 0x05, 0x41, 0x32, 0x47 },
[]byte { 0x00, 0x06, 0x43, 0x57, 0x32, 0x45, 0x23 },
[]byte { 0x00, 0x07, 0x47, 0x32, 0x84, 0x02, 0x90, 0x34, 0x09, 0x82, 0x34 },
[]byte { 0x00, 0x08, 0x63, 0x45, 0x12, 0x63, 0xCE },
[]byte { 0x00, 0x09, 0x67, 0x40, 0x74, 0x4E, 0x3D, 0x6F, 0xCD, 0x17, 0x75 },
[]byte { 0x00, 0x0A, 0x87, 'f', 'o', 'x', ' ', 'b', 'e', 'd' },
[]byte { 0x00, 0x0B, 0xC4, 0x00, 0x07, 0x00, 0x06, 0x00, 0x05, 0x00, 0x04 },
[]byte { 0x00, 0x0C, 0xE1, 0x02,
0x00, 0x01, 0x20, 0x08,
0x00, 0x02, 0x67, 0x40, 0x11, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9A },
[]byte { 0x00, 0x0D, 0xE1, 0x03,
0x00, 0x01, 0x63, 0x43, 0xF4, 0xC0, 0x00,
0x00, 0x02, 0x82, 'h', 'i',
0x00, 0x03, 0x21, 0x39, 0x92 },
))
`) `)
} }

View File

@ -116,6 +116,7 @@ func (this *parser) parseType() (Type, error) {
case "String": return TypeString { }, this.Next() case "String": return TypeString { }, this.Next()
case "Buffer": return TypeBuffer { }, this.Next() case "Buffer": return TypeBuffer { }, this.Next()
case "Table": return TypeTable { }, this.Next() case "Table": return TypeTable { }, this.Next()
case "Any": return TypeAny { }, this.Next()
} }
return this.parseTypeNamed() return this.parseTypeNamed()
case TokenLBracket: case TokenLBracket:

View File

@ -31,6 +31,7 @@ func TestParse(test *testing.T) {
0x0002: Field { Name: "Followers", Type: TypeInt { Bits: 32 } }, 0x0002: Field { Name: "Followers", Type: TypeInt { Bits: 32 } },
}, },
} }
correct.Types["Anything"] = TypeAny { }
test.Log("CORRECT:", &correct) test.Log("CORRECT:", &correct)
got, err := ParseReader("test.pdl", strings.NewReader(` got, err := ParseReader("test.pdl", strings.NewReader(`
@ -48,6 +49,8 @@ func TestParse(test *testing.T) {
0001 Bio String, 0001 Bio String,
0002 Followers U32, 0002 Followers U32,
} }
Anything Any
`)) `))
if err != nil { test.Fatal(parse.Format(err)) } if err != nil { test.Fatal(parse.Format(err)) }
test.Log("GOT: ", got) test.Log("GOT: ", got)

View File

@ -99,6 +99,12 @@ func (typ TypeNamed) String() string {
return typ.Name return typ.Name
} }
type TypeAny struct { }
func (typ TypeAny) String() string {
return "Any"
}
func HashType(typ Type) [16]byte { func HashType(typ Type) [16]byte {
// TODO: if we ever want to make the compiler more efficient, this would // TODO: if we ever want to make the compiler more efficient, this would
// be a good place to start, complex string concatenation in a hot path // be a good place to start, complex string concatenation in a hot path

View File

@ -119,6 +119,10 @@ type describer struct {
} }
func (this *describer) describe(value reflect.Value) { func (this *describer) describe(value reflect.Value) {
if !value.IsValid() {
this.printf("<invalid>")
return
}
value = reflect.ValueOf(value.Interface()) value = reflect.ValueOf(value.Interface())
switch value.Kind() { switch value.Kind() {
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:

View File

@ -8,6 +8,17 @@ package tape
// TODO: test all of these smaller functions individually // TODO: test all of these smaller functions individually
// For an explanation as to why this package always treats LBA/SBA as strings,
// refer to https://go.dev/blog/strings:
//
// Its important to state right up front that a string holds arbitrary
// bytes. It is not required to hold Unicode text, UTF-8 text, or any other
// predefined format. As far as the content of a string is concerned, it is
// exactly equivalent to a slice of bytes.
//
// Arbitrary byte slices and blobs won't be as common of a use case as text
// data, and if you need that anyway you can just cast it to a byte slice.
import "fmt" import "fmt"
import "reflect" import "reflect"
@ -46,6 +57,8 @@ func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
case reflect.Uint32: return encoder.WriteUint32(uint32(reflectValue.Uint())) case reflect.Uint32: return encoder.WriteUint32(uint32(reflectValue.Uint()))
case reflect.Int64: return encoder.WriteInt64(int64(reflectValue.Int())) case reflect.Int64: return encoder.WriteInt64(int64(reflectValue.Int()))
case reflect.Uint64: return encoder.WriteUint64(uint64(reflectValue.Uint())) case reflect.Uint64: return encoder.WriteUint64(uint64(reflectValue.Uint()))
case reflect.Float32: return encoder.WriteFloat32(float32(reflectValue.Float()))
case reflect.Float64: return encoder.WriteFloat64(float64(reflectValue.Float()))
case reflect.String: case reflect.String:
if reflectValue.Len() > MaxStructureLength { if reflectValue.Len() > MaxStructureLength {
return 0, ErrTooLong return 0, ErrTooLong
@ -86,9 +99,10 @@ func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
return n, fmt.Errorf("cannot encode type %T", value) return n, fmt.Errorf("cannot encode type %T", value)
} }
// DecodeAny decodes data and places it into destination, which must be a // DecodeAnyInto decodes data and places it into destination, which must be a
// pointer to a supported type. See [EncodeAny] for a list of supported types. // pointer to a supported type. See [EncodeAny] for a list of supported types.
func DecodeAny(decoder *Decoder, destination any, tag Tag) (n int, err error) { // The head of the decoder must be at the start of the payload.
func DecodeAnyInto(decoder *Decoder, destination any, tag Tag) (n int, err error) {
reflectDestination := reflect.ValueOf(destination) reflectDestination := reflect.ValueOf(destination)
if reflectDestination.Kind() != reflect.Pointer { if reflectDestination.Kind() != reflect.Pointer {
return n, fmt.Errorf("expected pointer destination, not %v", destination) return n, fmt.Errorf("expected pointer destination, not %v", destination)
@ -96,6 +110,17 @@ func DecodeAny(decoder *Decoder, destination any, tag Tag) (n int, err error) {
return decodeAny(decoder, reflectDestination.Elem(), tag) return decodeAny(decoder, reflectDestination.Elem(), tag)
} }
// DecodeAny is like [DecodeAnyInto], but it automatically creates the
// destination from the tag and data. The head of the decoder must be at the
// start of the payload.
func DecodeAny(decoder *Decoder, tag Tag) (value any, n int, err error) {
destination, err := skeletonPointer(decoder, tag)
if err != nil { return nil, n, err }
nn, err := DecodeAnyInto(decoder, destination, tag)
n += nn; if err != nil { return nil, n, err }
return destination, n, err
}
// unknownSlicePlaceholder is inserted by skeletonValue and informs the program // unknownSlicePlaceholder is inserted by skeletonValue and informs the program
// that the destination for the slice needs to be generated based on the item // that the destination for the slice needs to be generated based on the item
// tag in the OTA. // tag in the OTA.
@ -105,7 +130,8 @@ var unknownSlicePlaceholderType = reflect.TypeOf(unknownSlicePlaceholder { })
// decodeAny is internal to [DecodeAny]. It takes in an addressable // decodeAny is internal to [DecodeAny]. It takes in an addressable
// [reflect.Value] as the destination. If the decoded value cannot fit in the // [reflect.Value] as the destination. If the decoded value cannot fit in the
// destination, it skims over the payload, leaves the destination empty, and // destination, it skims over the payload, leaves the destination empty, and
// returns without an error. // returns without an error. The head of the decoder must be at the start of the
// payload.
func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) { func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) {
n, err = decodeAnyOrError(decoder, destination, tag) n, err = decodeAnyOrError(decoder, destination, tag)
if _, ok := err.(errCantAssign); ok { if _, ok := err.(errCantAssign); ok {
@ -122,7 +148,7 @@ func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err
// destination, it decodes nothing and returns an error of type errCantAssign, // destination, it decodes nothing and returns an error of type errCantAssign,
// except for the case of a mismatched OTA element tag, wherein it will skim // except for the case of a mismatched OTA element tag, wherein it will skim
// over the rest of the payload, leave the destination empty, and return without // over the rest of the payload, leave the destination empty, and return without
// an error. // an error. The head of the decoder must be at the start of the payload.
func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) { func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) {
err = canSet(destination.Type(), tag) err = canSet(destination.Type(), tag)
if err != nil { return n, err } if err != nil { return n, err }
@ -130,7 +156,7 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
switch tag.WithoutCN() { switch tag.WithoutCN() {
case SI: case SI:
// SI: (none) // SI: (none)
setInt(destination, uint64(tag.CN())) setUint(destination, uint64(tag.CN()), 1)
case LI: case LI:
// LI: <value: IntN> // LI: <value: IntN>
nn, err := decodeAndSetUint(decoder, destination, tag.CN() + 1) nn, err := decodeAndSetUint(decoder, destination, tag.CN() + 1)
@ -152,7 +178,7 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
buffer := make([]byte, length) buffer := make([]byte, length)
nn, err := decoder.Read(buffer) nn, err := decoder.Read(buffer)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
setByteArray(destination, buffer) setString(destination, string(buffer))
case LBA: case LBA:
// LBA: <length: UN> <data: U8>* // LBA: <length: UN> <data: U8>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1) length, nn, err := decoder.ReadUintN(tag.CN() + 1)
@ -163,9 +189,16 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
buffer := make([]byte, length) buffer := make([]byte, length)
nn, err = decoder.Read(buffer) nn, err = decoder.Read(buffer)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
setByteArray(destination, buffer) setString(destination, string(buffer))
case OTA: case OTA:
// OTA: <length: UN> <elementTag: tape.Tag> <values>* // OTA: <length: UN> <elementTag: tape.Tag> <values>*
oldDestination := destination
if isTypeAny(destination.Type()) {
// need a skeleton value if we are assigning to any.
value, err := skeletonValue(decoder, tag)
if err != nil { return n, err }
destination = value
}
length, nn, err := decoder.ReadUintN(tag.CN() + 1) length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
if length > uint64(MaxStructureLength) { if length > uint64(MaxStructureLength) {
@ -199,6 +232,7 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
return n, err return n, err
} }
} }
oldDestination.Set(destination)
case KTV: case KTV:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)* // KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
length, nn, err := decoder.ReadUintN(tag.CN() + 1) length, nn, err := decoder.ReadUintN(tag.CN() + 1)
@ -206,13 +240,21 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
if length > uint64(MaxStructureLength) { if length > uint64(MaxStructureLength) {
return 0, ErrTooLong return 0, ErrTooLong
} }
lengthCast, err := Uint64ToIntSafe(length)
if err != nil { return n, err }
if isTypeAny(destination.Type()) {
// need a skeleton value if we are assigning to any.
value := reflect.MakeMapWithSize(reflect.TypeOf(dummyMap), lengthCast)
destination.Set(value)
destination = value
}
destination.Clear() destination.Clear()
for _ = range length { for _ = range lengthCast {
key, nn, err := decoder.ReadUint16() key, nn, err := decoder.ReadUint16()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
itemTag, nn, err := decoder.ReadTag() itemTag, nn, err := decoder.ReadTag()
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
value, err := skeletonValue(decoder, itemTag) value, err := skeletonPointer(decoder, itemTag)
if err != nil { return n, err } if err != nil { return n, err }
nn, err = decodeAny(decoder, value.Elem(), itemTag) nn, err = decodeAny(decoder, value.Elem(), itemTag)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
@ -244,6 +286,8 @@ func tagAny(reflectValue reflect.Value) (Tag, error) {
case reflect.Uint16: return LI.WithCN(1), nil case reflect.Uint16: return LI.WithCN(1), nil
case reflect.Uint32: return LI.WithCN(3), nil case reflect.Uint32: return LI.WithCN(3), nil
case reflect.Uint64: return LI.WithCN(7), nil case reflect.Uint64: return LI.WithCN(7), nil
case reflect.Float32: return FP.WithCN(3), nil
case reflect.Float64: return FP.WithCN(7), nil
case reflect.String: return bufferLenTag(reflectValue.Len()), nil case reflect.String: return bufferLenTag(reflectValue.Len()), nil
} }
if reflectValue.CanConvert(reflect.TypeOf(dummyBuffer)) { if reflectValue.CanConvert(reflect.TypeOf(dummyBuffer)) {
@ -311,6 +355,10 @@ func encodeAnyMap(encoder *Encoder, value any, tag Tag) (n int, err error) {
} }
func canSet(destination reflect.Type, tag Tag) error { func canSet(destination reflect.Type, tag Tag) error {
// anything can be assigned to `any`
if isTypeAny(destination) {
return nil
}
switch tag.WithoutCN() { switch tag.WithoutCN() {
case SI, LI, LSI: case SI, LI, LSI:
switch destination.Kind() { switch destination.Kind() {
@ -327,6 +375,7 @@ func canSet(destination reflect.Type, tag Tag) error {
return errCantAssignf("cannot assign float to %v", destination) return errCantAssignf("cannot assign float to %v", destination)
} }
case SBA, LBA: case SBA, LBA:
if destination.Kind() == reflect.String { return nil }
if destination.Kind() != reflect.Slice { if destination.Kind() != reflect.Slice {
return errCantAssignf("cannot assign byte array to %v", destination) return errCantAssignf("cannot assign byte array to %v", destination)
} }
@ -348,17 +397,43 @@ func canSet(destination reflect.Type, tag Tag) error {
} }
// setInt expects a settable destination. // setInt expects a settable destination.
func setInt[T int64 | uint64](destination reflect.Value, value T) { func setInt(destination reflect.Value, value int64, bytes int) {
switch { switch {
case destination.CanInt(): case destination.CanInt():
destination.Set(reflect.ValueOf(int64(value)).Convert(destination.Type())) destination.Set(reflect.ValueOf(int64(value)).Convert(destination.Type()))
case destination.CanUint(): case destination.CanUint():
destination.Set(reflect.ValueOf(value).Convert(destination.Type())) destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
case isTypeAny(destination.Type()):
switch {
case bytes > 4: destination.Set(reflect.ValueOf(int64(value)))
case bytes > 2: destination.Set(reflect.ValueOf(int32(value)))
case bytes > 1: destination.Set(reflect.ValueOf(int16(value)))
default: destination.Set(reflect.ValueOf(int8(value)))
}
default: default:
panic("setInt called on an unsupported type") panic("setInt called on an unsupported type")
} }
} }
// setUint expects a settable destination.
func setUint(destination reflect.Value, value uint64, bytes int) {
switch {
case destination.CanInt():
destination.Set(reflect.ValueOf(int64(value)).Convert(destination.Type()))
case destination.CanUint():
destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
case isTypeAny(destination.Type()):
switch {
case bytes > 4: destination.Set(reflect.ValueOf(uint64(value)))
case bytes > 2: destination.Set(reflect.ValueOf(uint32(value)))
case bytes > 1: destination.Set(reflect.ValueOf(uint16(value)))
default: destination.Set(reflect.ValueOf(uint8(value)))
}
default:
panic("setUint called on an unsupported type")
}
}
// setFloat expects a settable destination. // setFloat expects a settable destination.
func setFloat(destination reflect.Value, value float64) { func setFloat(destination reflect.Value, value float64) {
destination.Set(reflect.ValueOf(value).Convert(destination.Type())) destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
@ -369,11 +444,16 @@ func setByteArray(destination reflect.Value, value []byte) {
destination.Set(reflect.ValueOf(value)) destination.Set(reflect.ValueOf(value))
} }
// setString exepctes a settable destination
func setString(destination reflect.Value, value string) {
destination.Set(reflect.ValueOf(value))
}
// decodeAndSetInt expects a settable destination. // decodeAndSetInt expects a settable destination.
func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) { func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
value, nn, err := decoder.ReadIntN(bytes) value, nn, err := decoder.ReadIntN(bytes)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
setInt(destination, value) setInt(destination, value, bytes)
return n, nil return n, nil
} }
@ -381,7 +461,7 @@ func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n
func decodeAndSetUint(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) { func decodeAndSetUint(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
value, nn, err := decoder.ReadUintN(bytes) value, nn, err := decoder.ReadUintN(bytes)
n += nn; if err != nil { return n, err } n += nn; if err != nil { return n, err }
setInt(destination, value) setUint(destination, value, bytes)
return n, nil return n, nil
} }
@ -402,16 +482,26 @@ func decodeAndSetFloat(decoder *Decoder, destination reflect.Value, bytes int) (
return n, errCantAssignf("unsupported bit width float%d", bytes * 8) return n, errCantAssignf("unsupported bit width float%d", bytes * 8)
} }
// skeletonValue returns a pointer value. In order for it to be set, it must be // skeletonValue returns an addressable value. It can be set directly. The head
// dereferenced using Elem(). // of the decoder must be at the start of the payload when calling.
func skeletonValue(decoder *Decoder, tag Tag) (reflect.Value, error) { func skeletonValue(decoder *Decoder, tag Tag) (reflect.Value, error) {
ptr, err := skeletonPointer(decoder, tag)
if err != nil { return reflect.Value { }, err }
return ptr.Elem(), nil
}
// skeletonPointer returns a pointer value. In order for it to be set, it must
// be dereferenced using Elem(). The head of the decoder must be at the start of
// the payload when calling.
func skeletonPointer(decoder *Decoder, tag Tag) (reflect.Value, error) {
typ, err := typeOf(decoder, tag) typ, err := typeOf(decoder, tag)
if err != nil { return reflect.Value { }, err } if err != nil { return reflect.Value { }, err }
return reflect.New(typ), nil return reflect.New(typ), nil
} }
// typeOf returns the type of the current tag being decoded. It does not use up // typeOf returns the type of the current tag being decoded. It does not use up
// the decoder, it only peeks. // the decoder, it only peeks. The head of the decoder must be at the start of
// the payload when calling.
func typeOf(decoder *Decoder, tag Tag) (reflect.Type, error) { func typeOf(decoder *Decoder, tag Tag) (reflect.Type, error) {
switch tag.WithoutCN() { switch tag.WithoutCN() {
case SI: case SI:
@ -438,8 +528,8 @@ func typeOf(decoder *Decoder, tag Tag) (reflect.Type, error) {
case 7: return reflect.TypeOf(float64(0)), nil case 7: return reflect.TypeOf(float64(0)), nil
} }
return nil, fmt.Errorf("unknown CN %d for FP", tag.CN()) return nil, fmt.Errorf("unknown CN %d for FP", tag.CN())
case SBA: return reflect.SliceOf(reflect.TypeOf(byte(0))), nil case SBA: return reflect.TypeOf(""), nil
case LBA: return reflect.SliceOf(reflect.TypeOf(byte(0))), nil case LBA: return reflect.TypeOf(""), nil
case OTA: case OTA:
elemTag, dimension, err := peekSlice(decoder, tag) elemTag, dimension, err := peekSlice(decoder, tag)
if err != nil { return nil, err } if err != nil { return nil, err }
@ -455,6 +545,12 @@ func typeOf(decoder *Decoder, tag Tag) (reflect.Type, error) {
return nil, fmt.Errorf("unknown TN %d", tag.TN()) return nil, fmt.Errorf("unknown TN %d", tag.TN())
} }
// isTypeAny returns whether the given reflect.Type is an interface with no
// methods.
func isTypeAny(typ reflect.Type) bool {
return typ.Kind() == reflect.Interface && typ.NumMethod() == 0
}
// peekSlice returns the element tag and dimension count of the OTA currently // peekSlice returns the element tag and dimension count of the OTA currently
// being decoded. It does not use up the decoder, it only peeks. // being decoded. It does not use up the decoder, it only peeks.
func peekSlice(decoder *Decoder, tag Tag) (Tag, int, error) { func peekSlice(decoder *Decoder, tag Tag) (Tag, int, error) {

View File

@ -1,11 +1,56 @@
package tape package tape
import "fmt"
import "bytes" import "bytes"
import "testing" import "testing"
import "reflect" import "reflect"
import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil" import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil"
var samplePayloads = [][]byte {
/* int8 */ []byte { byte(LSI.WithCN(0)), 0x45 },
/* int16 */ []byte { byte(LSI.WithCN(1)), 0x45, 0x67 },
/* int32 */ []byte { byte(LSI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB },
/* int64 */ []byte { byte(LSI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 },
/* uint5 */ []byte { byte(SI.WithCN(12)) },
/* uint8 */ []byte { byte(LI.WithCN(0)), 0x45 },
/* uint16 */ []byte { byte(LI.WithCN(1)), 0x45, 0x67 },
/* uint32 */ []byte { byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB },
/* uint64 */ []byte { byte(LI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 },
/* string */ []byte { byte(SBA.WithCN(7)), 'p', 'u', 'p', 'e', 'v', 'e', 'r' },
/* []byte */ []byte { byte(SBA.WithCN(5)), 'b', 'l', 'a', 'r', 'g' },
/* []string */ []byte {
byte(OTA.WithCN(0)), 2, byte(LBA.WithCN(0)),
0x08, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23,
0x05, 0x11, 0x11, 0x11, 0x11, 0x11,
},
/* map[uint16] any */ []byte {
byte(KTV.WithCN(0)), 2,
0x02, 0x23, byte(LSI.WithCN(1)), 0x45, 0x67,
0x02, 0x24, byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB,
},
}
var sampleValues = []any {
/* int8 */ int8(0x45),
/* int16 */ int16(0x4567),
/* int32 */ int32(0x456789AB),
/* int64 */ int64(0x456789ABCDEF0123),
/* uint5 */ uint8(12),
/* uint8 */ uint8(0x45),
/* uint16 */ uint16(0x4567),
/* uint32 */ uint32(0x456789AB),
/* uint64 */ uint64(0x456789ABCDEF0123),
/* string */ "pupever",
/* []byte */ "blarg",
/* []string */ []string {
"\x45\x67\x89\xAB\xCD\xEF\x01\x23",
"\x11\x11\x11\x11\x11",
},
/* map[uint16] any */ map[uint16] any {
0x0223: int16(0x4567),
0x0224: uint32(0x456789AB),
},
}
type userDefinedInteger int16 type userDefinedInteger int16
func TestEncodeAnyInt(test *testing.T) { func TestEncodeAnyInt(test *testing.T) {
@ -26,7 +71,9 @@ func TestEncodeAnyTable(test *testing.T) {
0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} }, 0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} },
0x2345: [][]int16 { []int16 { 0x5 }, []int16 { 0x17, -0xAAA } }, 0x2345: [][]int16 { []int16 { 0x5 }, []int16 { 0x17, -0xAAA } },
0x3456: userDefinedInteger(0x3921), 0x3456: userDefinedInteger(0x3921),
}, KTV.WithCN(0), tu.S(7).AddVar( 0x1F1F: float32(67.26),
0x0F0F: float64(5.3),
}, KTV.WithCN(0), tu.S(9).AddVar(
[]byte { []byte {
0xF3, 0xB9, 0xF3, 0xB9,
byte(LSI.WithCN(3)), byte(LSI.WithCN(3)),
@ -70,41 +117,27 @@ func TestEncodeAnyTable(test *testing.T) {
byte(LSI.WithCN(1)), byte(LSI.WithCN(1)),
0x39, 0x21, 0x39, 0x21,
}, },
[]byte {
0x1F, 0x1F,
byte(FP.WithCN(3)),
0x42, 0x86, 0x85, 0x1F,
},
[]byte {
0x0F, 0x0F,
byte(FP.WithCN(7)),
0x40, 0x15, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33,
},
)) ))
if err != nil { test.Fatal(err) } if err != nil { test.Fatal(err) }
} }
func TestDecodeWrongType(test *testing.T) { func TestDecodeWrongType(test *testing.T) {
datas := [][]byte { for index, data := range samplePayloads {
/* int8 */ []byte { byte(LSI.WithCN(0)), 0x45 },
/* int16 */ []byte { byte(LSI.WithCN(1)), 0x45, 0x67 },
/* int32 */ []byte { byte(LSI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB },
/* int64 */ []byte { byte(LSI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 },
/* uint5 */ []byte { byte(SI.WithCN(12)) },
/* uint8 */ []byte { byte(LI.WithCN(0)), 0x45 },
/* uint16 */ []byte { byte(LI.WithCN(1)), 0x45, 0x67 },
/* uint32 */ []byte { byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB },
/* uint64 */ []byte { byte(LI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 },
/* string */ []byte { byte(SBA.WithCN(7)), 'p', 'u', 'p', 'e', 'v', 'e', 'r' },
/* []byte */ []byte { byte(SBA.WithCN(5)), 'b', 'l', 'a', 'r', 'g' },
/* []string */ []byte {
byte(OTA.WithCN(0)), 2, byte(LBA.WithCN(0)),
0x08, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23,
0x05, 0x11, 0x11, 0x11, 0x11, 0x11,
},
/* map[uint16] any */ []byte {
byte(KTV.WithCN(0)), 2,
0x02, 0x23, byte(LSI.WithCN(1)), 0x45, 0x67,
0x02, 0x23, byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB,
},
}
for index, data := range datas {
test.Logf("data %2d %v [%s]", index, Tag(data[0]), tu.HexBytes(data[1:])) test.Logf("data %2d %v [%s]", index, Tag(data[0]), tu.HexBytes(data[1:]))
// integers should only assign to other integers // integers should only assign to other integers
if index > 8 { if index > 8 {
cas := func(destination any) { cas := func(destination any) {
n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0])) n, err := DecodeAnyInto(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0]))
if err != nil { test.Fatalf("error: %v | n: %d", err, n) } if err != nil { test.Fatalf("error: %v | n: %d", err, n) }
reflectValue := reflect.ValueOf(destination).Elem() reflectValue := reflect.ValueOf(destination).Elem()
if reflectValue.CanInt() { if reflectValue.CanInt() {
@ -138,7 +171,7 @@ func TestDecodeWrongType(test *testing.T) {
{ var dest uint64; cas(&dest) } { var dest uint64; cas(&dest) }
} }
arrayCase := func(destination any) { arrayCase := func(destination any) {
n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0])) n, err := DecodeAnyInto(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0]))
if err != nil { test.Fatalf("error: %v | n: %d", err, n) } if err != nil { test.Fatalf("error: %v | n: %d", err, n) }
reflectDestination := reflect.ValueOf(destination) reflectDestination := reflect.ValueOf(destination)
reflectValue := reflectDestination.Elem() reflectValue := reflectDestination.Elem()
@ -175,13 +208,36 @@ func TestEncodeDecodeAnyTable(test *testing.T) {
0x0102: uint32(2), 0x0102: uint32(2),
0x0103: int64(23432), 0x0103: int64(23432),
0x0104: int64(-88777), 0x0104: int64(-88777),
0x0000: []byte("hi!"), 0x0000: "hi!",
0xFFFF: []uint16 { 0xBEE5, 0x7777 }, 0xFFFF: []uint16 { 0xBEE5, 0x7777 },
0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} }, 0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} },
0x1F1F: float32(67.26),
0x0F0F: float64(5.3),
}, nil) }, nil)
if err != nil { test.Fatal(err) } if err != nil { test.Fatal(err) }
} }
func TestEncodeDecodeAnyDestination(test *testing.T) {
var destination any
for index, data := range samplePayloads {
tag := Tag(data[0])
payload := data[1:]
test.Logf("data %2d %v [%s]", index, tag, tu.HexBytes(payload))
n, err := DecodeAnyInto(NewDecoder(bytes.NewBuffer(payload)), &destination, tag)
if err != nil { test.Fatalf("error: %v | n: %d", err, n) }
got := destination
correct := sampleValues[index]
test.Log("got: ", tu.Describe(got))
test.Log("correct:", tu.Describe(correct))
if !reflect.DeepEqual(got, correct) {
test.Fatalf("values not equal")
}
if n != len(payload) {
test.Fatalf("n not equal: %d != %d", n, len(payload))
}
}
}
func TestPeekSlice(test *testing.T) { func TestPeekSlice(test *testing.T) {
buffer := bytes.NewBuffer([]byte { buffer := bytes.NewBuffer([]byte {
2, byte(OTA.WithCN(3)), 2, byte(OTA.WithCN(3)),
@ -240,71 +296,3 @@ func TestPeekSliceOnce(test *testing.T) {
test.Fatalf("wrong n: %d != %d", got, correct) test.Fatalf("wrong n: %d != %d", got, correct)
} }
} }
func encAny(value any) ([]byte, Tag, int, error) {
tag, err := TagAny(value)
if err != nil { return nil, 0, 0, err }
buffer := bytes.Buffer { }
encoder := NewEncoder(&buffer)
n, err := EncodeAny(encoder, value, tag)
if err != nil { return nil, 0, n, err }
encoder.Flush()
return buffer.Bytes(), tag, n, nil
}
func decAny(data []byte) (Tag, any, int, error) {
destination := map[uint16] any { }
tag, err := TagAny(destination)
if err != nil { return 0, nil, 0, err }
n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data)), &destination, tag)
if err != nil { return 0, nil, n, err }
return tag, destination, n, nil
}
func testEncodeAny(test *testing.T, value any, correctTag Tag, correctBytes tu.Snake) error {
bytes, tag, n, err := encAny(value)
if err != nil { return err }
test.Log("n: ", n)
test.Log("tag: ", tag)
test.Log("got: ", tu.HexBytes(bytes))
test.Log("correct:", correctBytes)
if tag != correctTag {
return fmt.Errorf("tag not equal: %v != %v", tag, correctTag)
}
if ok, n := correctBytes.Check(bytes); !ok {
return fmt.Errorf("bytes not equal at index %d", n)
}
if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes))
}
return nil
}
func testEncodeDecodeAny(test *testing.T, value, correctValue any) error {
if correctValue == nil {
correctValue = value
}
test.Log("encoding...")
bytes, tag, n, err := encAny(value)
if err != nil { return err }
test.Log("n: ", n)
test.Log("tag:", tag)
test.Log("got:", tu.HexBytes(bytes))
test.Log("decoding...", tag)
if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes))
}
_, decoded, n, err := decAny(bytes)
if err != nil { return err }
test.Log("got: ", tu.Describe(decoded))
test.Log("correct:", tu.Describe(correctValue))
if !reflect.DeepEqual(decoded, correctValue) {
return fmt.Errorf("values not equal")
}
if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes))
}
return nil
}

75
tape/misc_test.go Normal file
View File

@ -0,0 +1,75 @@
package tape
import "fmt"
import "bytes"
import "testing"
import "reflect"
import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil"
func encAny(value any) ([]byte, Tag, int, error) {
tag, err := TagAny(value)
if err != nil { return nil, 0, 0, err }
buffer := bytes.Buffer { }
encoder := NewEncoder(&buffer)
n, err := EncodeAny(encoder, value, tag)
if err != nil { return nil, 0, n, err }
encoder.Flush()
return buffer.Bytes(), tag, n, nil
}
func decAny(data []byte) (Tag, any, int, error) {
destination := map[uint16] any { }
tag, err := TagAny(destination)
if err != nil { return 0, nil, 0, err }
n, err := DecodeAnyInto(NewDecoder(bytes.NewBuffer(data)), &destination, tag)
if err != nil { return 0, nil, n, err }
return tag, destination, n, nil
}
func testEncodeAny(test *testing.T, value any, correctTag Tag, correctBytes tu.Snake) error {
bytes, tag, n, err := encAny(value)
if err != nil { return err }
test.Log("n: ", n)
test.Log("tag: ", tag)
test.Log("got: ", tu.HexBytes(bytes))
test.Log("correct:", correctBytes)
if tag != correctTag {
return fmt.Errorf("tag not equal: %v != %v", tag, correctTag)
}
if ok, n := correctBytes.Check(bytes); !ok {
return fmt.Errorf("bytes not equal at index %d", n)
}
if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes))
}
return nil
}
func testEncodeDecodeAny(test *testing.T, value, correctValue any) error {
if correctValue == nil {
correctValue = value
}
test.Log("encoding...")
bytes, tag, n, err := encAny(value)
if err != nil { return err }
test.Log("n: ", n)
test.Log("tag:", tag)
test.Log("got:", tu.HexBytes(bytes))
test.Log("decoding...", tag)
if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes))
}
_, decoded, n, err := decAny(bytes)
if err != nil { return err }
test.Log("got: ", tu.Describe(decoded))
test.Log("correct:", tu.Describe(correctValue))
if !reflect.DeepEqual(decoded, correctValue) {
return fmt.Errorf("values not equal")
}
if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes))
}
return nil
}