Compare commits

..

No commits in common. "main" and "encode-signedness" have entirely different histories.

19 changed files with 388 additions and 1051 deletions

View File

@ -4,7 +4,6 @@ import "os"
import "fmt"
import "strings"
import "path/filepath"
import "git.tebibyte.media/sashakoshka/goparse"
import "git.tebibyte.media/sashakoshka/hopp/generate"
func main() {
@ -19,7 +18,7 @@ func main() {
input, err := os.Open(source)
handleErr(1, err)
defer input.Close()
protocol, err := generate.ParseReader(source, input)
protocol, err := generate.ParseReader(input)
handleErr(1, err)
absDestination, err := filepath.Abs(destination)
@ -31,18 +30,14 @@ func main() {
output, err := os.Create(destination)
handleErr(1, err)
generator := generate.Generator {
Output: output,
PackageName: packageName,
}
_, err = generator.Generate(protocol)
err = protocol.Generate(output, packageName)
handleErr(1, err)
fmt.Fprintf(os.Stderr, "%s: OK\n", name)
}
func handleErr(code int, err error) {
if err != nil {
fmt.Fprintf(os.Stderr, "%s: %v\n", os.Args[0], parse.Format(err))
fmt.Fprintf(os.Stderr, "%s: %v\n", os.Args[0], err)
os.Exit(code)
}
}

View File

@ -1,47 +0,0 @@
package hopp
import "fmt"
type anyInt16 interface { ~uint16 | ~int16 }
type anyInt64 interface { ~uint64 | ~int64 }
// decodeI16 decodes a 16 bit integer from the given data.
func decodeI16[T anyInt16](data []byte) (T, error) {
if len(data) != 2 { return 0, fmt.Errorf("decoding int16: %w", ErrWrongBufferLength) }
return T(data[0]) << 8 | T(data[1]), nil
}
// encodeI16 encodes a 16 bit integer into the given buffer.
func encodeI16[T anyInt16](buffer []byte, value T) error {
if len(buffer) != 2 { return fmt.Errorf("encoding int16: %w", ErrWrongBufferLength) }
buffer[0] = byte(value >> 8)
buffer[1] = byte(value)
return nil
}
// decodeI64 decodes a 64 bit integer from the given data.
func decodeI64[T anyInt64](data []byte) (T, error) {
if len(data) != 8 { return 0, fmt.Errorf("decoding int64: %w", ErrWrongBufferLength) }
return T(data[0]) << 56 |
T(data[1]) << 48 |
T(data[2]) << 40 |
T(data[3]) << 32 |
T(data[4]) << 24 |
T(data[5]) << 16 |
T(data[6]) << 8 |
T(data[7]), nil
}
// encodeI64 encodes a 64 bit integer into the given buffer.
func encodeI64[T anyInt64](buffer []byte, value T) error {
if len(buffer) != 8 { return fmt.Errorf("encoding int64: %w", ErrWrongBufferLength) }
buffer[0] = byte(value >> 56)
buffer[1] = byte(value >> 48)
buffer[2] = byte(value >> 40)
buffer[3] = byte(value >> 32)
buffer[4] = byte(value >> 24)
buffer[5] = byte(value >> 16)
buffer[6] = byte(value >> 8)
buffer[7] = byte(value)
return nil
}

View File

@ -2,7 +2,7 @@ package hopp
import "io"
import "net"
import "time"
// import "time"
const defaultSizeLimit int64 = 1024 * 1024 // 1 megabyte
@ -23,13 +23,8 @@ type Conn interface {
// be called in a loop to avoid the connection locking up.
AcceptTrans() (Trans, error)
// SetDeadline operates is [net.Conn.SetDeadline] but for OpenTrans
// and AcceptTrans calls.
SetDeadline(t time.Time) error
// SetSizeLimit sets a limit (in bytes) for how large messages can be.
// By default, this limit is 1 megabyte. Note that this is only
// enforced when sending and receiving byte slices, and it does not
// apply to [Trans.SendWriter] or [Trans.ReceiveReader].
// By default, this limit is 1 megabyte.
SetSizeLimit(limit int64)
}
@ -44,6 +39,8 @@ type Trans interface {
// ID returns the transaction ID. This must not change, and it must be
// unique within the connection. This method is safe for concurrent use.
ID() int64
// TODO: add methods for setting send and receive deadlines
// Send sends a message. This method is not safe for concurrent use.
Send(method uint16, data []byte) error
@ -60,12 +57,4 @@ type Trans interface {
// previously opened through this function will be discarded. This
// method is not safe for concurrent use, and neither is its result.
ReceiveReader() (method uint16, data io.Reader, err error)
// See the documentation for [net.Conn.SetDeadline].
SetDeadline(time.Time) error
// TODO
// // See the documentation for [net.Conn.SetReadDeadline].
// SetReadDeadline(t time.Time) error
// // See the documentation for [net.Conn.SetWriteDeadline].
// SetWriteDeadline(t time.Time) error
}

View File

@ -9,7 +9,6 @@ type Error string; const (
ErrIntegerOverflow Error = "integer overflow"
ErrMessageMalformed Error = "message is malformed"
ErrTablePairMissing Error = "required table pair is missing"
ErrWrongBufferLength Error = "wrong buffer length"
)
// Error implements the error interface.

View File

@ -41,8 +41,8 @@ type Message interface {
// the destination tag must come from the same (or hash-equivalent) PDL type.
func canAssign(destination, source tape.Tag) bool {
if destination.Is(source) { return true }
if (destination.Is(tape.SBA) || destination.Is(tape.LBA)) &&
(source.Is(tape.SBA) || source.Is(tape.LBA)) {
if (destination == tape.SBA || destination == tape.LBA) &&
(source == tape.SBA || source == tape.LBA) {
return true
}
return false
@ -317,22 +317,13 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
if typ.Signed {
prefix = "WriteInt"
}
nn, err := this.iprintf("nn, err = encoder.%s%d(", prefix, typ.Bits)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ) // TODO: cast like this for
// every type
n += nn; if err != nil { return n, err }
nn, err = this.printf("(%s))\n", valueSource)
nn, err := this.iprintf("nn, err = encoder.%s%d(%s)\n", prefix, typ.Bits, valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeFloat:
// FP: <value: FloatN>
nn, err := this.iprintf("nn, err = encoder.WriteFloat%d(", typ.Bits)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
nn, err = this.printf("(%s))\n", valueSource)
nn, err := this.iprintf("nn, err = encoder.WriteFloat%d(%s)\n", typ.Bits, valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
@ -343,14 +334,7 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
case TypeBuffer:
// SBA: <data: U8>*
// LBA: <length: UN> <data: U8>*
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s.Is(tape.LBA) {\n", tagSource)
nn, err := this.iprintf("if %s.Is(tape.LBA) {\n", tagSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf(
@ -362,20 +346,14 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.Write([]byte(%s))\n", valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeArray:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
nn, err := this.iprintf(
"nn, err = encoder.WriteUintN(uint64(len(%s)), %s.CN())\n",
valueSource, tagSource)
n += nn; if err != nil { return n, err }
@ -430,14 +408,7 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
n += nn; if err != nil { return n, err }
case TypeTable:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
nn, err := this.iprintf(
"nn, err = tape.EncodeAny(encoder, %s, %s)\n",
valueSource, tagSource)
n += nn; if err != nil { return n, err }
@ -445,14 +416,7 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
n += nn; if err != nil { return n, err }
case TypeTableDefined:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf("if %d > tape.MaxStructureLength {\n", len(typ.Fields))
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
nn, err := this.iprintf(
"nn, err = encoder.WriteUintN(%d, %s.CN())\n",
len(typ.Fields), tagSource)
n += nn; if err != nil { return n, err }
@ -518,57 +482,24 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS
// LI/LSI: <value: IntN>
if typ.Bits <= 5 {
// SI stores the value in the tag
if typeName == "" {
nn, err := this.iprintf("*%s = uint8(%s.CN())\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("*%s = %s(%s.CN())\n", valueSource, typeName, tagSource)
n += nn; if err != nil { return n, err }
}
nn, err := this.iprintf("*%s = uint8(%s.CN())\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
break
}
prefix := "ReadUint"
if typ.Signed {
prefix = "ReadInt"
}
destinationVar := this.newTemporaryVar("destination")
nn, err := this.iprintf("var %s ", destinationVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
nn, err = this.print("\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.%s%d()\n", destinationVar, prefix, typ.Bits)
nn, err := this.iprintf("*%s, nn, err = decoder.%s%d()\n", valueSource, prefix, typ.Bits)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
if typeName == "" {
nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar)
n += nn; if err != nil { return n, err }
}
case TypeFloat:
// FP: <value: FloatN>
destinationVar := this.newTemporaryVar("destination")
nn, err := this.iprintf("var %s ", destinationVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
nn, err = this.print("\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadFloat%d()\n", destinationVar, typ.Bits)
nn, err := this.iprintf("*%s, nn, err = decoder.ReadFloat%d()\n", valueSource, typ.Bits)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
if typeName == "" {
nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar)
n += nn; if err != nil { return n, err }
}
case TypeString, TypeBuffer:
// SBA: <data: U8>*
// LBA: <length: UN> <data: U8>*
@ -595,29 +526,17 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("buffer := make([]byte, %s)\n", lengthVar)
nn, err = this.iprintf("buffer := make([]byte, int(%s))\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = decoder.Read(buffer)\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
if typeName == "" {
if _, ok := typ.(TypeString); ok {
nn, err = this.iprintf("*%s = string(buffer)\n", valueSource)
n += nn; if err != nil { return n, err }
} else {
nn, err = this.iprintf("*%s = buffer\n", valueSource)
n += nn; if err != nil { return n, err }
}
if _, ok := typ.(TypeString); ok {
nn, err = this.iprintf("*%s = string(buffer)\n", valueSource)
n += nn; if err != nil { return n, err }
} else {
nn, err = this.iprintf("*%s = %s(buffer)\n", valueSource, typeName)
nn, err = this.iprintf("*%s = buffer\n", valueSource)
n += nn; if err != nil { return n, err }
}
case TypeArray:
@ -704,13 +623,6 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st
lengthVar := this.newTemporaryVar("length")
nn, err := this.iprintf("var %s uint64\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()))\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
@ -780,13 +692,6 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st
lengthVar := this.newTemporaryVar("length")
nn, err := this.iprintf("var %s uint64\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()))\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
@ -799,7 +704,7 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st
// problems
// read fields
nn, err = this.iprintf("for _ = range %s {\n", lengthVar)
nn, err = this.iprintf("for _ = range int(%s) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
// read field header
@ -943,10 +848,10 @@ func (this *Generator) generateTag(typ Type, source string) (n int, err error) {
nn, err := this.printf("tape.FP.WithCN(%d)", bitsToCN(typ.Bits))
n += nn; if err != nil { return n, err }
case TypeString:
nn, err := this.printf("tape.StringTag(string(%s))", source)
nn, err := this.printf("tape.StringTag(%s)", source)
n += nn; if err != nil { return n, err }
case TypeBuffer:
nn, err := this.printf("tape.BufferTag([]byte(%s))", source)
nn, err := this.printf("tape.BufferTag(%s)", source)
n += nn; if err != nil { return n, err }
case TypeArray:
nn, err := this.printf("tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))))", source)

View File

@ -1,15 +1,151 @@
package generate
// import "fmt"
import "strings"
import "testing"
import "git.tebibyte.media/sashakoshka/goparse"
// TODO: once everything has been ironed out, test that the public API of the
// generator is equal to something specific
var testGenerateCorrect =
`package protocol
var exampleProtocol = defaultProtocol()
/* # Do not edit this package by hand!
*
* This file was automatically generated by the Holanet PDL compiler. The
* source file is located at input.pdl
* Please edit that file instead, and re-compile it to this location.
*
* HOPP, TAPE, METADAPT, PDL/0 (c) 2025 holanet.xyz
*/
func init() {
exampleProtocol.Messages[0x0000] = Message {
import "git.tebibyte.media/sashakoshka/hopp/tape"
// Table is a KTV table with an undefined schema.
type Table map[uint16] any
// Message is any message that can be sent along this protocol.
type Message interface {
tape.Encodable
tape.Decodable
// Method returns the method code of the message.
Method() uint16
}
// User represents the protocol data type User.
type User struct {
Name string
Bio string
Followers uint32
}
// EncodeValue encodes the value of this type without the tag. The value is
// encoded according to the parameters specified by the tag, if possible.
func (this *User) EncodeValue(encoder *tape.Encoder) (n int, err error) {
nn, err := tape.WriteTableHeader(2)
n += nn; if err != nil { return n, err }
nn, err := encoder.WriteUint16(0x0000)
n += nn; if err != nil { return n, err }
nn, err := tape.WriteString(encoder, this.Name)
n += nn; if err != nil { return n, err }
nn, err := encoder.WriteUint16(0x0001)
n += nn; if err != nil { return n, err }
nn, err := tape.WriteString(encoder, this.Bio)
n += nn; if err != nil { return n, err }
return n, nil
}
// Decode replaces the data in this User with information from the decoder.
func (this *User) Decode(decoder *tape.Decoder) (n int, err error) {
pull, nn, err := tape.ReadTableHeader(decoder)
n += nn; if err != nil { return n, err }
for {
key, tag, end, nn, err := pull()
n += nn; if err != nil { return n, err }
if end { break }
switch key {
case 0x0000:
value, nn, err := tape.ReadString(decoder)
n += nn; if err != nil { return n, err }
this.Name = value
case 0x0001:
value, nn, err := tape.ReadString(decoder)
n += nn; if err != nil { return n, err }
this.Bio = value
}
}
return n, nil
}
// MessageConnect represents the protocol message M0000 Connect.
type MessageConnect struct {
Name string
Password string
}
// Method returns the method code, M0000.
func (this *MessageConnect) Method() uint16 {
return 0x0000
}
// Encode encodes the message to the encoder.
func (this *MessageConnect) Encode(encoder *tape.Encoder) (n int, err error) {
nn, err := tape.WriteTableHeader(2)
n += nn; if err != nil { return n, err }
nn, err := encoder.WriteUint16(0x0000)
n += nn; if err != nil { return n, err }
nn, err := tape.WriteString(encoder, this.Name)
n += nn; if err != nil { return n, err }
nn, err := encoder.WriteUint16(0x0001)
n += nn; if err != nil { return n, err }
nn, err := tape.WriteString(encoder, this.Password)
n += nn; if err != nil { return n, err }
return n, nil
}
// Decode replaces the data in this message with information from the decoder.
func (this *MessageConnect) Decode(decoder *tape.Decoder) (n int, err error) {
pull, nn, err := tape.ReadTableHeader(decoder)
n += nn; if err != nil { return n, err }
for {
key, tag, end, nn, err := pull()
n += nn; if err != nil { return n, err }
if end { break }
switch key {
case 0x0000:
value, nn, err := tape.ReadString(decoder)
n += nn; if err != nil { return n, err }
this.Name = value
case 0x0001:
value, nn, err := tape.ReadString(decoder)
n += nn; if err != nil { return n, err }
this.Password = value
}
}
return n, nil
}
// MessageUserList represents the protocol message M0001 UserList.
type MessageUserList struct {
Users []User
}
// Method returns the method code, M0001.
func (this *MessageUserList) Method() uint16 {
return 0x0001
}
// TODO methods
`
func TestGenerate(test *testing.T) {
protocol := defaultProtocol()
protocol.Messages[0x0000] = Message {
Name: "Connect",
Type: TypeTableDefined {
Fields: map[uint16] Field {
@ -18,7 +154,7 @@ func init() {
},
},
}
exampleProtocol.Messages[0x0001] = Message {
protocol.Messages[0x0001] = Message {
Name: "UserList",
Type: TypeTableDefined {
Fields: map[uint16] Field {
@ -26,7 +162,59 @@ func init() {
},
},
}
exampleProtocol.Messages[0x0002] = Message {
protocol.Types["User"] = TypeTableDefined {
Fields: map[uint16] Field {
0x0000: Field { Name: "Name", Type: TypeString { } },
0x0001: Field { Name: "Bio", Type: TypeString { } },
0x0002: Field { Name: "Followers", Type: TypeInt { Bits: 32 } },
},
}
correct := testGenerateCorrect
builder := strings.Builder { }
generator := Generator { Output: &builder }
/* TODO test n: */ _, err := generator.Generate(&protocol)
if err != nil { test.Fatal(parse.Format(err)) }
got := builder.String()
test.Log("CORRECT:")
test.Log(correct)
test.Log("GOT:")
test.Log(got)
if correct != got {
test.Error("not equal")
for index := range min(len(correct), len(got)) {
if correct[index] == got[index] { continue }
test.Log("C:", correct[max(0, index - 8):min(len(correct), index + 8)])
test.Log("G:", got[max(0, index - 8):min(len(got), index + 8)])
break
}
test.FailNow()
}
}
func TestGenerateRun(test *testing.T) {
protocol := defaultProtocol()
protocol.Messages[0x0000] = Message {
Name: "Connect",
Type: TypeTableDefined {
Fields: map[uint16] Field {
0x0000: Field { Name: "Name", Type: TypeString { } },
0x0001: Field { Name: "Password", Type: TypeString { } },
},
},
}
protocol.Messages[0x0001] = Message {
Name: "UserList",
Type: TypeTableDefined {
Fields: map[uint16] Field {
0x0000: Field { Name: "Users", Type: TypeArray { Element: TypeNamed { Name: "User" } } },
},
},
}
protocol.Messages[0x0002] = Message {
Name: "Pulse",
Type: TypeTableDefined {
Fields: map[uint16] Field {
@ -38,11 +226,11 @@ func init() {
},
},
}
exampleProtocol.Messages[0x0003] = Message {
protocol.Messages[0x0003] = Message {
Name: "NestedArray",
Type: TypeArray { Element: TypeArray { Element: TypeInt { Bits: 8 } } },
}
exampleProtocol.Messages[0x0004] = Message {
protocol.Messages[0x0004] = Message {
Name: "Integers",
Type: TypeTableDefined {
Fields: map[uint16] Field {
@ -62,25 +250,23 @@ func init() {
},
},
}
exampleProtocol.Types["User"] = TypeTableDefined {
protocol.Types["User"] = TypeTableDefined {
Fields: map[uint16] Field {
0x0000: Field { Name: "Name", Type: TypeString { } },
0x0001: Field { Name: "Bio", Type: TypeString { } },
0x0002: Field { Name: "Followers", Type: TypeInt { Bits: 32 } },
},
}
}
func TestGenerateRunEncodeDecode(test *testing.T) {
testGenerateRun(test, &exampleProtocol, "encode-decode", `
testGenerateRun(test, &protocol, `
// imports
`, `
// test case
log.Println("MessageConnect")
messageConnect := MessageConnect {
Name: "rarity",
Password: "gems",
}
testEncodeDecode(
testEncode(
&messageConnect,
tu.S(0xE1, 0x02).AddVar(
[]byte { 0x00, 0x00, 0x86, 'r', 'a', 'r', 'i', 't', 'y' },
@ -106,7 +292,7 @@ func TestGenerateRunEncodeDecode(test *testing.T) {
},
},
}
testEncodeDecode(
testEncode(
&messageUserList,
tu.S(0xE1, 0x01, 0x00, 0x00,
0xC1, 0x03, 0xE1,
@ -128,11 +314,11 @@ func TestGenerateRunEncodeDecode(test *testing.T) {
messagePulse := MessagePulse {
Index: 9,
Offset: -0x3521,
X: 45.375,
X: 45.389,
Y: 294.1,
Z: 384729384.234892034,
}
testEncodeDecode(
testEncode(
&messagePulse,
tu.S(0xE1, 0x05).AddVar(
[]byte { 0x00, 0x00, 0x09 },
@ -153,7 +339,7 @@ func TestGenerateRunEncodeDecode(test *testing.T) {
uint8s(6),
uint8s(35),
}
testEncodeDecode(
testEncode(
&messageNestedArray,
tu.S(0xC1, 0x02, 0xC1,
0x06, 0x20, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6,
@ -179,7 +365,7 @@ func TestGenerateRunEncodeDecode(test *testing.T) {
NI32: -0x10E134C9,
NI64: -0x639109BC10E134C9,
}
testEncodeDecode(
testEncode(
&messageIntegers,
tu.S(0xE1, 13).AddVar(
[]byte { 0x00, 0x00, 0x13 },
@ -198,184 +384,3 @@ func TestGenerateRunEncodeDecode(test *testing.T) {
))
`)
}
func TestGenerateRunDecodeWrongType(test *testing.T) {
protocol := defaultProtocol()
protocol.Messages[0x0000] = Message {
Name: "Uint5",
Type: TypeInt { Bits: 5 },
}
protocol.Messages[0x0001] = Message {
Name: "Uint8",
Type: TypeInt { Bits: 8 },
}
protocol.Messages[0x0002] = Message {
Name: "Uint16",
Type: TypeInt { Bits: 16 },
}
protocol.Messages[0x0003] = Message {
Name: "Uint32",
Type: TypeInt { Bits: 32 },
}
protocol.Messages[0x0004] = Message {
Name: "Uint64",
Type: TypeInt { Bits: 64 },
}
protocol.Messages[0x0005] = Message {
Name: "Int8",
Type: TypeInt { Bits: 8 },
}
protocol.Messages[0x0006] = Message {
Name: "Int16",
Type: TypeInt { Bits: 16 },
}
protocol.Messages[0x0007] = Message {
Name: "Int32",
Type: TypeInt { Bits: 32 },
}
protocol.Messages[0x0008] = Message {
Name: "Int64",
Type: TypeInt { Bits: 64 },
}
protocol.Messages[0x0009] = Message {
Name: "String",
Type: TypeString { },
}
protocol.Messages[0x000A] = Message {
Name: "Buffer",
Type: TypeBuffer { },
}
protocol.Messages[0x000B] = Message {
Name: "StringArray",
Type: TypeArray { Element: TypeString { } },
}
protocol.Messages[0x000C] = Message {
Name: "Table",
Type: TypeTable { },
}
protocol.Messages[0x000D] = Message {
Name: "TableDefined",
Type: TypeTableDefined {
Fields: map[uint16] Field {
0x0000: Field { Name: "Name", Type: TypeString { } },
0x0001: Field { Name: "Password", Type: TypeString { } },
},
},
}
testGenerateRun(test, &protocol, "decode-wrong-type", `
// imports
`, `
datas := [][]byte {
/* int8 */ []byte { byte(tape.LSI.WithCN(0)), 0x45 },
/* int16 */ []byte { byte(tape.LSI.WithCN(1)), 0x45, 0x67 },
/* int32 */ []byte { byte(tape.LSI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB },
/* int64 */ []byte { byte(tape.LSI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 },
/* uint5 */ []byte { byte(tape.SI.WithCN(12)) },
/* uint8 */ []byte { byte(tape.LI.WithCN(0)), 0x45 },
/* uint16 */ []byte { byte(tape.LI.WithCN(1)), 0x45, 0x67 },
/* uint32 */ []byte { byte(tape.LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB },
/* uint64 */ []byte { byte(tape.LI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 },
/* string */ []byte { byte(tape.SBA.WithCN(7)), 'p', 'u', 'p', 'e', 'v', 'e', 'r' },
/* []byte */ []byte { byte(tape.SBA.WithCN(5)), 'b', 'l', 'a', 'r', 'g' },
/* []string */ []byte {
byte(tape.OTA.WithCN(0)), 2, byte(tape.LBA.WithCN(0)),
0x08, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23,
0x05, 0x11, 0x11, 0x11, 0x11, 0x11,
},
/* map[uint16] any */ []byte {
byte(tape.KTV.WithCN(0)), 2,
0x02, 0x23, byte(tape.LSI.WithCN(1)), 0x45, 0x67,
0x02, 0x23, byte(tape.LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB,
},
}
for index, data := range datas {
log.Printf("data %2d %v [%s]", index, tape.Tag(data[0]), tu.HexBytes(data[1:]))
// integers should only assign to other integers
if index > 8 {
cas := func(destination Message) {
n, err := destination.Decode(tape.NewDecoder(bytes.NewBuffer(data)))
if err != nil { log.Fatalf("error: %v | n: %d", err, n) }
reflectValue := reflect.ValueOf(destination).Elem()
if reflectValue.CanInt() {
if reflectValue.Int() != 0 {
log.Fatalf(
"destination not zero: %v",
reflectValue.Elem().Interface())
}
} else {
if reflectValue.Uint() != 0 {
log.Fatalf(
"destination not zero: %v",
reflectValue.Elem().Interface())
}
}
if n != len(data) {
log.Fatalf("n not equal: %d != %d", n, len(data))
}
}
log.Println("- MessageInt8")
{ var dest MessageInt8; cas(&dest) }
log.Println("- MessageInt16")
{ var dest MessageInt16; cas(&dest) }
log.Println("- MessageInt32")
{ var dest MessageInt32; cas(&dest) }
log.Println("- MessageInt64")
{ var dest MessageInt64; cas(&dest) }
log.Println("- MessageUint8")
{ var dest MessageUint8; cas(&dest) }
log.Println("- MessageUint16")
{ var dest MessageUint16; cas(&dest) }
log.Println("- MessageUint32")
{ var dest MessageUint32; cas(&dest) }
log.Println("- MessageUint64")
{ var dest MessageUint64; cas(&dest) }
}
arrayCase := func(destination Message) {
n, err := destination.Decode(tape.NewDecoder(bytes.NewBuffer(data)),)
if err != nil { log.Fatalf("error: %v | n: %d", err, n) }
reflectDestination := reflect.ValueOf(destination)
reflectValue := reflectDestination.Elem()
if reflectValue.Len() != 0 {
log.Fatalf("len(destination) not zero: %v", reflectValue.Interface())
}
if n != len(data) {
log.Fatalf("n not equal: %d != %d", n, len(data))
}
}
anyCase := func(destination Message) {
n, err := destination.Decode(tape.NewDecoder(bytes.NewBuffer(data)),)
if err != nil { log.Fatalf("error: %v | n: %d", err, n) }
reflectDestination := reflect.ValueOf(destination)
reflectValue := reflectDestination.Elem()
if reflectValue == reflect.Zero(reflectValue.Type()) {
log.Fatalf("len(destination) not zero: %v", reflectValue.Interface())
}
if n != len(data) {
log.Fatalf("n not equal: %d != %d", n, len(data))
}
}
// SBA/LBA types should only assign to other SBA/LBA types
if index != 9 && index != 10 {
log.Println("- MessageString")
{ var dest MessageString; arrayCase(&dest) }
log.Println("- MessageBuffer")
{ var dest MessageBuffer; arrayCase(&dest) }
}
// arrays should only assign to other arrays
if index != 11 {
log.Println("- MessageStringArray")
{ var dest MessageStringArray; arrayCase(&dest) }
}
// tables should only assign to other tables
if index != 12 {
log.Println("- MessageTable")
{ var dest = make(MessageTable); arrayCase(&dest) }
log.Println("- MessageTableDefined")
{ var dest MessageTableDefined; anyCase(&dest) }
}
}
`)
}

View File

@ -6,9 +6,9 @@ import "os/exec"
import "testing"
import "path/filepath"
func testGenerateRun(test *testing.T, protocol *Protocol, title, imports, testCase string) {
func testGenerateRun(test *testing.T, protocol *Protocol, imports string, testCase string) {
// reset data directory
dir := filepath.Join("test", title)
dir := "test/generate-run"
err := os.RemoveAll(dir)
if err != nil { test.Fatal(err) }
err = os.MkdirAll(dir, 0750)
@ -34,7 +34,6 @@ func testGenerateRun(test *testing.T, protocol *Protocol, title, imports, testCa
imports = `
import "log"
import "bytes"
import "reflect"
import "git.tebibyte.media/sashakoshka/hopp/tape"
import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil"
` + imports
@ -57,70 +56,13 @@ func testGenerateRun(test *testing.T, protocol *Protocol, title, imports, testCa
log.Fatalln("not equal at", n)
}
}
func testDecode(correct Message, data any) {
var flat []byte
switch data := data.(type) {
case []byte: flat = data
case tu.Snake: flat = data.Flatten()
}
message := reflect.New(reflect.ValueOf(correct).Elem().Type()).Interface().(Message)
log.Println("before: ", message)
decoder := tape.NewDecoder(bytes.NewBuffer(flat))
n, err := message.Decode(decoder)
if err != nil { log.Fatalf("at %d: %v\n", n, err) }
log.Println("got: ", message)
log.Println("correct:", correct)
if n != len(flat) {
log.Fatalf("n incorrect: %d != %d\n", n, len(flat))
}
if !reflect.DeepEqual(message, correct) {
log.Fatalln("not equal")
}
}
// TODO: possibly combine the two above functions into this one,
// also take a data parameter here (snake)
func testEncodeDecode(message Message, data tu.Snake) {buffer := bytes.Buffer { }
log.Println("encoding:")
encoder := tape.NewEncoder(&buffer)
n, err := message.Encode(encoder)
if err != nil { log.Fatalf("at %d: %v\n", n, err) }
encoder.Flush()
got := buffer.Bytes()
log.Printf("got: [%s]", tu.HexBytes(got))
log.Println("correct:", data)
if n != len(got) {
log.Fatalf("n incorrect: %d != %d\n", n, len(got))
}
if ok, n := data.Check(got); !ok {
log.Fatalln("not equal at", n)
}
log.Println("decoding:")
destination := reflect.New(reflect.ValueOf(message).Elem().Type()).Interface().(Message)
flat := data.Flatten()
log.Println("before: ", destination)
decoder := tape.NewDecoder(bytes.NewBuffer(flat))
n, err = destination.Decode(decoder)
if err != nil { log.Fatalf("at %d: %v\n", n, err) }
log.Println("got: ", destination)
log.Println("correct:", message)
if n != len(flat) {
log.Fatalf("n incorrect: %d != %d\n", n, len(flat))
}
if !reflect.DeepEqual(destination, message) {
log.Fatalln("not equal")
}
}
`
fmt.Fprintf(
mainFile, "package main\n%s\nfunc main() {\n%s\n%s\n%s\n}\n%s",
imports, setup, testCase, teardown, static)
// build and run test
command := exec.Command("go", "run", "./" + filepath.Join("generate", dir))
command := exec.Command("go", "run", "./generate/test/generate-run")
workingDirAbs, err := filepath.Abs("..")
if err != nil { test.Fatal(err) }
command.Dir = workingDirAbs

View File

@ -25,8 +25,8 @@ func defaultProtocol() Protocol {
}
}
func ParseReader(fileName string, reader io.Reader) (*Protocol, error) {
lx, err := Lex(fileName, reader)
func ParseReader(reader io.Reader) (*Protocol, error) {
lx, err := Lex("test.pdl", reader)
if err != nil { return nil, err }
return Parse(lx)
}

View File

@ -33,7 +33,7 @@ func TestParse(test *testing.T) {
}
test.Log("CORRECT:", &correct)
got, err := ParseReader("test.pdl", strings.NewReader(`
got, err := ParseReader(strings.NewReader(`
M0000 Connect {
0000 Name String,
0001 Password String,

View File

@ -64,18 +64,6 @@ func (sn Snake) Check(data []byte) (ok bool, n int) {
return true, n
}
// Flatten returns the snake flattened to a byte array. The result of this
// function always satisfies the snake.
func (sn Snake) Flatten() []byte {
flat := []byte { }
for _, sector := range sn {
for _, variation := range sector {
flat = append(flat, variation...)
}
}
return flat
}
func (sn Snake) String() string {
if len(sn) == 0 || len(sn[0]) == 0 || len(sn[0][0]) == 0 {
return "EMPTY"

52
message.go Normal file
View File

@ -0,0 +1,52 @@
package hopp
import "fmt"
import "encoding"
import "git.tebibyte.media/sashakoshka/hopp/tape"
// Message is any object that can be sent or received over a HOPP connection.
type Message interface {
// Method returns the method number of the message. This must be unique
// within the protocol, and should not change between calls.
Method() uint16
encoding.BinaryMarshaler
encoding.BinaryUnmarshaler
}
var _ Message = new(MessageData)
// MessageData represents a message that organizes its data into table pairs. It
// can be used to alter a protocol at runtime, transmit data with arbitrary
// keys, etc. Bear in mind that is less performant than generating code because
// it has to make extra memory allocations and such.
type MessageData struct {
// Methd holds the method number. This should only be set once.
Methd uint16
// Pairs maps tags to values.
Pairs map[uint16] []byte
}
// Method returns the message's method field.
func (this *MessageData) Method() uint16 {
return this.Methd
}
// MarshalBinary implements the [encoding.BinaryMarshaler] interface. The
// message is encoded using TAPE (Table Pair Encoding).
func (this *MessageData) MarshalBinary() ([]byte, error) {
buffer, err := tape.EncodePairs(this.Pairs)
if err != nil { return nil, fmt.Errorf("marshaling MessageData: %w", err) }
return buffer, nil
}
// UnmarshalBinary implements the [encoding.BinaryUnmarshaler] interface. The
// message is decoded using TAPE (Table Pair Encoding).
func (this *MessageData) UnmarshalBinary(buffer []byte) error {
this.Pairs = make(map[uint16] []byte)
pairs, err := tape.DecodePairs(buffer)
if err != nil { return fmt.Errorf("unmarshaling MessageData: %w", err) }
for key, value := range pairs {
this.Pairs[key] = value
}
return nil
}

View File

@ -1,12 +1,11 @@
package hopp
import "io"
import "os"
import "fmt"
import "net"
import "sync"
import "time"
import "sync/atomic"
import "git.tebibyte.media/sashakoshka/hopp/tape"
import "git.tebibyte.media/sashakoshka/go-util/sync"
// TODO investigate why 30 never reaches the server, causing it to wait for ever
@ -110,10 +109,6 @@ func (this *a) AcceptTrans() (Trans, error) {
}
}
func (this *a) SetDeadline(t time.Time) error {
return this.underlying.SetDeadline(t)
}
func (this *a) SetSizeLimit(limit int64) {
this.sizeLimit = limit
}
@ -218,10 +213,6 @@ type transA struct {
currentWriter io.Closer
writeBuffer []byte
closed atomic.Bool
closeErr error
deadline *time.Timer
deadlineLock sync.Mutex
}
func (this *transA) Close() error {
@ -231,11 +222,6 @@ func (this *transA) Close() error {
return err
}
func (this *transA) closeWithError(err error) error {
this.closeErr = err
return this.Close()
}
func (this *transA) closeDontUnlist() (err error) {
// MUST be goroutine safe
this.incoming.Close()
@ -284,9 +270,9 @@ func (this *transA) Receive() (method uint16, data []byte, err error) {
}
func (this *transA) ReceiveReader() (uint16, io.Reader, error) {
// if the transaction has been closed, return an appropriate error.
if err := this.errIfClosed(); err != nil {
return 0, nil, err
// if the transaction has been closed, return an io.EOF
if this.closed.Load() {
return 0, nil, io.EOF
}
// drain previous reader if necessary
@ -304,54 +290,6 @@ func (this *transA) ReceiveReader() (uint16, io.Reader, error) {
return method, reader, nil
}
func (this *transA) SetDeadline(t time.Time) error {
this.deadlineLock.Lock()
defer this.deadlineLock.Unlock()
if t == (time.Time { }) {
if this.deadline != nil {
this.deadline.Stop()
}
return nil
}
until := time.Until(t)
if this.deadline == nil {
this.deadline.Reset(until)
return nil
}
this.deadline = time.AfterFunc(until, func () {
this.closeWithError(os.ErrDeadlineExceeded)
})
return nil
}
// TODO
// func (this *transA) SetReadDeadline(t time.Time) error {
// // TODO
// }
//
// func (this *transA) SetWriteDeadline(t time.Time) error {
// // TODO
// }
func (this *transA) errIfClosed() error {
if !this.closed.Load() {
return nil
}
return this.bestErr()
}
func (this *transA) bestErr() error {
if this.parent.err != nil {
return this.parent.err
}
if this.closeErr != nil {
return this.closeErr
}
return io.EOF
}
type readerA struct {
parent *transA
leftover []byte
@ -382,7 +320,11 @@ func (this *readerA) pull() (uint16, error) {
// close and return error on failure
this.eof = true
this.parent.Close()
return 0, fmt.Errorf("could not receive message: %w", this.parent.bestErr())
if this.parent.parent.err == nil {
return 0, fmt.Errorf("could not receive message: %w", io.EOF)
} else {
return 0, this.parent.parent.err
}
}
func (this *readerA) Read(buffer []byte) (int, error) {
@ -464,9 +406,9 @@ func encodeMessageA(
return ErrPayloadTooLarge
}
buffer := make([]byte, 18 + len(data))
encodeI64(buffer[:8], trans)
encodeI16(buffer[8:10], method)
encodeI64(buffer[10:18], uint64(len(data)))
tape.EncodeI64(buffer[:8], trans)
tape.EncodeI16(buffer[8:10], method)
tape.EncodeI64(buffer[10:18], uint64(len(data)))
copy(buffer[18:], data)
_, err := writer.Write(buffer)
return err
@ -485,11 +427,11 @@ func decodeMessageA(
headerBuffer := [18]byte { }
_, err = io.ReadFull(reader, headerBuffer[:])
if err != nil { return 0, 0, false, nil, err }
transID, err = decodeI64[int64](headerBuffer[:8])
transID, err = tape.DecodeI64[int64](headerBuffer[:8])
if err != nil { return 0, 0, false, nil, err }
method, err = decodeI16[uint16](headerBuffer[8:10])
method, err = tape.DecodeI16[uint16](headerBuffer[8:10])
if err != nil { return 0, 0, false, nil, err }
size, err := decodeI64[uint64](headerBuffer[10:18])
size, err := tape.DecodeI64[uint64](headerBuffer[10:18])
if err != nil { return 0, 0, false, nil, err }
chunked, size = splitCCBSize(size)
if size > uint64(sizeLimit) {

View File

@ -2,10 +2,10 @@ package hopp
import "io"
import "net"
import "time"
import "bytes"
import "errors"
import "context"
import "git.tebibyte.media/sashakoshka/hopp/tape"
// B implements METADAPT-B over a multiplexed stream-oriented transport such as
// QUIC.
@ -51,10 +51,6 @@ func (this *b) SetSizeLimit(limit int64) {
this.sizeLimit = limit
}
func (this *b) SetDeadline(t time.Time) error {
return this.underlying.SetDeadline(t)
}
func (this *b) newTrans(underlying Stream) *transB {
return &transB {
sizeLimit: this.sizeLimit,
@ -129,10 +125,6 @@ func (this *transB) receiveReader() (uint16, int64, io.Reader, error) {
return method, size, data, nil
}
func (this *transB) SetDeadline(t time.Time) error {
return this.underlying.SetDeadline(t)
}
type writerB struct {
parent *transB
buffer bytes.Buffer
@ -158,16 +150,12 @@ type MultiConn interface {
AcceptStream(context.Context) (Stream, error)
// OpenStream opens a new stream.
OpenStream() (Stream, error)
// See the documentation for [net.Conn.SetDeadline].
SetDeadline(time.Time) error
}
// Stream represents a single stream returned by a [MultiConn].
type Stream interface {
// See documentation for [net.Conn].
io.ReadWriteCloser
// See the documentation for [net.Conn.SetDeadline].
SetDeadline(time.Time) error
// ID returns the stream ID
ID() int64
}
@ -177,8 +165,8 @@ func encodeMessageB(writer io.Writer, sizeLimit int64, method uint16, data []byt
return ErrPayloadTooLarge
}
buffer := make([]byte, 10 + len(data))
encodeI16(buffer[:2], method)
encodeI64(buffer[2:10], uint64(len(data)))
tape.EncodeI16(buffer[:2], method)
tape.EncodeI64(buffer[2:10], uint64(len(data)))
copy(buffer[10:], data)
_, err := writer.Write(buffer)
return err
@ -199,9 +187,9 @@ func decodeMessageB(
if errors.Is(err, io.EOF) { return 0, 0, nil, io.ErrUnexpectedEOF }
return 0, 0, nil, err
}
method, err = decodeI16[uint16](headerBuffer[:2])
method, err = tape.DecodeI16[uint16](headerBuffer[:2])
if err != nil { return 0, 0, nil, err }
length, err := decodeI64[uint64](headerBuffer[2:10])
length, err := tape.DecodeI64[uint64](headerBuffer[2:10])
if err != nil { return 0, 0, nil, err }
if length > uint64(sizeLimit) {
return 0, 0, nil, ErrPayloadTooLarge

View File

@ -14,14 +14,6 @@ import "reflect"
var dummyMap map[uint16] any
var dummyBuffer []byte
type errCantAssign string
func (err errCantAssign) Error() string {
return string(err)
}
func errCantAssignf(format string, v ...any) errCantAssign {
return errCantAssign(fmt.Sprintf(format, v...))
}
// EncodeAny encodes an "any" value. Returns an error if the underlying type is
// unsupported. Supported types are:
//
@ -46,16 +38,9 @@ func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
case reflect.Uint32: return encoder.WriteUint32(uint32(reflectValue.Uint()))
case reflect.Int64: return encoder.WriteInt64(int64(reflectValue.Int()))
case reflect.Uint64: return encoder.WriteUint64(uint64(reflectValue.Uint()))
case reflect.String:
if reflectValue.Len() > MaxStructureLength {
return 0, ErrTooLong
}
return EncodeAny(encoder, []byte(reflectValue.String()), tag)
case reflect.String: return EncodeAny(encoder, []byte(reflectValue.String()), tag)
}
if reflectValue.CanConvert(reflect.TypeOf(dummyBuffer)) {
if reflectValue.Len() > MaxStructureLength {
return 0, ErrTooLong
}
if tag.Is(LBA) {
nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1)
n += nn; if err != nil { return n, err }
@ -71,13 +56,8 @@ func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
case reflect.Slice:
return encodeAnySlice(encoder, value, tag)
// case reflect.Array:
// TODO: we can encode arrays. but can we decode into them?
// that's the fucken question. maybe we just do the first
// return encodeAnySlice(encoder, reflect.ValueOf(value).Slice(0, reflectType.Len()).Interface(), tag)
case reflect.Map:
if reflectValue.Len() > MaxStructureLength {
return 0, ErrTooLong
}
if reflectType.Key() == reflect.TypeOf(uint16(0)) {
return encodeAnyMap(encoder, value, tag)
}
@ -103,34 +83,20 @@ type unknownSlicePlaceholder struct { }
var unknownSlicePlaceholderType = reflect.TypeOf(unknownSlicePlaceholder { })
// decodeAny is internal to [DecodeAny]. It takes in an addressable
// [reflect.Value] as the destination. If the decoded value cannot fit in the
// destination, it skims over the payload, leaves the destination empty, and
// returns without an error.
// [reflect.Value] as the destination.
func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) {
n, err = decodeAnyOrError(decoder, destination, tag)
if _, ok := err.(errCantAssign); ok {
if n > 0 { panic(fmt.Sprintf("decodeAnyOrError decoded more than it should: %d", n)) }
nn, err := Skim(decoder, tag)
n += nn; if err != nil { return n, err }
return n, nil
errWrongDestinationType := func(expected string) error {
panic(fmt.Errorf(
// return fmt.Errorf(
"expected %s destination, not %v",
expected, destination))
}
return n, err
}
// decodeAnyOrError is internal to [decodeAny]. It takes in an addressable
// [reflect.Value] as the destination. If the decoded value cannot fit in the
// destination, it decodes nothing and returns an error of type errCantAssign,
// except for the case of a mismatched OTA element tag, wherein it will skim
// over the rest of the payload, leave the destination empty, and return without
// an error.
func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) {
err = canSet(destination.Type(), tag)
if err != nil { return n, err }
switch tag.WithoutCN() {
case SI:
// SI: (none)
setInt(destination, uint64(tag.CN()))
err = setInt(destination, uint64(tag.CN()))
if err != nil { return n, err }
case LI:
// LI: <value: IntN>
nn, err := decodeAndSetUint(decoder, destination, tag.CN() + 1)
@ -145,68 +111,46 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
n += nn; if err != nil { return n, err }
case SBA:
// SBA: <data: U8>*
length := tag.CN()
if length > MaxStructureLength {
return 0, ErrTooLong
}
buffer := make([]byte, length)
buffer := make([]byte, tag.CN())
nn, err := decoder.Read(buffer)
n += nn; if err != nil { return n, err }
setByteArray(destination, buffer)
err = setByteArray(destination, buffer)
if err != nil { return n, err }
case LBA:
// LBA: <length: UN> <data: U8>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
if length > uint64(MaxStructureLength) {
return 0, ErrTooLong
}
buffer := make([]byte, length)
nn, err = decoder.Read(buffer)
n += nn; if err != nil { return n, err }
setByteArray(destination, buffer)
err = setByteArray(destination, buffer)
if err != nil { return n, err }
case OTA:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
if length > uint64(MaxStructureLength) {
return 0, ErrTooLong
}
lengthCast, err := Uint64ToIntSafe(length)
if err != nil { return n, err }
oneTag, nn, err := decoder.ReadTag()
n += nn; if err != nil { return n, err }
if destination.Cap() < lengthCast {
destination.Grow(lengthCast - destination.Cap())
if destination.Kind() != reflect.Slice {
return n, errWrongDestinationType("slice")
}
// skip the rest of the array if the one tag doesn't
// match up with the destination
err = canSet(destination.Type().Elem(), oneTag)
if _, ok := err.(errCantAssign); ok {
for _ = range length {
nn, err := Skim(decoder, oneTag)
n += nn; if err != nil { return n, err }
}
break
if destination.Cap() < int(length) {
destination.Grow(int(length) - destination.Cap())
}
if err != nil { return n, err }
destination.SetLen(lengthCast)
destination.SetLen(int(length))
for index := range length {
nn, err := decodeAny(decoder, destination.Index(int(index)), oneTag)
n += nn
if _, ok := err.(errCantAssign); ok {
continue
} else if err != nil {
return n, err
}
n += nn; if err != nil { return n, err }
}
case KTV:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
table := destination
if table.Type() != reflect.TypeOf(dummyMap) {
return n, errWrongDestinationType("map[uint16] any")
}
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
if length > uint64(MaxStructureLength) {
return 0, ErrTooLong
}
destination.Clear()
table.Clear()
for _ = range length {
key, nn, err := decoder.ReadUint16()
n += nn; if err != nil { return n, err }
@ -216,7 +160,7 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
if err != nil { return n, err }
nn, err = decodeAny(decoder, value.Elem(), itemTag)
n += nn; if err != nil { return n, err }
destination.SetMapIndex(reflect.ValueOf(key), value.Elem())
table.SetMapIndex(reflect.ValueOf(key), value.Elem())
}
default:
return n, fmt.Errorf("unknown TN %d", tag.TN())
@ -228,40 +172,35 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
// underlying type is unsupported. See [EncodeAny] for a list of supported
// types.
func TagAny(value any) (Tag, error) {
return tagAny(reflect.ValueOf(value))
}
func tagAny(reflectValue reflect.Value) (Tag, error) {
// TODO use reflection for all of this to ignore type names
// primitives
switch reflectValue.Kind() {
case reflect.Int: return LSI.WithCN(3), nil
case reflect.Int8: return LSI.WithCN(0), nil
case reflect.Int16: return LSI.WithCN(1), nil
case reflect.Int32: return LSI.WithCN(3), nil
case reflect.Int64: return LSI.WithCN(7), nil
case reflect.Uint: return LI.WithCN(3), nil
case reflect.Uint8: return LI.WithCN(0), nil
case reflect.Uint16: return LI.WithCN(1), nil
case reflect.Uint32: return LI.WithCN(3), nil
case reflect.Uint64: return LI.WithCN(7), nil
case reflect.String: return bufferLenTag(reflectValue.Len()), nil
}
if reflectValue.CanConvert(reflect.TypeOf(dummyBuffer)) {
return bufferLenTag(reflectValue.Len()), nil
switch value := value.(type) {
case int: return LSI.WithCN(3), nil
case int8: return LSI.WithCN(0), nil
case int16: return LSI.WithCN(1), nil
case int32: return LSI.WithCN(3), nil
case int64: return LSI.WithCN(7), nil
case uint: return LI.WithCN(3), nil
case uint8: return LI.WithCN(0), nil
case uint16: return LI.WithCN(1), nil
case uint32: return LI.WithCN(3), nil
case uint64: return LI.WithCN(7), nil
case string: return bufferLenTag(len(value)), nil
case []byte: return bufferLenTag(len(value)), nil
}
// aggregates
reflectType := reflectValue.Type()
reflectType := reflect.TypeOf(value)
switch reflectType.Kind() {
case reflect.Slice: return OTA.WithCN(IntBytes(uint64(reflectValue.Len())) - 1), nil
case reflect.Slice: return OTA.WithCN(IntBytes(uint64(reflect.ValueOf(value).Len())) - 1), nil
case reflect.Array: return OTA.WithCN(reflectType.Len()), nil
case reflect.Map:
if reflectType.Key() == reflect.TypeOf(uint16(0)) {
return KTV.WithCN(IntBytes(uint64(reflectValue.Len())) - 1), nil
return KTV.WithCN(IntBytes(uint64(reflect.ValueOf(value).Len())) - 1), nil
}
return 0, fmt.Errorf("cannot encode map key %v, key must be uint16", reflectType.Key())
return 0, fmt.Errorf("cannot encode map key %T, key must be uint16", value)
}
return 0, fmt.Errorf("cannot get tag of type %v", reflectType)
return 0, fmt.Errorf("cannot get tag of type %T", value)
}
func encodeAnySlice(encoder *Encoder, value any, tag Tag) (n int, err error) {
@ -270,10 +209,11 @@ func encodeAnySlice(encoder *Encoder, value any, tag Tag) (n int, err error) {
nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1)
n += nn; if err != nil { return n, err }
reflectType := reflect.TypeOf(value)
oneTag, err := tagAny(reflect.Zero(reflectType.Elem()))
oneTag, err := TagAny(reflect.Zero(reflectType.Elem()).Interface())
if err != nil { return n, err }
for index := 0; index < reflectValue.Len(); index += 1 {
itemTag, err := tagAny(reflectValue.Index(index))
item := reflectValue.Index(index).Interface()
itemTag, err := TagAny(item)
if err != nil { return n, err }
if itemTag.CN() > oneTag.CN() { oneTag = itemTag }
}
@ -295,12 +235,11 @@ func encodeAnyMap(encoder *Encoder, value any, tag Tag) (n int, err error) {
n += nn; if err != nil { return n, err }
iter := reflectValue.MapRange()
for iter.Next() {
reflectValue := iter.Value().Elem()
key := iter.Key().Interface().(uint16)
value := reflectValue.Interface()
key := iter.Key().Interface().(uint16)
value := iter.Value().Interface()
nn, err = encoder.WriteUint16(key)
n += nn; if err != nil { return n, err }
itemTag, err := tagAny(reflectValue)
itemTag, err := TagAny(value)
if err != nil { return n, err }
nn, err = encoder.WriteUint8(uint8(itemTag))
n += nn; if err != nil { return n, err }
@ -310,79 +249,53 @@ func encodeAnyMap(encoder *Encoder, value any, tag Tag) (n int, err error) {
return n, nil
}
func canSet(destination reflect.Type, tag Tag) error {
switch tag.WithoutCN() {
case SI, LI, LSI:
switch destination.Kind() {
case
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
default:
return errCantAssignf("cannot assign integer to %v", destination)
}
case FP:
switch destination.Kind() {
case reflect.Float32, reflect.Float64:
default:
return errCantAssignf("cannot assign float to %v", destination)
}
case SBA, LBA:
if destination.Kind() != reflect.Slice {
return errCantAssignf("cannot assign byte array to %v", destination)
}
if destination.Elem() != reflect.TypeOf(byte(0)) {
return errCantAssignf("cannot convert %v to *[]byte", destination)
}
case OTA:
if destination.Kind() != reflect.Slice {
return errCantAssignf("cannot assign array to %v", destination)
}
case KTV:
if destination != reflect.TypeOf(dummyMap) {
return errCantAssignf("cannot assign table to %v", destination)
}
default:
return fmt.Errorf("unknown TN %d", tag.TN())
}
return nil
}
// setInt expects a settable destination.
func setInt[T int64 | uint64](destination reflect.Value, value T) {
func setInt[T int64 | uint64](destination reflect.Value, value T) error {
switch {
case destination.CanInt():
destination.Set(reflect.ValueOf(int64(value)).Convert(destination.Type()))
case destination.CanUint():
destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
default:
panic("setInt called on an unsupported type")
return fmt.Errorf("cannot assign integer to %T", destination.Interface())
}
return nil
}
// setFloat expects a settable destination.
func setFloat(destination reflect.Value, value float64) {
func setFloat(destination reflect.Value, value float64) error {
if !destination.CanFloat() {
return fmt.Errorf("cannot assign float to %T", destination.Interface())
}
destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
return nil
}
// setByteArrayexpects a settable destination.
func setByteArray(destination reflect.Value, value []byte) {
func setByteArray(destination reflect.Value, value []byte) error {
typ := destination.Type()
if typ.Kind() != reflect.Slice {
return fmt.Errorf("cannot assign %T to ", value)
}
if typ.Elem() != reflect.TypeOf(byte(0)) {
return fmt.Errorf("cannot convert %T to *[]byte", value)
}
destination.Set(reflect.ValueOf(value))
return nil
}
// decodeAndSetInt expects a settable destination.
func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
value, nn, err := decoder.ReadIntN(bytes)
n += nn; if err != nil { return n, err }
setInt(destination, value)
return n, nil
return n, setInt(destination, value)
}
// decodeAndSetUint expects a settable destination.
func decodeAndSetUint(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
value, nn, err := decoder.ReadUintN(bytes)
n += nn; if err != nil { return n, err }
setInt(destination, value)
return n, nil
return n, setInt(destination, value)
}
// decodeAndSetInt expects a settable destination.
@ -391,15 +304,13 @@ func decodeAndSetFloat(decoder *Decoder, destination reflect.Value, bytes int) (
case 8:
value, nn, err := decoder.ReadFloat64()
n += nn; if err != nil { return n, err }
setFloat(destination, float64(value))
return n, nil
return n, setFloat(destination, float64(value))
case 4:
value, nn, err := decoder.ReadFloat32()
n += nn; if err != nil { return n, err }
setFloat(destination, float64(value))
return n, nil
return n, setFloat(destination, float64(value))
}
return n, errCantAssignf("unsupported bit width float%d", bytes * 8)
return n, fmt.Errorf("cannot decode float%d", bytes * 8)
}
// skeletonValue returns a pointer value. In order for it to be set, it must be

View File

@ -6,8 +6,6 @@ import "testing"
import "reflect"
import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil"
type userDefinedInteger int16
func TestEncodeAnyInt(test *testing.T) {
err := testEncodeAny(test, uint8(0xCA), LI.WithCN(0), tu.S(0xCA))
if err != nil { test.Fatal(err) }
@ -25,8 +23,7 @@ func TestEncodeAnyTable(test *testing.T) {
0xFFFF: []uint16 { 0xBEE5, 0x7777 },
0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} },
0x2345: [][]int16 { []int16 { 0x5 }, []int16 { 0x17, -0xAAA } },
0x3456: userDefinedInteger(0x3921),
}, KTV.WithCN(0), tu.S(7).AddVar(
}, KTV.WithCN(0), tu.S(6).AddVar(
[]byte {
0xF3, 0xB9,
byte(LSI.WithCN(3)),
@ -65,110 +62,10 @@ func TestEncodeAnyTable(test *testing.T) {
0, 0x17,
0xF5, 0x56,
},
[]byte {
0x34, 0x56,
byte(LSI.WithCN(1)),
0x39, 0x21,
},
))
if err != nil { test.Fatal(err) }
}
func TestDecodeWrongType(test *testing.T) {
datas := [][]byte {
/* int8 */ []byte { byte(LSI.WithCN(0)), 0x45 },
/* int16 */ []byte { byte(LSI.WithCN(1)), 0x45, 0x67 },
/* int32 */ []byte { byte(LSI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB },
/* int64 */ []byte { byte(LSI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 },
/* uint5 */ []byte { byte(SI.WithCN(12)) },
/* uint8 */ []byte { byte(LI.WithCN(0)), 0x45 },
/* uint16 */ []byte { byte(LI.WithCN(1)), 0x45, 0x67 },
/* uint32 */ []byte { byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB },
/* uint64 */ []byte { byte(LI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 },
/* string */ []byte { byte(SBA.WithCN(7)), 'p', 'u', 'p', 'e', 'v', 'e', 'r' },
/* []byte */ []byte { byte(SBA.WithCN(5)), 'b', 'l', 'a', 'r', 'g' },
/* []string */ []byte {
byte(OTA.WithCN(0)), 2, byte(LBA.WithCN(0)),
0x08, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23,
0x05, 0x11, 0x11, 0x11, 0x11, 0x11,
},
/* map[uint16] any */ []byte {
byte(KTV.WithCN(0)), 2,
0x02, 0x23, byte(LSI.WithCN(1)), 0x45, 0x67,
0x02, 0x23, byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB,
},
}
for index, data := range datas {
test.Logf("data %2d %v [%s]", index, Tag(data[0]), tu.HexBytes(data[1:]))
// integers should only assign to other integers
if index > 8 {
cas := func(destination any) {
n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0]))
if err != nil { test.Fatalf("error: %v | n: %d", err, n) }
reflectValue := reflect.ValueOf(destination).Elem()
if reflectValue.CanInt() {
if reflectValue.Int() != 0 {
test.Fatalf("destination not zero: %v", reflectValue.Elem().Interface())
}
} else {
if reflectValue.Uint() != 0 {
test.Fatalf("destination not zero: %v", reflectValue.Elem().Interface())
}
}
if n != len(data) - 1 {
test.Fatalf("n not equal: %d != %d", n, len(data) - 1)
}
}
test.Log("- int8")
{ var dest int8; cas(&dest) }
test.Log("- int16")
{ var dest int16; cas(&dest) }
test.Log("- int32")
{ var dest int32; cas(&dest) }
test.Log("- int64")
{ var dest int64; cas(&dest) }
test.Log("- uint8")
{ var dest uint8; cas(&dest) }
test.Log("- uint16")
{ var dest uint16; cas(&dest) }
test.Log("- uint32")
{ var dest uint32; cas(&dest) }
test.Log("- uint64")
{ var dest uint64; cas(&dest) }
}
arrayCase := func(destination any) {
n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0]))
if err != nil { test.Fatalf("error: %v | n: %d", err, n) }
reflectDestination := reflect.ValueOf(destination)
reflectValue := reflectDestination.Elem()
if reflectValue.Len() != 0 {
test.Fatalf("len(destination) not zero: %v", reflectValue.Interface())
}
if n != len(data) - 1 {
test.Fatalf("n not equal: %d != %d", n, len(data) - 1)
}
}
// SBA/LBA types should only assign to other SBA/LBA types
if index != 9 && index != 10 {
test.Log("- string")
{ var dest string; arrayCase(&dest) }
test.Log("- []byte")
{ var dest []byte; arrayCase(&dest) }
}
// arrays should only assign to other arrays
if index != 11 {
test.Log("- []string")
{ var dest []string; arrayCase(&dest) }
}
// tables should only assign to other tables
if index != 12 {
test.Log("- map[uint16] any")
{ var dest = map[uint16] any { }; arrayCase(&dest) }
}
}
}
func TestEncodeDecodeAnyTable(test *testing.T) {
err := testEncodeDecodeAny(test, map[uint16] any {
0xF3B9: uint32(1),
@ -269,10 +166,10 @@ func testEncodeAny(test *testing.T, value any, correctTag Tag, correctBytes tu.S
test.Log("got: ", tu.HexBytes(bytes))
test.Log("correct:", correctBytes)
if tag != correctTag {
return fmt.Errorf("tag not equal: %v != %v", tag, correctTag)
return fmt.Errorf("tag not equal")
}
if ok, n := correctBytes.Check(bytes); !ok {
return fmt.Errorf("bytes not equal at index %d", n)
return fmt.Errorf("bytes not equal: %d", n)
}
if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes))

View File

@ -1,12 +0,0 @@
package tape
// Error enumerates common errors in this package.
type Error string; const (
ErrTooLong Error = "data structure too long"
ErrTooLarge Error = "number too large"
)
// Error implements the error interface.
func (err Error) Error() string {
return string(err)
}

View File

@ -1,26 +0,0 @@
package tape
// MaxStructureLength determines how long a TAPE data structure can be. This
// applies to:
//
// - OTA
// - SBA/LBA
// - KTV
//
// By default it is set at 2^20 (about a million).
// You shouldn't need to change this. If you do, it should only be set once at
// the start of the program.
var MaxStructureLength = 1024 * 1024
// MaxInt is the maximum value an int can hold. This varies depending on the
// system.
const MaxInt int = int(^uint(0) >> 1)
// Uint64ToIntSafe casts the input to an int if it can be done without overflow,
// or returns an error otherwise.
func Uint64ToIntSafe(input uint64) (int, error) {
if input > uint64(MaxInt) {
return 0, ErrTooLarge
}
return int(input), nil
}

View File

@ -1,54 +0,0 @@
package tape
import "fmt"
// Skim uses up data from a decoder to "skim" over one value (and all else
// contained within it) without actually putting the data anywhere.
func Skim(decoder *Decoder, tag Tag) (n int, err error) {
switch tag.WithoutCN() {
case SI:
// SI: (none)
return n, nil
case LI, LSI, FP:
// LI: <value: IntN>
// LSI: <value: IntN>
// FP: <value: FloatN>
nn, err := decoder.Discard(tag.CN() + 1)
n += nn; if err != nil { return n, err }
case SBA:
// SBA: <data: U8>*
nn, err := decoder.Discard(tag.CN())
n += nn; if err != nil { return n, err }
case LBA:
// LBA: <length: UN> <data: U8>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
nn, err = decoder.Discard(int(length))
n += nn; if err != nil { return n, err }
case OTA:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
oneTag, nn, err := decoder.ReadTag()
n += nn; if err != nil { return n, err }
for _ = range length {
nn, err := Skim(decoder, oneTag)
n += nn; if err != nil { return n, err }
}
case KTV:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
for _ = range length {
nn, err := decoder.Discard(2)
n += nn; if err != nil { return n, err }
itemTag, nn, err := decoder.ReadTag()
n += nn; if err != nil { return n, err }
nn, err = Skim(decoder, itemTag)
n += nn; if err != nil { return n, err }
}
default:
return n, fmt.Errorf("unknown TN %d", tag.TN())
}
return n, nil
}

View File

@ -1,137 +0,0 @@
package tape
import "bytes"
import "testing"
func TestSkimInteger(test *testing.T) {
data := []byte {
0x12, 0x45, 0x23, 0xF9,
}
mainDataLen := len(data)
// extra junk
data = append(data, 0x00, 0x01, 0x02, 0x03,)
n, err := Skim(NewDecoder(bytes.NewBuffer(data)), LI.WithCN(3))
if err != nil {
test.Fatal(err)
}
if got, correct := n, mainDataLen; got != correct {
test.Fatalf("n not equal: %d != %d", got, correct)
}
}
func TestSkimArray(test *testing.T) {
data := []byte {
2, byte(LI.WithCN(1)),
0xBE, 0xE5, 0x77, 0x77,
}
mainDataLen := len(data)
// extra junk
data = append(data, 0x00, 0x01, 0x02, 0x03,)
n, err := Skim(NewDecoder(bytes.NewBuffer(data)), OTA.WithCN(0))
if err != nil {
test.Fatal(err)
}
if got, correct := n, mainDataLen; got != correct {
test.Fatalf("n not equal: %d != %d", got, correct)
}
}
func TestSkimNestedArray(test *testing.T) {
data := []byte {
2, byte(OTA.WithCN(0)),
1, byte(LSI.WithCN(1)),
0, 0x5,
2, byte(LSI.WithCN(1)),
0, 0x17,
0xF5, 0x56,
}
mainDataLen := len(data)
// extra junk
data = append(data, 0x00, 0x01, 0x02, 0x03,)
n, err := Skim(NewDecoder(bytes.NewBuffer(data)), OTA.WithCN(0))
if err != nil {
test.Fatal(err)
}
if got, correct := n, mainDataLen; got != correct {
test.Fatalf("n not equal: %d != %d", got, correct)
}
}
func TestSkimTable(test *testing.T) {
data := []byte {
2,
0xF3, 0xB9,
byte(LSI.WithCN(3)),
0, 0, 0, 1,
0x01, 0x02,
byte(LSI.WithCN(3)),
0, 0, 0, 2,
}
mainDataLen := len(data)
// extra junk
data = append(data, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03)
n, err := Skim(NewDecoder(bytes.NewBuffer(data)), KTV.WithCN(0))
if got, correct := n, mainDataLen; got != correct {
test.Fatalf("n not equal: %d != %d ... (%d)", got, correct, len(data))
}
if err != nil {
test.Fatal(err)
}
}
func TestSkimTableComplex(test *testing.T) {
data := []byte {
7,
0xF3, 0xB9,
byte(LSI.WithCN(3)),
0, 0, 0, 1,
0x01, 0x02,
byte(LSI.WithCN(3)),
0, 0, 0, 2,
0, 0,
byte(SBA.WithCN(3)),
'h', 'i', '!',
0xFF, 0xFF,
byte(OTA.WithCN(0)), 2, byte(LI.WithCN(1)),
0xBE, 0xE5, 0x77, 0x77,
0x12, 0x34,
byte(OTA.WithCN(0)), 2, byte(OTA.WithCN(0)),
1, byte(LI.WithCN(1)),
0, 0x5,
2, byte(LI.WithCN(1)),
0, 0x17,
0xAA, 0xAA,
0x23, 0x45,
byte(OTA.WithCN(0)), 2, byte(OTA.WithCN(0)),
1, byte(LSI.WithCN(1)),
0, 0x5,
2, byte(LSI.WithCN(1)),
0, 0x17,
0xF5, 0x56,
0x34, 0x56,
byte(LSI.WithCN(1)),
0x39, 0x21,
}
mainDataLen := len(data)
// extra junk
data = append(data, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03)
n, err := Skim(NewDecoder(bytes.NewBuffer(data)), KTV.WithCN(0))
if got, correct := n, mainDataLen; got != correct {
test.Fatalf("n not equal: %d != %d ... (%d)", got, correct, len(data))
}
if err != nil {
test.Fatal(err)
}
}