diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6f9c924 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/generate/test diff --git a/cmd/hopp-generate/main.go b/cmd/hopp-generate/main.go index 3e38b0d..78ecb56 100644 --- a/cmd/hopp-generate/main.go +++ b/cmd/hopp-generate/main.go @@ -4,6 +4,7 @@ import "os" import "fmt" import "strings" import "path/filepath" +import "git.tebibyte.media/sashakoshka/goparse" import "git.tebibyte.media/sashakoshka/hopp/generate" func main() { @@ -18,7 +19,7 @@ func main() { input, err := os.Open(source) handleErr(1, err) defer input.Close() - protocol, err := generate.ParseReader(input) + protocol, err := generate.ParseReader(source, input) handleErr(1, err) absDestination, err := filepath.Abs(destination) @@ -30,14 +31,18 @@ func main() { output, err := os.Create(destination) handleErr(1, err) - err = protocol.Generate(output, packageName) + generator := generate.Generator { + Output: output, + PackageName: packageName, + } + _, err = generator.Generate(protocol) handleErr(1, err) fmt.Fprintf(os.Stderr, "%s: OK\n", name) } func handleErr(code int, err error) { if err != nil { - fmt.Fprintf(os.Stderr, "%s: %v\n", os.Args[0], err) + fmt.Fprintf(os.Stderr, "%s: %v\n", os.Args[0], parse.Format(err)) os.Exit(code) } } diff --git a/codec.go b/codec.go new file mode 100644 index 0000000..08426e1 --- /dev/null +++ b/codec.go @@ -0,0 +1,47 @@ +package hopp + +import "fmt" + +type anyInt16 interface { ~uint16 | ~int16 } +type anyInt64 interface { ~uint64 | ~int64 } + +// decodeI16 decodes a 16 bit integer from the given data. +func decodeI16[T anyInt16](data []byte) (T, error) { + if len(data) != 2 { return 0, fmt.Errorf("decoding int16: %w", ErrWrongBufferLength) } + return T(data[0]) << 8 | T(data[1]), nil +} + +// encodeI16 encodes a 16 bit integer into the given buffer. +func encodeI16[T anyInt16](buffer []byte, value T) error { + if len(buffer) != 2 { return fmt.Errorf("encoding int16: %w", ErrWrongBufferLength) } + buffer[0] = byte(value >> 8) + buffer[1] = byte(value) + return nil +} + +// decodeI64 decodes a 64 bit integer from the given data. +func decodeI64[T anyInt64](data []byte) (T, error) { + if len(data) != 8 { return 0, fmt.Errorf("decoding int64: %w", ErrWrongBufferLength) } + return T(data[0]) << 56 | + T(data[1]) << 48 | + T(data[2]) << 40 | + T(data[3]) << 32 | + T(data[4]) << 24 | + T(data[5]) << 16 | + T(data[6]) << 8 | + T(data[7]), nil +} + +// encodeI64 encodes a 64 bit integer into the given buffer. +func encodeI64[T anyInt64](buffer []byte, value T) error { + if len(buffer) != 8 { return fmt.Errorf("encoding int64: %w", ErrWrongBufferLength) } + buffer[0] = byte(value >> 56) + buffer[1] = byte(value >> 48) + buffer[2] = byte(value >> 40) + buffer[3] = byte(value >> 32) + buffer[4] = byte(value >> 24) + buffer[5] = byte(value >> 16) + buffer[6] = byte(value >> 8) + buffer[7] = byte(value) + return nil +} diff --git a/connection.go b/connection.go index adefbdf..07eb829 100644 --- a/connection.go +++ b/connection.go @@ -1,7 +1,10 @@ package hopp +import "io" import "net" -// import "time" +import "time" + +const defaultSizeLimit int64 = 1024 * 1024 // 1 megabyte // Conn is a HOPP connection. type Conn interface { @@ -19,22 +22,50 @@ type Conn interface { // AcceptTrans accepts a transaction from the other party. This must // be called in a loop to avoid the connection locking up. AcceptTrans() (Trans, error) + + // SetDeadline operates is [net.Conn.SetDeadline] but for OpenTrans + // and AcceptTrans calls. + SetDeadline(t time.Time) error + // SetSizeLimit sets a limit (in bytes) for how large messages can be. + // By default, this limit is 1 megabyte. Note that this is only + // enforced when sending and receiving byte slices, and it does not + // apply to [Trans.SendWriter] or [Trans.ReceiveReader]. + SetSizeLimit(limit int64) } -// Trans is a HOPP transaction. +// Trans is a HOPP transaction. Methods of this interface are not safe for +// concurrent use with the exception of the Close and ID methods. The +// recommended use case is one goroutine per transaction. type Trans interface { // Close closes the transaction. Any blocked operations will be - // unblocked and return errors. + // unblocked and return errors. This method is safe for concurrent use. Close() error // ID returns the transaction ID. This must not change, and it must be - // unique within the connection. + // unique within the connection. This method is safe for concurrent use. ID() int64 - - // TODO: add methods for setting send and receive deadlines - // Send sends a message. + // Send sends a message. This method is not safe for concurrent use. Send(method uint16, data []byte) error - // Receive receives a message. + // SendWriter sends data written to an [io.Writer]. The writer must be + // closed after use. Closing the writer flushes any data that hasn't + // been written yet. Any writer previously opened through this function + // will be discarded. This method is not safe for concurrent use, and + // neither is its result. + SendWriter(method uint16) (io.WriteCloser, error) + // Receive receives a message. This method is not safe for concurrent + // use. Receive() (method uint16, data []byte, err error) + // ReceiveReader receives a message as an [io.Reader]. Any reader + // previously opened through this function will be discarded. This + // method is not safe for concurrent use, and neither is its result. + ReceiveReader() (method uint16, data io.Reader, err error) + + // See the documentation for [net.Conn.SetDeadline]. + SetDeadline(time.Time) error + // TODO + // // See the documentation for [net.Conn.SetReadDeadline]. + // SetReadDeadline(t time.Time) error + // // See the documentation for [net.Conn.SetWriteDeadline]. + // SetWriteDeadline(t time.Time) error } diff --git a/design/branched-generated-encoder.md b/design/branched-generated-encoder.md new file mode 100644 index 0000000..5bacc24 --- /dev/null +++ b/design/branched-generated-encoder.md @@ -0,0 +1,128 @@ +# Branched Generated Decoder + +Pasted here because Tebitea is down + +## The problem + +TAPE is designed so that the decoder can gloss over data it does not understand. +Technically the protocol allows for this, but I completely forgot to implement +this in the generated decoder, oops. This would be trivial if TAPE messages were +still flat tables, but they aren't, because those aren't useful enough. So, +let's analyze the problem. + +## When it happens + +There are two reasons something might not match up with the expected data: + +The first and most obvious is unrecognized keys. If the key is not in the set of +recognized keys for a KTV, it should leave the corresponding struct field blank. +Once #6 has been implemented, throw an error if the data was not optional. + +The second is wrong types. If we are expecting KTV and get SBA, we should leave +the data as empty. The aforementioned concern about #6 also applies here. We +don't need to worry about special cases at the structure root, because it would +be technically possible to make the structure root an option, so it really is +just a normal value. Until #6, we will leave that blank too. + +## Preliminary ideas + +The first is going to be pretty simple. All we need to do is have a skimmer +function that skims over TAPE data very, and then call that on the KTV value +each time we run into a mystery key. It should only return an error if the +structure of the data is malformed in such a way that it cannot continue to the +next one. This should be stored in the tape package alongside the dynamic +decoding functions, because they will essentially function the same way and +could probably share lots of code. + +The second is a bit more complicated because of the existence of KTV and OTA +because they are aggregate types. Go types work a bit differently, as if you +have an array of an array of an array of ints, that information is represented +in one place, whereas TAPE doesn't really do that. All of that information is +sort of buried within the data structure, so we don't know what we will be +decoding before we actually do it. Whenever we encounter a type we don't expect, +we would need to abort decoding of the entire data structure, and then skim over +whatever detritus is left, which would literally be in a half-decoded state. The +fact that the code is generated flat and thus cannot use return or defer +statements contributes to the complexity of this problem. We need to go up, but +we can't. There is no up, only forward. + +Of course, the dynamic decoder does not have this problem in the first place +because it doesn't expect anything, and constructs the destination to fit +whatever it sees in the TAPE structure as it is decoding it. KTVs are completely +dynamic because they are implemented as maps, so the only time it needs to +completely comprehend a type is with OTAs. There is a function called typeOf +that gets the type of the current tag and returns it as a reflect.Type, which +necessitates recursion and peeking at OTAs and their elements. + +We could try to do the same thing in the generated decoder, comparing the +determined type against the expected type to try to figure out whether we should +decode an array or a table, etc. This is immediately problematic as it requires +memory to be allocated, both for the peek buffer and the resulting tree of type +information. If we end up with some crazy way to keep track of the types, that's +only one half of the allocation problem and we would still be spending extra +cycles going over all of that twice. + +## Performance constraints + +The generated decoder is supposed to blaze through data, and it can't do that if +it does all the singing and dancing that the dynamic decoder does. It's time for +some performance constraints: + +- No allocations, except as required to build the destination for the data +- No redundant work +- So, no freaking peeking +- It should take well under 500 lines of generated code to decode one message of +reasonable size (i.e. be careful not to bloat the binary) + +I'm not really going to do my usual thing here of making a slow version and +speeding it up over time based on evidence and experimentation because these +constraints inform the design so much it would be impossible to continue without +them. I am 99% confident that these constraints will allow for an acceptable +baseline of performance (for generated code) and we can still profile and +micro-optimize later. This is good enough for me. +Heavy solution + +There is a solution that might work very well which involves completely redoing +the generated decoding code. We could create a function for every source type to +destination type mapping that exists in protocol, and then compose them all +together. The decoding methods for each message or type would be wrappers around +the correct function for their root TAPE -> Go type mapping. The main benefit of +this is it would make this problem a lot more manageable because the interface +points between the data would be represented by function boundaries. This would +allow the use of return and defer statements, and would allow more code sharing, +producing a smaller binary. Go would probably inline these where needed. + +Would this work? Probably. More investigation is required to make sure. I want +to stop re-writing things I don't need to. On the other hand, it is just the +decoder. + +## Light solution + +TODO: find a solution that satisfies the performance constraints, keeps the same +identical interface, and works off the same code. I am convinced this is doable, +and it might even allow us to extract more data from an unexpected structure. +However, continuing this way might introduce unmanageable complexity. It is +already a little unmanageable and I am just one pony (kind of). + +## Implementation + +Heavy solution is going to work here, applied to only the points of +`Generator.generateDecodeValue` where it decodes an aggregate data structure. +That way, only minimal amounts of code need to be redone. + +Whenever a branch needs to happen, a call shall be generated, a deferred +implementation request shall be added to a special FIFO queue within the +generator. After generating data structures and their root decoding functions, +the generator shall pick away at this queue until no requests remain. The +generator shall accept new items during this process, so that recursion is +possible. This is all to ensure it is only ever writing one function at a time + +The functions shall take a pointer to a type that accepts any type like (~) the +destination's base type. We should also probably just call +`Generator.generateDecodeValue` directly on user defined types this way, keeping +their public `Decode` methods just for convenience. + +The tape package shall contain a skimming function that takes a decoder and a +tag, and recursively consumes the decoder given the context of the tag. This +shall be utilized by the decoder functions to skip over values if their tags +or keys do not match up with what is expected. diff --git a/design/pdl-compiler.md b/design/pdl-compiler.md new file mode 100644 index 0000000..b87e3f3 --- /dev/null +++ b/design/pdl-compiler.md @@ -0,0 +1,109 @@ +# PDL Compiler Specification + +Given one or more PDL files representing a protocol, the compiler shall generate +a Go package named "protocol", which shall contain types for message and type +definitions, as well as encoding and decoding methods. + +## Static Section + +The compiler shall write a static section alongside the generated code. It +shall contain this text: + +```go +// Table is a KTV table with an undefined schema. +type Table map[uint16] any + +// Message is any message that can be sent along this protocol. +type Message interface { + codec.Encodable + codec.Decodable + + // Method returns the method code of the message. + Method() uint16 +} +``` + +## Preamble + +At the start of each file but after the package name, the compiler shall emit +this text: + +```go +/* # Do not edit this package by hand! + * + * This file was automatically generated by the Holanet PDL compiler. The + * source file is located at + * Please edit that file instead, and re-compile it to this location. + * + * HOPP, TAPE, METADAPT, PDL/0 (c) 2025 holanet.xyz + */ +``` + +Where `` is the path of the protocol definition file relative to the +generated file. + +## Message Definitions + +For each defined message, the compiler shall generate a Go type named +`MessageName`, where `Name` is the name of the message as written in its +definition. The message shall be encodable, and shall have `Encode` and `Decode` +methods as described below. + +All messages shall satisfy a `Message` interface, which is defined in the +static section. + +## Type Definitions + +For each defined type, the compiler shall generate a Go type with the same name +as written in its definition. The Go type shall be encodable, and shall have +`EncodeValue`, `DecodeValue`, and `Tag` methods as described below. + +## Encoding and Decoding Methods + +Each message shall be given an `Encode` method and a `Decode` method, +which shall take in a `codec.Encoder` and a `codec.Decoder` respectively. Both +return an `(int, error)` pair describing the amount of bytes written and an +error if the write stopped early. `Encode` shall encode the data within the +message to the given encoder, and `Decode` shall decode data from the given +decoder and place it in the type's value. The methods shall not retain or close +any encoders or decoders they are given. Both methods shall have pointer +receivers. In effect, these methods shall satisfy `codec.Encodable` and +`codec.Decodable`. + +Each defined type shall be given an `EncodeValue` method and a `DecodeValue` +method, which shall both take in a `tape.Tag`, then a `codec.Encoder` and a +`codec.Decoder` respectively. These methods shall encode and decode the value +according to the CN given by the tag. The TN shall be ignored. The message shall +also have a method `Tag` that takes no arguments and returns the preferred tag +of the type including the TN and CN. + +## Connection + +The compiler shall generate a `Conn` struct which embeds a `hopp.Conn`, which +is the real "porcelain" of the generated code. It shall provide methods to +create and accept transactions. Each transaction shall be a struct which embeds +a `hopp.Trans`, and shall have methods for sending and receiving messages. + +### Sending + +To send a message along a transaction, the program shall: + +1. Obtain the method code from the message +3. Obtain a writer from the connection using the method code +4. Wrap the writer in a `codec.Encoder` +5. Use the encoder to encode the message +6. Close the writer + +### Receiving + +To receiving a message from a transaction, the program shall: + +1. Obtain a method code and reader from the connection +2. Wrap the reader in a `codec.Decoder` +3. Switch on the method code, and decode the correct message using the decoder +4. Return the message to the caller as a value + +The recieve function must return the message as a value instead of a pointer in +order to avoid making an allocation. Because of this, the return value must be +`any` instead of `Message`. The caller must then use a type switch to figure out +what message was sent. diff --git a/design/pdl-language.md b/design/pdl-language.md new file mode 100644 index 0000000..27195bd --- /dev/null +++ b/design/pdl-language.md @@ -0,0 +1,104 @@ +# PDL Language Definition + +PDL allows defining a protocol using HOPP and TAPE. + +## Data Types + +| Syntax | TN | CN | Description +| ---------- | ------- | -: | ----------- +| I5 | SI | | +| I8 | LSI | 0 | +| I16 | LSI | 1 | +| I32 | LSI | 3 | +| I64 | LSI | 7 | +| I128[^2] | LSI | 15 | +| I256[^2] | LSI | 31 | +| U5 | SI | | +| U8 | LI | 0 | +| U16 | LI | 1 | +| U32 | LI | 3 | +| U64 | LI | 7 | +| U128[^2] | LI | 15 | +| U256[^2] | LI | 31 | +| F16 | FP | 1 | +| F32 | FP | 3 | +| F64 | FP | 7 | +| F128[^2] | FP | 15 | +| F256[^2] | FP | 31 | +| String | SBA/LBA | * | UTF-8 string +| Buffer | SBA/LBA | * | Byte array +| []\ | OTA | * | Array of any type[^1] +| Table | KTV | * | Table with undefined schema +| {...} | KTV | * | Table with defined schema + +[^1]: Excluding SI and SBA. I5 and U5 cannot be used in an array, but String and +Buffer are simply forced to use their "long" variant. + +[^2]: Some systems may lack support for this. + +## Tokens + +PDL files are divided into tokens, which assemble together into larger language +structures. They are separated by whitespace. + +| Name | Syntax | Description +| -------- | ------------------ | ----------- +| Method | `M[0-9A-Fa-f]{4}` | A 16-bit hexadecimal method code. +| Key | `[0-9A-Fa-f]{4}` | A 16-bit hexadecimal table key. +| Ident | `[A-Z][A-Za-z0-9]` | An identifier. +| Comma | `,` | A comma separator. +| LBrace | `{` | A left curly brace. +| RBrace | `}` | A right curly brace. +| LBracket | `[` | A left square bracket. +| RBracket | `]` | A right square bracket. + +## Syntax + +Types are expressed with an Ident. A table can be used by either writing the +name of the type (Table), or by defining a schema with curly braces. Arrays must +be expressed using two matching square brackets before their element type. + +A table schema contains comma-separated fields in-between its braces. Each field +has three parts: the key number (Key), the field name (Ident), and the field +type. Tables, Arrays, etc. can be nested. + +Files directly contain messages and types, which start with a Method token and +an Ident token respectively. A message consists of the method code (Method), the +message name (Ident), and the message's root type. This is usually a table, but +can be anything. + +Here is an example of all that: + +``` +M0000 Connect { + 0000 Name String, + 0001 Password String, +} + +M0001 UserList { + 0000 Users []User, +} + +User { + 0000 Name String, + 0001 Bio String, + 0002 Followers U32, +} +``` + +## EBNF Description + +Below is an EBNF description of the language. + +``` + -> ( | -> /M[0-9A-Fa-f]{4}/ + -> /[0-9A-Fa-f]{4}/ + -> /[A-Z][A-Za-z0-9]/ + -> + -> + | "[" "]" + | "{" ( ",")* [] "}" + -> + -> +``` diff --git a/design/protocol.md b/design/protocol.md index ca37998..f1a107e 100644 --- a/design/protocol.md +++ b/design/protocol.md @@ -18,12 +18,10 @@ dependant on which transport is being used. A message refers to a block of octets sent within a transaction, paired with an unsigned 16-bit method code. The order of messages within a given transaction is preserved, but the order of messages accross the entire connection is not -guaranteed. - -The message payload must be 65,535 (unsigned 16-bit integer limit) octets or -smaller in length. This does not include the method code. Applications are free -to send whatever data they wish as the payload, but TAPE is recommended for -encoding it. +guaranteed. There is no functional limit on the size of a message payload, but +there may be one depending on which +[METADAPT sub-protocol](#message-and-transaction-demarcation-protocol-metadapt) +is in use. Method codes should be written in upper-case base 16 with the prefix "M" in logs, error messages, documentation, etc. For example, the method code 62,206 in @@ -37,100 +35,92 @@ fucking with you. ## Table Pair Encoding (TAPE) The Table Pair Encoding (TAPE) scheme is a method for encoding structured data within HOPP messages. It defines standard binary encoding methods for common -data types, as well as a corruption-resistant table structure that maps numeric -IDs to values. It is designed to allow applications to be presented with data -they are not equipped to handle while continuing to function normally. This -enables backwards compatibile application protocol changes. +data types, as well as aggregate data types such as tables and arrays. It is +designed to allow applications to be presented with data they are not equipped +to handle while continuing to function normally. This enables backwards +compatibile application protocol changes. -### Table Structure -A table is divided into two sections: the header, and the values. The header -begins with the number (U16) of pairs in the table, which is then followed by -that many tag-offset pairs. A tag-offset pair consists of a numerical (U16) tag, -followed the position (U16) of the value relative to the start of the values -section. The values section contains the value data for each pair, where the -start of each value is determined by its offset, and the end is determined by -the offset of the next value, or the end of the message if there is no value -after it. +TAPE expresses types using tags. A tag is 8 bits in size, and is divided into +two parts: the Type Number (TN), and the Configuration Number (CN). The TN is 3 +bits, and the CN is 5 bits. Both are interpreted as unsigned integers. Both +sides of the connection must agree on the semantic meaning of the values and +their arrangement. -Both sections must be in the same order, and because of this, each value offset -must be greater than or equal to the last. If a message has erratic structure -(such as unordered or out-of-bounds offsets), implementations may opt to discard -only the erratic pairs, as well as the pairs directly before those. +A TAPE structure begins with one root, which consists of a tag followed by a +payload. This is usually an aggregate data structure such as KTV to allow for +several different values. + +TAPE is based on an encoding method previously developed by silt. ### Data Value Types -The table below lists all data value types supported by TAPE. +The table below lists all data value types supported by TAPE. They are discussed +in detail in the following sections. -| Name | Size | Description | Encoding Method -| ----------- | --------------: | --------------------------- | --------------- -| I8 | 1 | A signed 8-bit integer | BETC -| I16 | 2 | A signed 16-bit integer | BETC -| I32 | 4 | A signed 32-bit integer | BETC -| I64 | 8 | A signed 64-bit integer | BETC -| U8 | 1 | An unsigned 8-bit integer | BEU -| U16 | 2 | An unsigned 16-bit integer | BEU -| U32 | 4 | An unsigned 32-bit integer | BEU -| U64 | 8 | An unsigned 64-bit integer | BEU -| Array[^1] | SOP[^2] | An array of any above type | PASTA -| String | N/A | A UTF-8 string | UTF-8 -| StringArray | n * 2 + SOP[^2] | An array the String type | VILA +| TN | Bits | Name | Description +| -: | ---: | ---- | ----------- +| 0 | 000 | SI | Small integer +| 1 | 001 | LI | Large integer +| 2 | 010 | FP | Floating point +| 3 | 011 | SBA | Small byte array +| 4 | 100 | LBA | Large byte array +| 5 | 101 | OTA | One-tag array +| 6 | 110 | KTV | Key-tag-value table +| 7 | 111 | N/A | Reserved -[^1]: Array types are written as Array, where is the element type. For -example, an array of I32 would be written as I32Array. StringArray still follows -this rule, even though it is encoded differently from other arrays. Nesting -arrays inside of arrays is prohibited. This problem can be avoided in most cases -by effectively utilizing the table structure, or by improving the design of -your protocol. +#### Small Integer (SI) +SI encodes an integer of up to 5 bits, which are stored in the CN. It has no +payload. Whether the bits are interpreted as unsigned or as signed two's +complement is semantic information and must be agreed upon by both sides of the +connection. Thus, the value may range from 0 to 31 if unsigned, and from -16 to +17 if signed. -[^2]: SOP (sum of parts) refers to the sum of the size of every item in a data -structure. +#### Large Integer (LI) +LI encodes an integer of up to 256 bits, which are stored in the payload. The CN +determine the length of the payload in bytes. The integer is big-endian. Whether +the payload is interpreted as unsigned or as signed two's complement is semantic +information and must be agreed upon by both sides of the connection. Thus, the +value may range from 0 to 31 if unsigned, and from -16 to 17 if signed. -### Encoding Methods -Below are all encoding methods supported by TAPE. +#### Floating Point (FP) +FP encodes an IEEE 754 floating point number of up to 256 bits, which are stored +in the payload. The CN determines the length of the payload in bytes, and it may +only be one of these values: 16, 32, 64, 128, or 256. -#### BETC -Big-Endian, Two's Complement signed integer. The size is defined as the least -amount of whole octets which can fit all bits in the integer, regardless if the -bits are on or off. Therefore, the size cannot change at runtime. +#### Small Byte Array (SBA) +SBA encodes an array of up to 32 bytes, which are stored in the paylod. The +CN determines the length of the payload in bytes. -#### BEU -Big-Endian, Unsigned integer. The size is defined as the least amount of whole -octets which can fit all bits in the integer, regardless if the bits are on or -off. Therefore, the size cannot change at runtime. +#### Large Byte Array (LBA) +LBA encodes an array of up to 2^256 bytes, which are stored in the second part +of the payload, directly after the length. The length of the data length field +in bytes is determined by the CN. -#### PASTA -Packed Single-Type Array. The size is defined as the size of an individual item -times the number of items. Items are placed one after the other with no gaps -in-between them, except as required to align the start of each item to the -nearest whole octet. Items should be of the same type and must be of the same -size. +#### One-Tag Array (OTA) +OTA encodes an array of up to 2^256 items, which are stored in the payload after +the length field and the item tag, where the length field comes first. Each item +must be the same length, as they all share the same tag. The length of the data +length field in bytes is determined by the CN. -#### UTF-8 -UTF-8 string. The size is defined as the least amount of whole octets which can -fit all bits in the string, regardless if the bits are on or off. The size of -this type is not fixed and may change at runtime, so this needs to be accounted -for during use. - -#### VILA -Variable Item Length Array. The size is defined as the least amount of whole -octets which can fit each item plus one U16 per item. The size of this type is -not fixed and may change at runtime, so this needs to be accounted for during -use. The amount of items must be greater than zero. Items are each prefixed by -their size (in octets) encoded as a U16, and they are placed one after the other -with no gaps in-between them, except as required to align the start of each item -to the nearest whole octet. Items should be of the same type but do not need to -be of the same size. +#### Key-Tag-Value Table (KTV) +KTV encodes a table of up to 2^256 key/value pairs, which are stored in the +payload after the length field. The pairs themselves consist of a 16-bit +unsigned big-endian key followed by a tag and then the payload. Pair values can +be of different types and sizes. The order of the pairs is not significant and +should never be treated as such. ## Transports A transport is a protocol that HOPP connections can run on top of. HOPP currently supports the QUIC transport protocol for communicating between -machines, and UNIX domain sockets for quicker communication among applications -on the same machine. Both protocols are supported through METADAPT. +machines, TCP/TLS for legacy systems that do not support QUIC, and UNIX domain +sockets for faster communication among applications on the same machine. Both +protocols are supported through METADAPT. ## Message and Transaction Demarcation Protocol (METADAPT) The Message and Transaction Demarcation Protocol is used to break one or more reliable data streams into transactions, which are broken down further into -messages. A message, as well as its associated metadata (length, transaction, -method, etc.) together is referred to as METADAPT Message Block (MMB). +messages. The representation of a message (or a part thereof) on the protocol, +including its associated metadata (length, transaction, method, etc.) is +referred to as METADAPT Message Block (MMB). For transports that offer multiple multiplexed data streams that can be created and destroyed on-demand (such as QUIC) each stream is used as a transaction. If @@ -145,8 +135,12 @@ METADAPT-A requires a transport which offers a single full-duplex data stream that persists for the duration of the connection. All transactions are multiplexed onto this single stream. Each MMB contains a 12-octet long header, with the transaction ID, then the method, and then the payload size (in octets). -The transaction ID is encoded as an I64, and the method and payload size are -both encoded as U16s. The remainder of the message is the payload. Since each +The transaction ID is encoded as an I64, the method is encoded as a U16 and the +and payload size is encoded as a U64. Only the 63 least significant bits of the +payload size describe the actual size, the most significant bit controlling +chunking. See the section on chunking for more information. + +The remainder of the message is the payload. Since each MMB is self-describing, they are sent sequentially with no gaps in-between them. Transactions "open" when the first message with a given transaction ID is sent. @@ -162,13 +156,24 @@ used up, the connection must fail. Don't worry about this though, because the sun will have expanded to swallow earth by then. Your connection will not last that long. +#### Message Chunking +The most significant bit of the payload size field of an MMB is called the Chunk +Control Bit (CCB). If the CCB of a given MMB is zero, the represented message is +interpreted as being self-contained and the data is processed immediately. If +the CCB is one, the message is interpreted as being chunked, with the data of +the current MMB being the first chunk. The data of further MMBs sent along the +transaction will be appended to the message until an MMB is read with a zero +CCB, in which case the MMB will be the last chunk and any more MMBs will be +interpreted as normal. + ### METADAPT-B METADAPT-B requires a transport which offers multiple multiplexed full-duplex data streams per connection that can be created and destroyed on-demand. Each data stream is used as an individual transaction. Each MMB contains a 4-octet -long header with the method and then the payload size (in octets) both encoded -as U16s. The remainder of the message is the payload. Since each MMB is -self-describing, they are sent sequentially with no gaps in-between them. +long header with the method and then the payload size (in octets) encoded as a +U16 and U64 respectively. The remainder of the message is the payload. Since +each MMB is self-describing, they are sent sequentially with no gaps in-between +them. The ID of any transaction will reflect the ID of its corresponding stream. The lifetime of the transaction is tied to the lifetime of the stream, that is to diff --git a/dial.go b/dial.go index 95a24c9..9b9d3d8 100644 --- a/dial.go +++ b/dial.go @@ -1,9 +1,9 @@ package hopp import "net" +import "errors" import "context" import "crypto/tls" -import "github.com/quic-go/quic-go" // Dial opens a connection to a server. The network must be one of "quic", // "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, "quic4" and @@ -19,9 +19,8 @@ type Dialer struct { } // Dial opens a connection to a server. The network must be one of "quic", -// "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, "quic4" and -// "quic6" don't do anything as the quic-go package doesn't seem to support this -// behavior. +// "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, quic is not +// supported. func (diale Dialer) Dial(ctx context.Context, network, address string) (Conn, error) { switch network { case "quic", "quic4", "quic6": return diale.dialQUIC(ctx, network, address) @@ -31,12 +30,7 @@ func (diale Dialer) Dial(ctx context.Context, network, address string) (Conn, er } func (diale Dialer) dialQUIC(ctx context.Context, network, address string) (Conn, error) { - // sorry i fucking lied to you about the network parameter. for all - // quic-go's bullshit bloat, it doesnt even support that. not even when - // instantiating a transport. go figure :/ - conn, err := quic.DialAddr(ctx, address, tlsConfig(diale.TLSConfig), quicConfig()) - if err != nil { return nil, err } - return AdaptB(quicMultiConn { underlying: conn }), nil + return nil, errors.New("quic is not yet implemented") } func (diale Dialer) dialUnix(ctx context.Context, network, address string) (Conn, error) { @@ -60,15 +54,6 @@ func tlsConfig(conf *tls.Config) *tls.Config { return conf } -func quicConfig() *quic.Config { - return &quic.Config { - // TODO: perhaps we might want to put something here - // the quic config shouldn't be exported, just set up - // automatically. we can't have that strangely built quic-go - // package be part of the API, or any third-party packages for - // that matter. it must all be abstracted away. - } -} func quicNetworkToUDPNetwork(network string) (string, error) { switch network { diff --git a/error.go b/error.go index 20bf6b3..3a78a4d 100644 --- a/error.go +++ b/error.go @@ -9,6 +9,7 @@ type Error string; const ( ErrIntegerOverflow Error = "integer overflow" ErrMessageMalformed Error = "message is malformed" ErrTablePairMissing Error = "required table pair is missing" + ErrWrongBufferLength Error = "wrong buffer length" ) // Error implements the error interface. diff --git a/generate/generate.go b/generate/generate.go index 65cbe15..5538075 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -2,301 +2,1182 @@ package generate import "io" import "fmt" -import "bufio" +import "maps" +import "math" +import "slices" import "strings" +import "encoding/hex" +import "git.tebibyte.media/sashakoshka/hopp/tape" -const send = -`// Send sends one message along a transaction. -func Send(trans hopp.Trans, message hopp.Message) error { - buffer, err := message.MarshalBinary() - if err != nil { return err } - return trans.Send(message.Method(), buffer) -} +const imports = +` +import "git.tebibyte.media/sashakoshka/hopp/tape" +` + +const preamble = ` +// Code generated by the Holanet PDL compiler. DO NOT EDIT. +// The source file is located at +// Please edit that file instead, and re-compile it to this location. +// HOPP, TAPE, METADAPT, PDL/0 (c) 2025 holanet.xyz ` -// ResolveType resolves a HOPP type name to a Go type. For now, it supports all -// data types defined in TAPE. -func (this *Protocol) ResolveType(hopp string) (string, error) { - switch hopp { - case "I8": return "int8", nil - case "I16": return "int16", nil - case "I32": return "int32", nil - case "I64": return "int64", nil - case "U8": return "uint8", nil - case "U16": return "uint16", nil - case "U32": return "uint32", nil - case "U64": return "uint64", nil - case "I8Array": return "[]int8", nil - case "I16Array": return "[]int16", nil - case "I32Array": return "[]int32", nil - case "I64Array": return "[]int64", nil - case "U8Array": return "[]uint8", nil - case "U16Array": return "[]uint16", nil - case "U32Array": return "[]uint32", nil - case "U64Array": return "[]uint64", nil - case "String": return "string", nil - case "StringArray": return "[]string", nil - default: return "", fmt.Errorf("unknown type: %s", hopp) - } +const static = ` +// Table is a KTV table with an undefined schema. +type Table map[uint16] any + +// Message is any message that can be sent along this protocol. +type Message interface { + tape.Encodable + tape.Decodable + + // Method returns the method code of the message. + Method() uint16 } -// Generate turns this protocol into code. The package name for the generated -// code must be specified. -func (this *Protocol) Generate(writer io.Writer, packag string) error { - out := bufio.NewWriter(writer) - defer out.Flush() +// canAssign determines if data from the given source tag can be assigned to +// a Go type represented by destination. It is designed to receive destination +// values from [generate.Generator.generateCanAssign]. The eventual Go type and +// the destination tag must come from the same (or hash-equivalent) PDL type. +func canAssign(destination, source tape.Tag) bool { + if destination.Is(source) { return true } + if (destination.Is(tape.SBA) || destination.Is(tape.LBA)) && + (source.Is(tape.SBA) || source.Is(tape.LBA)) { + return true + } + return false +} +` + +// Generator converts protocols into Go code. +type Generator struct { + // Output is where the generated code will be sent. + Output io.Writer + // PackageName is the package name that will be used in the file. If + // left empty, the default is "protocol". + PackageName string - fmt.Fprintf(out, "package %s\n\n", packag) - fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp\"\n") - fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp/tape\"\n\n") + nestingLevel int + temporaryVar int + protocol *Protocol + + decodeBranchRequestQueue []decodeBranchRequest +} + +type decodeBranchRequest struct { + hash [16]byte + typ Type + name string +} + +func (this *Generator) Generate(protocol *Protocol) (n int, err error) { + this.nestingLevel = 0 + this.protocol = protocol + defer func() { this.protocol = nil }() + + // preamble and static section + packageName := "protocol" + if this.PackageName != "" { + packageName = this.PackageName + } + nn, err := this.iprintf("package %s\n", packageName) + n += nn; if err != nil { return n, err } + nn, err = this.print(preamble) + n += nn; if err != nil { return n, err } + nn, err = this.print(imports) + n += nn; if err != nil { return n, err } + nn, err = this.print(static) + n += nn; if err != nil { return n, err } + + // type definitions + for _, name := range slices.Sorted(maps.Keys(protocol.Types)) { + nn, err := this.generateTypedef(name, protocol.Types[name]) + n += nn; if err != nil { return n, err } + } + + // messages + for _, method := range slices.Sorted(maps.Keys(protocol.Messages)) { + nn, err := this.generateMessage(method, protocol.Messages[method]) + n += nn; if err != nil { return n, err } + } + + // request queue + for { + hash, typ, name, ok := this.pullDecodeBranchRequest() + if !ok { break } + nn, err := this.generateDecodeBranch(hash, typ, name) + n += nn; if err != nil { return n, err } + } + + return n, nil +} + +func (this *Generator) generateTypedef(name string, typ Type) (n int, err error) { + // type definition + nn, err := this.iprintf( + "\n// %s represents the protocol data type %s.\n", + name, name) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("type %s ", name) + n += nn; if err != nil { return n, err } + nn, err = this.generateType(typ) + n += nn; if err != nil { return n, err } + nn, err = this.println() + n += nn; if err != nil { return n, err } + + // 'Tag' method + // to be honest we probably don't need this method at all + // nn, err = this.iprintf("\n// Tag returns the preferred TAPE tag.\n") + // n += nn; if err != nil { return n, err } + // nn, err = this.iprintf("func (this *%s) Tag() tape.Tag {\n", name) + // n += nn; if err != nil { return n, err } + // this.push() + // nn, err = this.iprintf("return ") + // n += nn; if err != nil { return n, err } + // nn, err = this.generateTag(typ, "(*this)") + // n += nn; if err != nil { return n, err } + // nn, err = this.println() + // n += nn; if err != nil { return n, err } + // this.pop() + // nn, err = this.iprintf("}\n") + // n += nn; if err != nil { return n, err } + + // EncodeValue method + nn, err = this.iprintf( + "\n// EncodeValue encodes the value of this type without the " + + "tag. The value is\n// encoded according to the parameters " + + "specified by the tag, if possible.\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( + "func (this *%s) EncodeValue(encoder *tape.Encoder, tag tape.Tag) (n int, err error) {\n", + name) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("var nn int\n") + n += nn; if err != nil { return n, err } + nn, err = this.generateEncodeValue(typ, "(*this)", "tag") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("return n, nil\n") + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } - fmt.Fprintf(out, send) - this.receive(out) + // DecodeValue method + nn, err = this.iprintf( + "\n // DecodeValue decodes the value of this type without " + + "the tag. The value is\n// decoded according to the " + + "parameters specified by the tag, if possible.\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( + "func (this *%s) DecodeValue(decoder *tape.Decoder, tag tape.Tag) (n int, err error) {\n", + name) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("var nn int\n") + n += nn; if err != nil { return n, err } + + nn, err = this.iprintf("if !(") + n += nn; if err != nil { return n, err } + nn, err = this.generateCanAssign(typ, "tag") + n += nn; if err != nil { return n, err } + nn, err = this.printf(") {\n") + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("nn, err = tape.Skim(decoder, tag)\n") + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("return n, nil\n") + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + + nn, err = this.generateDecodeValue(typ, name, "this", "tag") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("return n, nil\n") + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } - for _, message := range this.Messages { - err := this.defineMessage(out, message) - if err != nil { return err } - err = this.marshalMessage(out, message) - if err != nil { return err } - err = this.unmarshalMessage(out, message) - if err != nil { return err } - } - - return nil + return n, nil } -func (this *Protocol) receive(out io.Writer) error { - fmt.Fprintf(out, "// Receive receives one message from a transaction.\n") - fmt.Fprintf(out, "func Receive(trans hopp.Trans) (hopp.Message, error) {\n") - fmt.Fprintf(out, "\tmethod, data, err := trans.Receive()\n") - fmt.Fprintf(out, "\tif err != nil { return nil, err }\n") - fmt.Fprintf(out, "\tswitch method {\n") - for _, message := range this.Messages { - fmt.Fprintf(out, "\tcase 0x%04X:\n", message.Method) - fmt.Fprintf(out, "\t\tmessage := &Message%s { }\n", message.Name) - fmt.Fprintf(out, "\t\terr := message.UnmarshalBinary(data)\n") - fmt.Fprintf(out, "\t\tif err != nil { return nil, err }\n") - fmt.Fprintf(out, "\t\treturn message, nil\n") - } - fmt.Fprintf(out, "\tdefault: return nil, hopp.ErrUnknownMethod\n") - fmt.Fprintf(out, "\t}\n") - fmt.Fprintf(out, "}\n\n") - return nil -} +// generateMessage generates the structure, as well as encoding decoding +// functions for the given message. +func (this *Generator) generateMessage(method uint16, message Message) (n int, err error) { + nn, err := this.iprintf( + "\n// %s represents the protocol message M%04X %s.\n", + message.Name, method, message.Name) + nn, err = this.iprintf("type %s ", this.resolveMessageName(message.Name)) + n += nn; if err != nil { return n, err } + nn, err = this.generateType(message.Type) + n += nn; if err != nil { return n, err } + nn, err = this.println() + n += nn; if err != nil { return n, err } -func (this *Protocol) defineMessage(out io.Writer, message Message) error { - fmt.Fprintln(out, comment("//", fmt.Sprintf("(%d) %s\n", message.Method, message.Doc))) - fmt.Fprintf(out, "type Message%s struct {\n", message.Name) - for _, field := range message.Fields { - typ, err := this.ResolveType(field.Type) - if err != nil { return err } - if field.Doc != "" { - fmt.Fprintf(out, "\t%s\n", comment("\t//", field.Doc)) - } - if field.Optional { - typ = fmt.Sprintf("hopp.Option[%s]", typ) - } - fmt.Fprintf( - out, "\t/* %d */ %s %s\n", - field.Tag, field.Name, typ) - } - fmt.Fprintf(out, "}\n\n") + // Method method + nn, err = this.iprintf("\n// Method returns the message's method number.\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( + "func(this *%s) Method() uint16 { return 0x%04X }\n", + this.resolveMessageName(message.Name), + method) + n += nn; if err != nil { return n, err } + + // Encode method + nn, err = this.iprintf("\n// Encode encodes this message's tag and value.\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( + "func(this *%s) Encode(encoder *tape.Encoder) (n int, err error) {\n", + this.resolveMessageName(message.Name)) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("tag := ") + n += nn; if err != nil { return n, err } + nn, err = this.generateTag(message.Type, "(*this)") + n += nn; if err != nil { return n, err } + nn, err = this.println() + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("nn, err := encoder.WriteTag(tag)\n") + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + nn, err = this.generateEncodeValue(message.Type, "(*this)", "tag") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("return n, nil\n") + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + + // Decode method + nn, err = this.iprintf("\n// Decode decodes this message's tag and value.\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( + "func(this *%s) Decode(decoder *tape.Decoder) (n int, err error) {\n", + this.resolveMessageName(message.Name)) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("tag, nn, err := decoder.ReadTag()\n") + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + + nn, err = this.iprintf("if !(") + n += nn; if err != nil { return n, err } + nn, err = this.generateCanAssign(message.Type, "tag") + n += nn; if err != nil { return n, err } + nn, err = this.printf(") {\n") + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("nn, err = tape.Skim(decoder, tag)\n") + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("return n, nil\n") + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + + nn, err = this.generateDecodeValue(message.Type, this.resolveMessageName(message.Name), "this", "tag") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("return n, nil\n") + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } - fmt.Fprintf(out, "// Method returns the method number of the message.\n") - fmt.Fprintf(out, "func (msg Message%s) Method() uint16 {\n", message.Name) - fmt.Fprintf(out, "\treturn %d\n", message.Method) - fmt.Fprintf(out, "}\n\n") - return nil + return n, nil } -func (this *Protocol) marshalMessage(out io.Writer, message Message) error { - fmt.Fprintf(out, "// MarshalBinary encodes the data in this message into a buffer.\n") - fmt.Fprintf(out, "func (msg *Message%s) MarshalBinary() ([]byte, error) {\n", message.Name) - requiredCount := 0 - for _, field := range message.Fields { - if !field.Optional { requiredCount ++ } - } - fmt.Fprintf(out, "\tsize := 0\n") - fmt.Fprintf(out, "\tcount := %d\n", requiredCount) - for _, field := range message.Fields { - fmt.Fprintf(out, "\toffset%s := size\n", field.Name) - if field.Optional { - fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name) - fmt.Fprintf(out, "\t\tcount ++\n") - fmt.Fprintf(out, "\t\t") - err := this.marshalSizeOf(out, field) - if err != nil { return err } - fmt.Fprintf(out, " }\n") - } else { - fmt.Fprintf(out, "\t{") - fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name) - fmt.Fprintf(out, "\t\t") - err := this.marshalSizeOf(out, field) - if err != nil { return err } - fmt.Fprintf(out, " }\n") +// generateEncodeValue generates code to encode a value of a specified type. It +// pulls from the variable (or parenthetical statement) specified by +// valueSource, and the value will be encoded according to the tag stored in +// the variable (or parenthetical statement) specified by tagSource. +// the code generated is a BLOCK and expects these variables to be defined: +// +// - encoder *tape.Encoder +// - n int +// - err error +// - nn int +func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource string) (n int, err error) { + switch typ := typ.(type) { + case TypeInt: + // SI: (none) + // LI/LSI: + if typ.Bits <= 5 { + // SI stores the value in the tag, so we write nothing here + break } - } - fmt.Fprintf(out, "\tif size > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n") - fmt.Fprintf(out, "\tif count > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n") - fmt.Fprintf(out, "\tbuffer := make([]byte, 2 + 4 * count + size)\n") - fmt.Fprintf(out, "\ttape.EncodeI16(buffer[:2], uint16(count))\n") - for _, field := range message.Fields { - if field.Optional { - fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name) - fmt.Fprintf(out, "\t\t") - err := this.marshalField(out, field) - if err != nil { return err } - fmt.Fprintf(out, "}\n") - } else { - fmt.Fprintf(out, "\t{") - fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name) - fmt.Fprintf(out, "\t\t") - err := this.marshalField(out, field) - if err != nil { return err } - fmt.Fprintf(out, "}\n") + prefix := "WriteUint" + if typ.Signed { + prefix = "WriteInt" } - } - fmt.Fprintf(out, "\treturn buffer, nil\n") - fmt.Fprintf(out, "}\n\n") - return nil -} - -func (this *Protocol) marshalSizeOf(out io.Writer, field Field) error { - switch field.Type { - case "I8": fmt.Fprintf(out, "size += 1; _ = value") - case "I16": fmt.Fprintf(out, "size += 2; _ = value") - case "I32": fmt.Fprintf(out, "size += 4; _ = value") - case "I64": fmt.Fprintf(out, "size += 8; _ = value") - case "U8": fmt.Fprintf(out, "size += 1; _ = value") - case "U16": fmt.Fprintf(out, "size += 2; _ = value") - case "U32": fmt.Fprintf(out, "size += 4; _ = value") - case "U64": fmt.Fprintf(out, "size += 8; _ = value") - case "I8Array": fmt.Fprintf(out, "size += len(value)") - case "I16Array": fmt.Fprintf(out, "size += len(value) * 2") - case "I32Array": fmt.Fprintf(out, "size += len(value) * 4") - case "I64Array": fmt.Fprintf(out, "size += len(value) * 8") - case "U8Array": fmt.Fprintf(out, "size += len(value)") - case "U16Array": fmt.Fprintf(out, "size += len(value) * 2") - case "U32Array": fmt.Fprintf(out, "size += len(value) * 4") - case "U64Array": fmt.Fprintf(out, "size += len(value) * 8") - case "String": fmt.Fprintf(out, "size += len(value)") - case "StringArray": - fmt.Fprintf( - out, - "for _, el := range value { size += 2 + len(el) }") - default: - return fmt.Errorf("unknown type: %s", field.Type) - } - return nil -} - -func (this *Protocol) marshalField(out io.Writer, field Field) error { - switch field.Type { - case "I8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name) - case "I16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name) - case "I32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name) - case "I64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name) - case "U8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name) - case "U16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name) - case "U32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name) - case "U64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name) - case "I8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name) - case "I16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name) - case "I32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name) - case "I64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name) - case "U8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name) - case "U16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name) - case "U32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name) - case "U64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name) - case "String": fmt.Fprintf(out, "tape.EncodeString(buffer[offset%s:], value)", field.Name) - case "StringArray": fmt.Fprintf(out, "tape.EncodeStringArray(buffer[offset%s:], value)", field.Name) - default: - return fmt.Errorf("unknown type: %s", field.Type) - } - return nil -} - -func (this *Protocol) unmarshalMessage(out io.Writer, message Message) error { - fmt.Fprintf(out, "// UnmarshalBinary dencodes the data from a buffer int this message.\n") - fmt.Fprintf(out, - "func (msg *Message%s) UnmarshalBinary(buffer []byte) error {\n", - message.Name) - if len(message.Fields) < 1 { - fmt.Fprintf(out, "\t// no fields\n") - fmt.Fprintf(out, "\treturn nil\n") - fmt.Fprintf(out, "}\n\n") - return nil - } - fmt.Fprintf(out, "\tpairs, err := tape.DecodePairs(buffer)\n") - fmt.Fprintf(out, "\tif err != nil { return err }\n") - requiredTotal := 0 - for _, field := range message.Fields { - if field.Optional { - requiredTotal ++ - } - } - if requiredTotal > 0 { - fmt.Fprintf(out, "\tfoundRequired := 0\n") - } - fmt.Fprintf(out, "\tfor tag, data := range pairs {\n") - fmt.Fprintf(out, "\t\tswitch tag {\n") - for _, field := range message.Fields { - fmt.Fprintf(out, "\t\tcase %d:\n", field.Tag) - fmt.Fprintf(out, "\t\t\t") - err := this.unmarshalField(out, field) - if err != nil { return err } - fmt.Fprintf(out, "\n") - fmt.Fprintf(out, "\t\t\tif err != nil { return err }\n") - if field.Optional { - fmt.Fprintf(out, "\t\t\tmsg.%s = hopp.O(value)\n", field.Name) - } else { - fmt.Fprintf(out, "\t\t\tmsg.%s = value\n", field.Name) - if requiredTotal > 0 { - fmt.Fprintf(out, "\t\t\tfoundRequired ++\n") + nn, err := this.iprintf("nn, err = encoder.%s%d(", prefix, typ.Bits) + n += nn; if err != nil { return n, err } + nn, err = this.generateType(typ) // TODO: cast like this for + // every type + n += nn; if err != nil { return n, err } + nn, err = this.printf("(%s))\n", valueSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + case TypeFloat: + // FP: + nn, err := this.iprintf("nn, err = encoder.WriteFloat%d(", typ.Bits) + n += nn; if err != nil { return n, err } + nn, err = this.generateType(typ) + n += nn; if err != nil { return n, err } + nn, err = this.printf("(%s))\n", valueSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + case TypeString: + // see TypeBuffer + nn, err := this.generateEncodeValue(TypeBuffer { }, valueSource, tagSource) + n += nn; if err != nil { return n, err } + case TypeBuffer: + // SBA: * + // LBA: * + nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if %s.Is(tape.LBA) {\n", tagSource) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf( + "nn, err = encoder.WriteUintN(uint64(len(%s)), %s.CN())\n", + valueSource, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("nn, err = encoder.Write([]byte(%s))\n", valueSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + case TypeArray: + // OTA: * + nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( + "nn, err = encoder.WriteUintN(uint64(len(%s)), %s.CN())\n", + valueSource, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("{\n") + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("itemTag := ") + n += nn; if err != nil { return n, err } + nn, err = this.generateTN(typ.Element) + n += nn; if err != nil { return n, err } + nn, err = this.println() + n += nn; if err != nil { return n, err } + // TODO: we don't have to do this for loop for some + // types such as integers because the CN will be the + // same + nn, err = this.iprintf("for _, item := range %s {\n", valueSource) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("_ = item\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("tag := ") + n += nn; if err != nil { return n, err } + nn, err = this.generateTag(typ.Element, "item") + n += nn; if err != nil { return n, err } + nn, err = this.println() + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if tag.Is(tape.SBA) { continue }\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if tag.CN() > itemTag.CN() { itemTag = tag }\n") + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if itemTag.Is(tape.SBA) { itemTag += 1 << 5 }\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("nn, err = encoder.WriteTag(itemTag)\n") + n += nn; if err != nil { return n, err } + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + nn, err = this.iprintf("for _, item := range %s {\n", valueSource) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.generateEncodeValue(typ.Element, "item", "itemTag") + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + case TypeTable: + // KTV: ( )* + nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( + "nn, err = tape.EncodeAny(encoder, %s, %s)\n", + valueSource, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + case TypeTableDefined: + // KTV: ( )* + nn, err := this.iprintf("if %d > tape.MaxStructureLength {\n", len(typ.Fields)) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( + "nn, err = encoder.WriteUintN(%d, %s.CN())\n", + len(typ.Fields), tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("{\n") + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("var tag tape.Tag\n") + n += nn; if err != nil { return n, err } + for key, field := range typ.Fields { + nn, err = this.iprintf("nn, err = encoder.WriteUint16(0x%04X)\n", key) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("tag = ") + n += nn; if err != nil { return n, err } + fieldSource := fmt.Sprintf("%s.%s", valueSource, field.Name) + nn, err = this.generateTag(field.Type, fieldSource) + n += nn; if err != nil { return n, err } + nn, err = this.println() + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(tag))\n") + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + nn, err = this.generateEncodeValue(field.Type, fieldSource, "tag") + n += nn; if err != nil { return n, err } } - } - } - fmt.Fprintf(out, "\t\t}\n") - fmt.Fprintf(out, "\t}\n") - if requiredTotal > 0 { - fmt.Fprintf(out, - "\tif foundRequired != %d { return hopp.ErrTablePairMissing }\n", - requiredTotal) - } - fmt.Fprintf(out, "\treturn nil\n") - fmt.Fprintf(out, "}\n\n") - return nil -} - -func (this *Protocol) unmarshalField(out io.Writer, field Field) error { - typ, err := this.ResolveType(field.Type) - if err != nil { return err } - switch field.Type { - case "I8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ) - case "I16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ) - case "I32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ) - case "I64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ) - case "U8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ) - case "U16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ) - case "U32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ) - case "U64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ) - case "I8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ) - case "I16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ) - case "I32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ) - case "I64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ) - case "U8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ) - case "U16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ) - case "U32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ) - case "U64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ) - case "String": fmt.Fprintf(out, "value, err := tape.DecodeString[%s](data)", typ) - case "StringArray": fmt.Fprintf(out, "value, err := tape.DecodeStringArray[%s](data)", typ) + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + case TypeNamed: + // WHATEVER: [WHATEVER] + nn, err := this.iprintf("nn, err = %s.EncodeValue(encoder, %s)\n", valueSource, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } default: - return fmt.Errorf("unknown type: %s", field.Type) + panic(fmt.Errorf("unknown type: %T", typ)) } - return nil + + return n, nil } -func comment(prefix, text string) string { - return prefix + " " + strings.ReplaceAll(strings.TrimSpace(text), "\n", "\n" + prefix + " ") +// generateDencodeValue generates code to decode a value of a specified type. It +// overwrites memory pointed to by the variable (or parenthetical statement) +// specified by valueSource, and the value will be encoded according to the tag +// stored in the variable (or parenthetical statement) specified by tagSource. +// the code generated is a BLOCK and expects these variables to be defined: +// +// - decoder *tape.Decoder +// - n int +// - err error +// - nn int +// +// The typeName paramterer is handled in the way described in the documentation +// for [Generator.generateDecodeBranch]. +func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagSource string) (n int, err error) { + switch typ := typ.(type) { + case TypeInt: + // SI: (none) + // LI/LSI: + if typ.Bits <= 5 { + // SI stores the value in the tag + if typeName == "" { + nn, err := this.iprintf("*%s = uint8(%s.CN())\n", valueSource, tagSource) + n += nn; if err != nil { return n, err } + } else { + nn, err := this.iprintf("*%s = %s(%s.CN())\n", valueSource, typeName, tagSource) + n += nn; if err != nil { return n, err } + } + break + } + prefix := "ReadUint" + if typ.Signed { + prefix = "ReadInt" + } + destinationVar := this.newTemporaryVar("destination") + nn, err := this.iprintf("var %s ", destinationVar) + n += nn; if err != nil { return n, err } + nn, err = this.generateType(typ) + n += nn; if err != nil { return n, err } + nn, err = this.print("\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("%s, nn, err = decoder.%s%d()\n", destinationVar, prefix, typ.Bits) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + if typeName == "" { + nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar) + n += nn; if err != nil { return n, err } + } else { + nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar) + n += nn; if err != nil { return n, err } + } + case TypeFloat: + // FP: + destinationVar := this.newTemporaryVar("destination") + nn, err := this.iprintf("var %s ", destinationVar) + n += nn; if err != nil { return n, err } + nn, err = this.generateType(typ) + n += nn; if err != nil { return n, err } + nn, err = this.print("\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("%s, nn, err = decoder.ReadFloat%d()\n", destinationVar, typ.Bits) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + if typeName == "" { + nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar) + n += nn; if err != nil { return n, err } + } else { + nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar) + n += nn; if err != nil { return n, err } + } + case TypeString, TypeBuffer: + // SBA: * + // LBA: * + lengthVar := this.newTemporaryVar("length") + nn, err := this.iprintf("var %s uint64\n", lengthVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if %s.Is(tape.LBA) {\n", tagSource) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf( + "%s, nn, err = decoder.ReadUintN(int(%s.CN()))\n", + lengthVar, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("} else {\n") + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf( + "%s = uint64(%s.CN())\n", + lengthVar, tagSource) + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("buffer := make([]byte, %s)\n", lengthVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("nn, err = decoder.Read(buffer)\n") + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + if typeName == "" { + if _, ok := typ.(TypeString); ok { + nn, err = this.iprintf("*%s = string(buffer)\n", valueSource) + n += nn; if err != nil { return n, err } + } else { + nn, err = this.iprintf("*%s = buffer\n", valueSource) + n += nn; if err != nil { return n, err } + } + } else { + nn, err = this.iprintf("*%s = %s(buffer)\n", valueSource, typeName) + n += nn; if err != nil { return n, err } + } + case TypeArray: + // OTA: * + nn, err := this.generateDecodeBranchCall(typ, typeName, valueSource, tagSource) + n += nn; if err != nil { return n, err } + case TypeTable: + // KTV: ( )* + nn, err := this.iprintf( + "nn, err = tape.DecodeAny(decoder, %s, %s)\n", + valueSource, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + case TypeTableDefined: + // KTV: ( )* + nn, err := this.generateDecodeBranchCall(typ, typeName, valueSource, tagSource) + n += nn; if err != nil { return n, err } + case TypeNamed: + // WHATEVER: [WHATEVER] + nn, err := this.iprintf("nn, err = %s.DecodeValue(decoder, %s)\n", valueSource, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + default: + panic(fmt.Errorf("unknown type: %T", typ)) + } + + return n, nil +} + +// generateDecodeBranchCall generates code to call an aggregate decoder function, +// for a specified type. The definition of the function is deferred so no +// duplicates are created. The function overwrites memory pointed to by the +// variable (or parenthetical statement) specified by valueSource, and the value +// will be encoded according to the tag stored in the variable (or parenthetical +// statement) specified by tagSource. the code generated is a BLOCK and expects +// these variables to be defined: +// +// - decoder *tape.Decoder +// - n int +// - err error +// - nn int +// +// The typeName paramterer is handled in the way described in the documentation +// for [Generator.generateDecodeBranch]. +func (this *Generator) generateDecodeBranchCall(typ Type, typeName, valueSource, tagSource string) (n int, err error) { + hash := HashType(typ) + nn, err := this.iprintf( + "nn, err = %s(%s, decoder, %s)\n", + this.decodeBranchName(hash, typeName), valueSource, tagSource) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + this.pushDecodeBranchRequest(hash, typ, typeName) + return n, nil +} + +// generateDecodeBranch generates an aggregate decoder function definition for a +// specified type. It assumes that hash == HashType(typ). If typeName is not +// empty, it will be used as the type in the argument list instead of the result +// of [Generator.generateType]. +func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName string) (n int, err error) { + nn, err := this.iprintf("\nfunc %s(this *", this.decodeBranchName(hash, typeName)) + n += nn; if err != nil { return n, err } + if typeName == "" { + nn, err = this.generateType(typ) + n += nn; if err != nil { return n, err } + } else { + nn, err = this.print(typeName) + n += nn; if err != nil { return n, err } + } + nn, err = this.printf(", decoder *tape.Decoder, tag tape.Tag) (n int, err error) {\n") + n += nn; if err != nil { return n, err } + this.push() + + nn, err = this.iprintf("var nn int\n") + n += nn; if err != nil { return n, err } + + switch typ := typ.(type) { + case TypeArray: + // OTA: * + // read header + lengthVar := this.newTemporaryVar("length") + nn, err := this.iprintf("var %s uint64\n", lengthVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()))\n", lengthVar) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + elementTagVar := this.newTemporaryVar("elementTag") + nn, err = this.iprintf("var %s tape.Tag\n", elementTagVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("%s, nn, err = decoder.ReadTag()\n", elementTagVar) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + + // abort macro + abort := func() (n int, err error) { + // skim entire array + nn, err = this.iprintf("for _ = range %s {\n", lengthVar) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("nn, err = tape.Skim(decoder, %s)\n", elementTagVar) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("return n, nil\n") + n += nn; if err != nil { return n, err } + return n, nil + } + + // validate header + // TODO: here, validate that length is less than the + // max, whatever that is configured to be. the reason we + // want to read it here is that we would have to skip + // the tag anyway so why not. + nn, err = this.iprintf("if !(") + n += nn; if err != nil { return n, err } + nn, err = this.generateCanAssign(typ.Element, elementTagVar) + n += nn; if err != nil { return n, err } + nn, err = this.printf(") {\n") + n += nn; if err != nil { return n, err } + this.push() + nn, err = abort() + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + + // decode payloads + nn, err = this.iprintf("*this = make(") + n += nn; if err != nil { return n, err } + nn, err = this.generateType(typ) + n += nn; if err != nil { return n, err } + nn, err = this.printf(", %s)\n", lengthVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("for index := range int(%s) {\n", lengthVar) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.generateDecodeValue(typ.Element, "", "(&(*this)[index])", elementTagVar) + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + case TypeTableDefined: + // KTV: ( )* + // read header + lengthVar := this.newTemporaryVar("length") + nn, err := this.iprintf("var %s uint64\n", lengthVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()))\n", lengthVar) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + + // validate header + // TODO: here, validate that length is less than the + // max, whatever that is configured to be. if not, stop + // ALL decoding. skimming huge big ass data could cause + // problems + + // read fields + nn, err = this.iprintf("for _ = range %s {\n", lengthVar) + n += nn; if err != nil { return n, err } + this.push() + // read field header + fieldKeyVar := this.newTemporaryVar("fieldKey") + nn, err = this.iprintf("var %s uint16\n", fieldKeyVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("%s, nn, err = decoder.ReadUint16()\n", fieldKeyVar) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + fieldTagVar := this.newTemporaryVar("fieldTag") + nn, err = this.iprintf("var %s tape.Tag\n", fieldTagVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("%s, nn, err = decoder.ReadTag()\n", fieldTagVar) + n += nn; if err != nil { return n, err } + nn, err = this.generateErrorCheck() + n += nn; if err != nil { return n, err } + + // abort field macro + abortField := func() (n int, err error) { + nn, err = this.iprintf("tape.Skim(decoder, %s)\n", fieldTagVar) + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("continue\n") + n += nn; if err != nil { return n, err } + return n, nil + } + + // switch on tag + nn, err = this.iprintf("switch %s {\n", fieldKeyVar) + n += nn; if err != nil { return n, err } + for _, key := range slices.Sorted(maps.Keys(typ.Fields)) { + field := typ.Fields[key] + nn, err = this.iprintf("case 0x%04X:\n", key) + n += nn; if err != nil { return n, err } + this.push() + + // validate field header + nn, err = this.iprintf("if !(") + n += nn; if err != nil { return n, err } + nn, err = this.generateCanAssign(field.Type, fieldTagVar) + n += nn; if err != nil { return n, err } + nn, err = this.printf(") {\n") + n += nn; if err != nil { return n, err } + this.push() + nn, err = abortField() + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + + // decode payload + nn, err = this.generateDecodeValue( + field.Type, "", + fmt.Sprintf("(&(this.%s))", field.Name), fieldTagVar) + n += nn; if err != nil { return n, err } + this.pop() + } + nn, err = this.iprintf("default:\n") + n += nn; if err != nil { return n, err } + this.push() + abortField() + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + + // TODO once options are implemented, have a set of + // bools for each non-optional field, and check here + // that they are all true. a counter will not work + // because if someone specifies a non-optional field + // twice, they can neglect to specify another + // non-optional field and we won't even know because the + // count will still be even. we shouldn't use a map + // either because its an allocation and its way more + // memory than just, like 5 bools (on the stack no less) + default: return n, fmt.Errorf("unexpected type: %T", typ) + } + + nn, err = this.iprintf("return n, nil\n") + + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + return n, nil +} + +func (this *Generator) decodeBranchName(hash [16]byte, name string) string { + if name == "" { + return fmt.Sprintf("decodeBranch_%s", hex.EncodeToString(hash[:])) + } else { + return fmt.Sprintf("decodeBranch_%s_%s", hex.EncodeToString(hash[:]), name) + } +} + +// pushDecodeBranchRequest pushes a new branch decode function request to the +// back of the queue, if it is not already in the queue. +func (this *Generator) pushDecodeBranchRequest(hash [16]byte, typ Type, name string) { + for _, item := range this.decodeBranchRequestQueue { + if item.hash == hash && item.name == name { return } + } + this.decodeBranchRequestQueue = append(this.decodeBranchRequestQueue, decodeBranchRequest { + hash: hash, + typ: typ, + name: name, + }) +} + +// pullDecodeBranchRequest pulls a branch decode function request from the front +// of the queue. +func (this *Generator) pullDecodeBranchRequest() (hash [16]byte, typ Type, name string, ok bool) { + if len(this.decodeBranchRequestQueue) < 1 { + return [16]byte { }, nil, "", false + } + request := this.decodeBranchRequestQueue[0] + this.decodeBranchRequestQueue = this.decodeBranchRequestQueue[1:] + return request.hash, request.typ, request.name, true +} + +func (this *Generator) generateErrorCheck() (n int, err error) { + return this.iprintf("n += nn; if err != nil { return n, err }\n") +} + +// generateTag generates the preferred TN and CN for the given type and value. +// The generated code is INLINE. +func (this *Generator) generateTag(typ Type, source string) (n int, err error) { + switch typ := typ.(type) { + case TypeInt: + if typ.Bits <= 5 { + nn, err := this.printf("tape.SI.WithCN(int(%s))", source) + n += nn; if err != nil { return n, err } + } else if typ.Signed { + nn, err := this.printf("tape.LSI.WithCN(%d)", bitsToCN(typ.Bits)) + n += nn; if err != nil { return n, err } + } else { + nn, err := this.printf("tape.LI.WithCN(%d)", bitsToCN(typ.Bits)) + n += nn; if err != nil { return n, err } + } + case TypeFloat: + nn, err := this.printf("tape.FP.WithCN(%d)", bitsToCN(typ.Bits)) + n += nn; if err != nil { return n, err } + case TypeString: + nn, err := this.printf("tape.StringTag(string(%s))", source) + n += nn; if err != nil { return n, err } + case TypeBuffer: + nn, err := this.printf("tape.BufferTag([]byte(%s))", source) + n += nn; if err != nil { return n, err } + case TypeArray: + nn, err := this.printf("tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))))", source) + n += nn; if err != nil { return n, err } + case TypeTable: + nn, err := this.printf("tape.KTV.WithCN(tape.IntBytes(uint64(len(%s))))", source) + n += nn; if err != nil { return n, err } + case TypeTableDefined: + nn, err := this.printf("tape.KTV.WithCN(%d)", tape.IntBytes(uint64(len(typ.Fields)))) + n += nn; if err != nil { return n, err } + case TypeNamed: + resolved, err := this.resolveTypeName(typ.Name) + if err != nil { return n, err } + nn, err := this.generateTag(resolved, source) + n += nn; if err != nil { return n, err } + default: + panic(fmt.Errorf("unknown type: %T", typ)) + } + + return n, nil +} + +// generateTN generates the appropriate TN for the given type. The generated +// code is INLINE. The generated tag will have a CN as zero. For types that +// change TN based on their length, the TN capable of supporting more +// information is chosen. +func (this *Generator) generateTN(typ Type) (n int, err error) { + switch typ := typ.(type) { + case TypeInt: + if typ.Bits <= 5 { + nn, err := this.printf("tape.SI") + n += nn; if err != nil { return n, err } + } else if typ.Signed { + nn, err := this.printf("tape.LSI") + n += nn; if err != nil { return n, err } + } else { + nn, err := this.printf("tape.LI") + n += nn; if err != nil { return n, err } + } + case TypeFloat: + nn, err := this.printf("tape.FP",) + n += nn; if err != nil { return n, err } + case TypeString: + nn, err := this.generateTN(TypeBuffer { }) + n += nn; if err != nil { return n, err } + case TypeBuffer: + nn, err := this.printf("tape.LBA") + n += nn; if err != nil { return n, err } + case TypeArray: + nn, err := this.printf("tape.OTA") + n += nn; if err != nil { return n, err } + case TypeTable: + nn, err := this.printf("tape.KTV") + n += nn; if err != nil { return n, err } + case TypeTableDefined: + nn, err := this.printf("tape.KTV") + n += nn; if err != nil { return n, err } + case TypeNamed: + resolved, err := this.resolveTypeName(typ.Name) + if err != nil { return n, err } + nn, err := this.generateTN(resolved) + n += nn; if err != nil { return n, err } + } + + return n, nil +} + +func (this *Generator) generateType(typ Type) (n int, err error) { + switch typ := typ.(type) { + case TypeInt: + if err := this.validateIntBitSize(typ.Bits); err != nil { + return n, err + } + if typ.Bits <= 5 { + nn, err := this.printf("uint8") + n += nn; if err != nil { return n, err } + break + } + if typ.Signed { + nn, err := this.printf("int%d", typ.Bits) + n += nn; if err != nil { return n, err } + } else { + nn, err := this.printf("uint%d", typ.Bits) + n += nn; if err != nil { return n, err } + } + case TypeFloat: + switch typ.Bits { + case 16: + nn, err := this.print("float32") + n += nn; if err != nil { return n, err } + case 32, 64: + nn, err := this.printf("float%d", typ.Bits) + n += nn; if err != nil { return n, err } + default: + return n, fmt.Errorf("floats of size %d are unsupported on this platform", typ.Bits) + } + case TypeString: + nn, err := this.print("string") + n += nn; if err != nil { return n, err } + case TypeBuffer: + nn, err := this.print("[]byte") + n += nn; if err != nil { return n, err } + case TypeArray: + nn, err := this.print("[]") + n += nn; if err != nil { return n, err } + nn, err = this.generateType(typ.Element) + n += nn; if err != nil { return n, err } + case TypeTable: + nn, err := this.print("Table") + n += nn; if err != nil { return n, err } + case TypeTableDefined: + nn, err := this.generateTypeTableDefined(typ) + n += nn; if err != nil { return n, err } + case TypeNamed: + nn, err := this.print(typ.Name) + n += nn; if err != nil { return n, err } + } + return n, nil +} + +func (this *Generator) generateTypeTableDefined(typ TypeTableDefined) (n int, err error) { + nn, err := this.print("struct {\n") + n += nn; if err != nil { return n, err } + this.push() + + for _, key := range slices.Sorted(maps.Keys(typ.Fields)) { + field := typ.Fields[key] + nn, err := this.iprintf("%s ", field.Name) + n += nn; if err != nil { return n, err } + nn, err = this.generateType(field.Type) + n += nn; if err != nil { return n, err } + nn, err = this.print("\n") + n += nn; if err != nil { return n, err } + } + + this.pop() + nn, err = this.iprint("}") + n += nn; if err != nil { return n, err } + return n, nil +} + +// generateCanAssign generates an expression which checks if the tag specified +// by tagSource can be assigned to a Go destination generated from typ. The +// generated code is INLINE. +func (this *Generator) generateCanAssign(typ Type, tagSource string) (n int, err error) { + nn, err := this.printf("canAssign(") + n += nn; if err != nil { return n, err } + nn, err = this.generateTN(typ) + n += nn; if err != nil { return n, err } + nn, err = this.printf(", %s)", tagSource) + n += nn; if err != nil { return n, err } + return n, nil +} + +func (this *Generator) validateIntBitSize(size int) error { + switch size { + case 5, 8, 16, 32, 64: return nil + default: return fmt.Errorf("integers of size %d are unsupported on this platform", size) + } +} + +func (this *Generator) validateFloatBitSize(size int) error { + switch size { + case 16, 32, 64: return nil + default: return fmt.Errorf("floats of size %d are unsupported on this platform", size) + } +} + +func (this *Generator) push() { + this.nestingLevel ++ +} + +func (this *Generator) pop() { + if this.nestingLevel < 1 { + panic("cannot pop when nesting level is less than 1") + } + this.nestingLevel -- +} + +func (this *Generator) indent() string { + return strings.Repeat("\t", this.nestingLevel) +} + +func (this *Generator) print(args ...any) (n int, err error) { + return fmt.Fprint(this.Output, args...) +} + +func (this *Generator) println(args ...any) (n int, err error) { + return fmt.Fprintln(this.Output, args...) +} + +func (this *Generator) printf(format string, args ...any) (n int, err error) { + return fmt.Fprintf(this.Output, format, args...) +} + +func (this *Generator) iprint(args ...any) (n int, err error) { + return fmt.Fprint(this.Output, this.indent() + fmt.Sprint(args...)) +} + +func (this *Generator) iprintln(args ...any) (n int, err error) { + return fmt.Fprintln(this.Output, this.indent() + fmt.Sprint(args...)) +} + +func (this *Generator) iprintf(format string, args ...any) (n int, err error) { + return fmt.Fprintf(this.Output, this.indent() + format, args...) +} + +func (this *Generator) resolveMessageName(message string) string { + return "Message" + message +} + +func (this *Generator) resolveTypeName(name string) (Type, error) { + if typ, ok := this.protocol.Types[name]; ok { + if typ, ok := typ.(TypeNamed); ok { + return this.resolveTypeName(typ.Name) + } + + return typ, nil + } + return nil, fmt.Errorf("no type exists called %s", name) +} + +func (this *Generator) newTemporaryVar(base string) string { + this.temporaryVar += 1 + return fmt.Sprintf("%s_%d", base, this.temporaryVar) +} + +func bitsToBytes(bits int) int { + return int(math.Ceil(float64(bits) / 8.0)) +} + +func bitsToCN(bits int) int { + return bitsToBytes(bits) - 1 } diff --git a/generate/generate_test.go b/generate/generate_test.go new file mode 100644 index 0000000..a4d0706 --- /dev/null +++ b/generate/generate_test.go @@ -0,0 +1,381 @@ +package generate + +// import "fmt" +import "testing" + +// TODO: once everything has been ironed out, test that the public API of the +// generator is equal to something specific + +var exampleProtocol = defaultProtocol() + +func init() { + exampleProtocol.Messages[0x0000] = Message { + Name: "Connect", + Type: TypeTableDefined { + Fields: map[uint16] Field { + 0x0000: Field { Name: "Name", Type: TypeString { } }, + 0x0001: Field { Name: "Password", Type: TypeString { } }, + }, + }, + } + exampleProtocol.Messages[0x0001] = Message { + Name: "UserList", + Type: TypeTableDefined { + Fields: map[uint16] Field { + 0x0000: Field { Name: "Users", Type: TypeArray { Element: TypeNamed { Name: "User" } } }, + }, + }, + } + exampleProtocol.Messages[0x0002] = Message { + Name: "Pulse", + Type: TypeTableDefined { + Fields: map[uint16] Field { + 0x0000: Field { Name: "Index", Type: TypeInt { Bits: 5 } }, + 0x0001: Field { Name: "Offset", Type: TypeInt { Bits: 16, Signed: true }}, + 0x0002: Field { Name: "X", Type: TypeFloat { Bits: 16 }}, + 0x0003: Field { Name: "Y", Type: TypeFloat { Bits: 32 }}, + 0x0004: Field { Name: "Z", Type: TypeFloat { Bits: 64 }}, + }, + }, + } + exampleProtocol.Messages[0x0003] = Message { + Name: "NestedArray", + Type: TypeArray { Element: TypeArray { Element: TypeInt { Bits: 8 } } }, + } + exampleProtocol.Messages[0x0004] = Message { + Name: "Integers", + Type: TypeTableDefined { + Fields: map[uint16] Field { + 0x0000: Field { Name: "U5", Type: TypeInt { Bits: 5 } }, + 0x0001: Field { Name: "U8", Type: TypeInt { Bits: 8 } }, + 0x0002: Field { Name: "U16", Type: TypeInt { Bits: 16 } }, + 0x0003: Field { Name: "U32", Type: TypeInt { Bits: 32 } }, + 0x0004: Field { Name: "U64", Type: TypeInt { Bits: 64 } }, + 0x0006: Field { Name: "I8", Type: TypeInt { Bits: 8, Signed: true } }, + 0x0007: Field { Name: "I16", Type: TypeInt { Bits: 16, Signed: true } }, + 0x0008: Field { Name: "I32", Type: TypeInt { Bits: 32, Signed: true } }, + 0x0009: Field { Name: "I64", Type: TypeInt { Bits: 64, Signed: true } }, + 0x000B: Field { Name: "NI8", Type: TypeInt { Bits: 8, Signed: true } }, + 0x000C: Field { Name: "NI16",Type: TypeInt { Bits: 16, Signed: true } }, + 0x000D: Field { Name: "NI32",Type: TypeInt { Bits: 32, Signed: true } }, + 0x000E: Field { Name: "NI64",Type: TypeInt { Bits: 64, Signed: true } }, + }, + }, + } + exampleProtocol.Types["User"] = TypeTableDefined { + Fields: map[uint16] Field { + 0x0000: Field { Name: "Name", Type: TypeString { } }, + 0x0001: Field { Name: "Bio", Type: TypeString { } }, + 0x0002: Field { Name: "Followers", Type: TypeInt { Bits: 32 } }, + }, + } +} + +func TestGenerateRunEncodeDecode(test *testing.T) { + testGenerateRun(test, &exampleProtocol, "encode-decode", ` + // imports + `, ` + log.Println("MessageConnect") + messageConnect := MessageConnect { + Name: "rarity", + Password: "gems", + } + testEncodeDecode( + &messageConnect, + tu.S(0xE1, 0x02).AddVar( + []byte { 0x00, 0x00, 0x86, 'r', 'a', 'r', 'i', 't', 'y' }, + []byte { 0x00, 0x01, 0x84, 'g', 'e', 'm', 's' }, + )) + log.Println("MessageUserList") + messageUserList := MessageUserList { + Users: []User { + User { + Name: "rarity", + Bio: "asdjads", + Followers: 0x324, + }, + User { + Name: "deez nuts", + Bio: "logy", + Followers: 0x8000, + }, + User { + Name: "creekflow", + Bio: "im creekflow", + Followers: 0x3894, + }, + }, + } + testEncodeDecode( + &messageUserList, + tu.S(0xE1, 0x01, 0x00, 0x00, + 0xC1, 0x03, 0xE1, + ).Add(0x03).AddVar( + []byte { 0x00, 0x00, 0x86, 'r', 'a', 'r', 'i', 't', 'y' }, + []byte { 0x00, 0x01, 0x87, 'a', 's', 'd', 'j', 'a', 'd', 's' }, + []byte { 0x00, 0x02, 0x23, 0x00, 0x00, 0x03, 0x24 }, + ).Add(0x03).AddVar( + []byte { 0x00, 0x00, 0x89, 'd', 'e', 'e', 'z', ' ', 'n', 'u', 't', 's' }, + []byte { 0x00, 0x01, 0x84, 'l', 'o', 'g', 'y' }, + []byte { 0x00, 0x02, 0x23, 0x00, 0x00, 0x80, 0x00 }, + ).Add(0x03).AddVar( + []byte { 0x00, 0x00, 0x89, 'c', 'r', 'e', 'e', 'k', 'f', 'l', 'o', 'w' }, + []byte { 0x00, 0x01, 0x8C, 'i', 'm', ' ', 'c', 'r', 'e', 'e', 'k', 'f', + 'l', 'o', 'w' }, + []byte { 0x00, 0x02, 0x23, 0x00, 0x00, 0x38, 0x94 }, + )) + log.Println("MessagePulse") + messagePulse := MessagePulse { + Index: 9, + Offset: -0x3521, + X: 45.375, + Y: 294.1, + Z: 384729384.234892034, + } + testEncodeDecode( + &messagePulse, + tu.S(0xE1, 0x05).AddVar( + []byte { 0x00, 0x00, 0x09 }, + []byte { 0x00, 0x01, 0x41, 0xCA, 0xDF }, + []byte { 0x00, 0x02, 0x61, 0x51, 0xAC }, + []byte { 0x00, 0x03, 0x63, 0x43, 0x93, 0x0C, 0xCD }, + []byte { 0x00, 0x04, 0x67, 0x41, 0xB6, 0xEE, 0x81, 0x28, 0x3C, 0x21, 0xE2 }, + )) + log.Println("MessageNestedArray") + uint8s := func(n int) []uint8 { + array := make([]uint8, n) + for index := range array { + array[index] = uint8(index + 1) | 0xF0 + } + return array + } + messageNestedArray := MessageNestedArray { + uint8s(6), + uint8s(35), + } + testEncodeDecode( + &messageNestedArray, + tu.S(0xC1, 0x02, 0xC1, + 0x06, 0x20, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, + 35, 0x20, 0xF1, 0xF2, 0xF3, 0xF4, 0xF5, 0xF6, + 0xF7, 0xF8, 0xF9, 0xFA, 0xFB, 0xFC, + 0xFD, 0xFE, 0xFF, 0xF0, 0xF1, 0xF2, + 0xF3, 0xF4, 0xF5, 0xF6, 0xF7, 0xF8, + 0xF9, 0xFA, 0xFB, 0xFC, 0xFD, 0xFE, + 0xFF, 0xF0, 0xF1, 0xF2, 0xF3)) + log.Println("MessageIntegers") + messageIntegers := MessageIntegers { + U5: 0x13, + U8: 0xC9, + U16: 0x34C9, + U32: 0x10E134C9, + U64: 0x639109BC10E134C9, + I8: 0x35, + I16: 0x34C9, + I32: 0x10E134C9, + I64: 0x639109BC10E134C9, + NI8: -0x35, + NI16: -0x34C9, + NI32: -0x10E134C9, + NI64: -0x639109BC10E134C9, + } + testEncodeDecode( + &messageIntegers, + tu.S(0xE1, 13).AddVar( + []byte { 0x00, 0x00, 0x13 }, + []byte { 0x00, 0x01, 0x20, 0xC9 }, + []byte { 0x00, 0x02, 0x21, 0x34, 0xC9 }, + []byte { 0x00, 0x03, 0x23, 0x10, 0xE1, 0x34, 0xC9 }, + []byte { 0x00, 0x04, 0x27, 0x63, 0x91, 0x09, 0xBC, 0x10, 0xE1, 0x34, 0xC9 }, + []byte { 0x00, 0x06, 0x40, 0x35 }, + []byte { 0x00, 0x07, 0x41, 0x34, 0xC9 }, + []byte { 0x00, 0x08, 0x43, 0x10, 0xE1, 0x34, 0xC9 }, + []byte { 0x00, 0x09, 0x47, 0x63, 0x91, 0x09, 0xBC, 0x10, 0xE1, 0x34, 0xC9 }, + []byte { 0x00, 0x0B, 0x40, 0xCB }, + []byte { 0x00, 0x0C, 0x41, 0xCB, 0x37 }, + []byte { 0x00, 0x0D, 0x43, 0xEF, 0x1E, 0xCB, 0x37 }, + []byte { 0x00, 0x0E, 0x47, 0x9C, 0x6E, 0xF6, 0x43, 0xEF, 0x1E, 0xCB, 0x37 }, + )) + `) +} + +func TestGenerateRunDecodeWrongType(test *testing.T) { + protocol := defaultProtocol() + protocol.Messages[0x0000] = Message { + Name: "Uint5", + Type: TypeInt { Bits: 5 }, + } + protocol.Messages[0x0001] = Message { + Name: "Uint8", + Type: TypeInt { Bits: 8 }, + } + protocol.Messages[0x0002] = Message { + Name: "Uint16", + Type: TypeInt { Bits: 16 }, + } + protocol.Messages[0x0003] = Message { + Name: "Uint32", + Type: TypeInt { Bits: 32 }, + } + protocol.Messages[0x0004] = Message { + Name: "Uint64", + Type: TypeInt { Bits: 64 }, + } + protocol.Messages[0x0005] = Message { + Name: "Int8", + Type: TypeInt { Bits: 8 }, + } + protocol.Messages[0x0006] = Message { + Name: "Int16", + Type: TypeInt { Bits: 16 }, + } + protocol.Messages[0x0007] = Message { + Name: "Int32", + Type: TypeInt { Bits: 32 }, + } + protocol.Messages[0x0008] = Message { + Name: "Int64", + Type: TypeInt { Bits: 64 }, + } + protocol.Messages[0x0009] = Message { + Name: "String", + Type: TypeString { }, + } + protocol.Messages[0x000A] = Message { + Name: "Buffer", + Type: TypeBuffer { }, + } + protocol.Messages[0x000B] = Message { + Name: "StringArray", + Type: TypeArray { Element: TypeString { } }, + } + protocol.Messages[0x000C] = Message { + Name: "Table", + Type: TypeTable { }, + } + protocol.Messages[0x000D] = Message { + Name: "TableDefined", + Type: TypeTableDefined { + Fields: map[uint16] Field { + 0x0000: Field { Name: "Name", Type: TypeString { } }, + 0x0001: Field { Name: "Password", Type: TypeString { } }, + }, + }, + } + + testGenerateRun(test, &protocol, "decode-wrong-type", ` + // imports + `, ` + datas := [][]byte { + /* int8 */ []byte { byte(tape.LSI.WithCN(0)), 0x45 }, + /* int16 */ []byte { byte(tape.LSI.WithCN(1)), 0x45, 0x67 }, + /* int32 */ []byte { byte(tape.LSI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB }, + /* int64 */ []byte { byte(tape.LSI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 }, + /* uint5 */ []byte { byte(tape.SI.WithCN(12)) }, + /* uint8 */ []byte { byte(tape.LI.WithCN(0)), 0x45 }, + /* uint16 */ []byte { byte(tape.LI.WithCN(1)), 0x45, 0x67 }, + /* uint32 */ []byte { byte(tape.LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB }, + /* uint64 */ []byte { byte(tape.LI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 }, + /* string */ []byte { byte(tape.SBA.WithCN(7)), 'p', 'u', 'p', 'e', 'v', 'e', 'r' }, + /* []byte */ []byte { byte(tape.SBA.WithCN(5)), 'b', 'l', 'a', 'r', 'g' }, + /* []string */ []byte { + byte(tape.OTA.WithCN(0)), 2, byte(tape.LBA.WithCN(0)), + 0x08, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, + 0x05, 0x11, 0x11, 0x11, 0x11, 0x11, + }, + /* map[uint16] any */ []byte { + byte(tape.KTV.WithCN(0)), 2, + 0x02, 0x23, byte(tape.LSI.WithCN(1)), 0x45, 0x67, + 0x02, 0x23, byte(tape.LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB, + }, + } + + + for index, data := range datas { + log.Printf("data %2d %v [%s]", index, tape.Tag(data[0]), tu.HexBytes(data[1:])) + // integers should only assign to other integers + if index > 8 { + cas := func(destination Message) { + n, err := destination.Decode(tape.NewDecoder(bytes.NewBuffer(data))) + if err != nil { log.Fatalf("error: %v | n: %d", err, n) } + reflectValue := reflect.ValueOf(destination).Elem() + if reflectValue.CanInt() { + if reflectValue.Int() != 0 { + log.Fatalf( + "destination not zero: %v", + reflectValue.Elem().Interface()) + } + } else { + if reflectValue.Uint() != 0 { + log.Fatalf( + "destination not zero: %v", + reflectValue.Elem().Interface()) + } + } + if n != len(data) { + log.Fatalf("n not equal: %d != %d", n, len(data)) + } + } + log.Println("- MessageInt8") + { var dest MessageInt8; cas(&dest) } + log.Println("- MessageInt16") + { var dest MessageInt16; cas(&dest) } + log.Println("- MessageInt32") + { var dest MessageInt32; cas(&dest) } + log.Println("- MessageInt64") + { var dest MessageInt64; cas(&dest) } + log.Println("- MessageUint8") + { var dest MessageUint8; cas(&dest) } + log.Println("- MessageUint16") + { var dest MessageUint16; cas(&dest) } + log.Println("- MessageUint32") + { var dest MessageUint32; cas(&dest) } + log.Println("- MessageUint64") + { var dest MessageUint64; cas(&dest) } + } + arrayCase := func(destination Message) { + n, err := destination.Decode(tape.NewDecoder(bytes.NewBuffer(data)),) + if err != nil { log.Fatalf("error: %v | n: %d", err, n) } + reflectDestination := reflect.ValueOf(destination) + reflectValue := reflectDestination.Elem() + if reflectValue.Len() != 0 { + log.Fatalf("len(destination) not zero: %v", reflectValue.Interface()) + } + if n != len(data) { + log.Fatalf("n not equal: %d != %d", n, len(data)) + } + } + anyCase := func(destination Message) { + n, err := destination.Decode(tape.NewDecoder(bytes.NewBuffer(data)),) + if err != nil { log.Fatalf("error: %v | n: %d", err, n) } + reflectDestination := reflect.ValueOf(destination) + reflectValue := reflectDestination.Elem() + if reflectValue == reflect.Zero(reflectValue.Type()) { + log.Fatalf("len(destination) not zero: %v", reflectValue.Interface()) + } + if n != len(data) { + log.Fatalf("n not equal: %d != %d", n, len(data)) + } + } + // SBA/LBA types should only assign to other SBA/LBA types + if index != 9 && index != 10 { + log.Println("- MessageString") + { var dest MessageString; arrayCase(&dest) } + log.Println("- MessageBuffer") + { var dest MessageBuffer; arrayCase(&dest) } + } + // arrays should only assign to other arrays + if index != 11 { + log.Println("- MessageStringArray") + { var dest MessageStringArray; arrayCase(&dest) } + } + // tables should only assign to other tables + if index != 12 { + log.Println("- MessageTable") + { var dest = make(MessageTable); arrayCase(&dest) } + log.Println("- MessageTableDefined") + { var dest MessageTableDefined; anyCase(&dest) } + } + } + `) +} diff --git a/generate/lex.go b/generate/lex.go new file mode 100644 index 0000000..0d9aaf8 --- /dev/null +++ b/generate/lex.go @@ -0,0 +1,230 @@ +package generate + +import "io" +import "bufio" +import "unicode" +import "unicode/utf8" +import "git.tebibyte.media/sashakoshka/goparse" + +const ( + TokenMethod parse.TokenKind = iota + TokenKey + TokenIdent + TokenComma + TokenLBrace + TokenRBrace + TokenLBracket + TokenRBracket +) + +var tokenNames = map[parse.TokenKind] string { + TokenMethod: "Method", + TokenKey: "Key", + TokenIdent: "Ident", + TokenComma: "Comma", + TokenLBrace: "LBrace", + TokenRBrace: "RBrace", + TokenLBracket: "LBracket", + TokenRBracket: "RBracket", +} + +func Lex(fileName string, reader io.Reader) (parse.Lexer, error) { + lex := &lexer { + fileName: fileName, + lineScanner: bufio.NewScanner(reader), + } + lex.nextRune() + return lex, nil +} + +type lexer struct { + fileName string + lineScanner *bufio.Scanner + rune rune + line string + lineFood string + + offset int + row int + column int + + eof bool + +} + +func (this *lexer) Next() (parse.Token, error) { + token, err := this.nextInternal() + if err == io.EOF { err = this.errUnexpectedEOF() } + return token, err +} + +func (this *lexer) nextInternal() (token parse.Token, err error) { + err = this.skipWhitespace() + token.Position = this.pos() + if this.eof { + token.Kind = parse.EOF + err = nil + return + } + if err != nil { return } + + appendRune := func () { + token.Value += string(this.rune) + err = this.nextRune() + } + + doNumber := func () { + for isDigit(this.rune) { + appendRune() + if this.eof { err = nil; return } + if err != nil { return } + } + } + + defer func () { + newPos := this.pos() + newPos.End -- // TODO figure out why tf we have to do this + token.Position = token.Position.Union(newPos) + } () + + switch { + // Method + case this.rune == 'M': + token.Kind = TokenMethod + err = this.nextRune() + if err != nil { return } + doNumber() + if this.eof { err = nil; return } + // Key + case isDigit(this.rune): + token.Kind = TokenKey + doNumber() + if this.eof { err = nil; return } + // Ident + case unicode.IsUpper(this.rune): + token.Kind = TokenIdent + for unicode.IsLetter(this.rune) || isDigit(this.rune) { + appendRune() + if this.eof { err = nil; return } + if err != nil { return } + } + // Comma + case this.rune == ',': + token.Kind = TokenComma + appendRune() + if this.eof { err = nil; return } + // LBrace + case this.rune == '{': + token.Kind = TokenLBrace + appendRune() + if this.eof { err = nil; return } + // RBrace + case this.rune == '}': + token.Kind = TokenRBrace + appendRune() + if this.eof { err = nil; return } + // LBracket + case this.rune == '[': + token.Kind = TokenLBracket + appendRune() + if this.eof { err = nil; return } + // RBracket + case this.rune == ']': + token.Kind = TokenRBracket + appendRune() + if this.eof { err = nil; return } + case unicode.IsPrint(this.rune): + err = parse.Errorf ( + this.pos(), "unexpected rune '%c'", + this.rune) + default: + err = parse.Errorf ( + this.pos(), "unexpected rune %U", + this.rune) + } + + return +} + +func (this *lexer) nextRune() error { + if this.lineFood == "" { + ok := this.lineScanner.Scan() + if ok { + this.line = this.lineScanner.Text() + this.lineFood = this.line + this.rune = '\n' + this.column = 0 + this.row ++ + } else { + err := this.lineScanner.Err() + if err == nil { + this.eof = true + return io.EOF + } else { + return err + } + } + } else { + var ch rune + var size int + for ch == 0 && this.lineFood != "" { + ch, size = utf8.DecodeRuneInString(this.lineFood) + this.lineFood = this.lineFood[size:] + } + this.rune = ch + this.column ++ + } + + return nil +} + +func (this *lexer) skipWhitespace() error { + err := this.skipComment() + if err != nil { return err } + for isWhitespace(this.rune) { + err := this.nextRune() + if err != nil { return err } + err = this.skipComment() + if err != nil { return err } + } + return nil +} + +func (this *lexer) skipComment() error { + if this.rune == ';' { + for this.rune != '\n' { + err := this.nextRune() + if err != nil { return err } + } + } + return nil +} + +func (this *lexer) pos() parse.Position { + return parse.Position { + File: this.fileName, + Line: this.lineScanner.Text(), + Row: this.row - 1, + Start: this.column - 1, + End: this.column, + } +} + +func (this *lexer) errUnexpectedEOF() error { + return parse.Errorf(this.pos(), "unexpected EOF") +} + +func isWhitespace(char rune) bool { + switch char { + case ' ', '\t', '\r', '\n': return true + default: return false + } +} + +func isDigit(char rune) bool { + return char >= '0' && char <= '9' +} + +func isHexDigit(char rune) bool { + return isDigit(char) || char >= 'a' && char <= 'f' || char >= 'A' && char <= 'F' +} diff --git a/generate/lex_test.go b/generate/lex_test.go new file mode 100644 index 0000000..fc4a967 --- /dev/null +++ b/generate/lex_test.go @@ -0,0 +1,54 @@ +package generate + +import "strings" +import "testing" +import "git.tebibyte.media/sashakoshka/goparse" + +func TestLex(test *testing.T) { + lexer, err := Lex("test.pdl", strings.NewReader(` + M0001 User { + 0000 Name String, + 0001 Users []User, + 0002 Followers U32, + }`)) + if err != nil { test.Fatal(parse.Format(err)) } + + correctTokens := []parse.Token { + tok(TokenMethod, "0001"), + tok(TokenIdent, "User"), + tok(TokenLBrace, "{"), + tok(TokenKey, "0000"), + tok(TokenIdent, "Name"), + tok(TokenIdent, "String"), + tok(TokenComma, ","), + tok(TokenKey, "0001"), + tok(TokenIdent, "Users"), + tok(TokenLBracket, "["), + tok(TokenRBracket, "]"), + tok(TokenIdent, "User"), + tok(TokenComma, ","), + tok(TokenKey, "0002"), + tok(TokenIdent, "Followers"), + tok(TokenIdent, "U32"), + tok(TokenComma, ","), + tok(TokenRBrace, "}"), + tok(parse.EOF, ""), + } + + for index, correct := range correctTokens { + got, err := lexer.Next() + if err != nil { test.Fatal(parse.Format(err)) } + if got.Kind != correct.Kind || got.Value != correct.Value { + test.Logf("token %d mismatch", index) + test.Log("GOT:", tokenNames[got.Kind], got.Value) + test.Fatal("CORRECT:", tokenNames[correct.Kind], correct.Value) + } + } +} + +func tok(kind parse.TokenKind, value string) parse.Token { + return parse.Token { + Kind: kind, + Value: value, + } +} diff --git a/generate/misc_test.go b/generate/misc_test.go new file mode 100644 index 0000000..3e08fc8 --- /dev/null +++ b/generate/misc_test.go @@ -0,0 +1,131 @@ +package generate + +import "os" +import "fmt" +import "os/exec" +import "testing" +import "path/filepath" + +func testGenerateRun(test *testing.T, protocol *Protocol, title, imports, testCase string) { + // reset data directory + dir := filepath.Join("test", title) + err := os.RemoveAll(dir) + if err != nil { test.Fatal(err) } + err = os.MkdirAll(dir, 0750) + if err != nil { test.Fatal(err) } + + // open files + sourceFile, err := os.Create(filepath.Join(dir, "protocol.go")) + if err != nil { test.Fatal(err) } + defer sourceFile.Close() + mainFile, err := os.Create(filepath.Join(dir, "main.go")) + if err != nil { test.Fatal(err) } + defer mainFile.Close() + + // generate protocol + generator := Generator { + Output: sourceFile, + PackageName: "main", + } + _, err = generator.Generate(protocol) + if err != nil { test.Fatal(err) } + + // build static source files + imports = ` + import "log" + import "bytes" + import "reflect" + import "git.tebibyte.media/sashakoshka/hopp/tape" + import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil" + ` + imports + setup := `log.Println("*** BEGIN TEST CASE OUTPUT ***")` + teardown := `log.Println("--- END TEST CASE OUTPUT ---")` + static := ` + func testEncode(message Message, correct tu.Snake) { + buffer := bytes.Buffer { } + encoder := tape.NewEncoder(&buffer) + n, err := message.Encode(encoder) + if err != nil { log.Fatalf("at %d: %v\n", n, err) } + encoder.Flush() + got := buffer.Bytes() + log.Printf("got: [%s]", tu.HexBytes(got)) + log.Println("correct:", correct) + if n != len(got) { + log.Fatalf("n incorrect: %d != %d\n", n, len(got)) + } + if ok, n := correct.Check(got); !ok { + log.Fatalln("not equal at", n) + } + } + + func testDecode(correct Message, data any) { + var flat []byte + switch data := data.(type) { + case []byte: flat = data + case tu.Snake: flat = data.Flatten() + } + message := reflect.New(reflect.ValueOf(correct).Elem().Type()).Interface().(Message) + log.Println("before: ", message) + decoder := tape.NewDecoder(bytes.NewBuffer(flat)) + n, err := message.Decode(decoder) + if err != nil { log.Fatalf("at %d: %v\n", n, err) } + log.Println("got: ", message) + log.Println("correct:", correct) + if n != len(flat) { + log.Fatalf("n incorrect: %d != %d\n", n, len(flat)) + } + if !reflect.DeepEqual(message, correct) { + log.Fatalln("not equal") + } + } + + // TODO: possibly combine the two above functions into this one, + // also take a data parameter here (snake) + func testEncodeDecode(message Message, data tu.Snake) {buffer := bytes.Buffer { } + log.Println("encoding:") + encoder := tape.NewEncoder(&buffer) + n, err := message.Encode(encoder) + if err != nil { log.Fatalf("at %d: %v\n", n, err) } + encoder.Flush() + got := buffer.Bytes() + log.Printf("got: [%s]", tu.HexBytes(got)) + log.Println("correct:", data) + if n != len(got) { + log.Fatalf("n incorrect: %d != %d\n", n, len(got)) + } + if ok, n := data.Check(got); !ok { + log.Fatalln("not equal at", n) + } + + log.Println("decoding:") + destination := reflect.New(reflect.ValueOf(message).Elem().Type()).Interface().(Message) + flat := data.Flatten() + log.Println("before: ", destination) + decoder := tape.NewDecoder(bytes.NewBuffer(flat)) + n, err = destination.Decode(decoder) + if err != nil { log.Fatalf("at %d: %v\n", n, err) } + log.Println("got: ", destination) + log.Println("correct:", message) + if n != len(flat) { + log.Fatalf("n incorrect: %d != %d\n", n, len(flat)) + } + if !reflect.DeepEqual(destination, message) { + log.Fatalln("not equal") + } + + } + ` + fmt.Fprintf( + mainFile, "package main\n%s\nfunc main() {\n%s\n%s\n%s\n}\n%s", + imports, setup, testCase, teardown, static) + + // build and run test + command := exec.Command("go", "run", "./" + filepath.Join("generate", dir)) + workingDirAbs, err := filepath.Abs("..") + if err != nil { test.Fatal(err) } + command.Dir = workingDirAbs + command.Env = os.Environ() + output, err := command.CombinedOutput() + test.Logf("output of %v:\n%s", command, output) + if err != nil { test.Fatal(err) } +} diff --git a/generate/parse.go b/generate/parse.go new file mode 100644 index 0000000..e6e1ca3 --- /dev/null +++ b/generate/parse.go @@ -0,0 +1,207 @@ +package generate + +import "io" +import "strconv" +import "git.tebibyte.media/sashakoshka/goparse" + +func Parse(lx parse.Lexer) (*Protocol, error) { + protocol := defaultProtocol() + par := parser { + Parser: parse.Parser { + Lexer: lx, + TokenNames: tokenNames, + }, + protocol: &protocol, + } + err := par.parse() + if err != nil { return nil, err } + return par.protocol, nil +} + +func defaultProtocol() Protocol { + return Protocol { + Messages: make(map[uint16] Message), + Types: map[string] Type { }, + } +} + +func ParseReader(fileName string, reader io.Reader) (*Protocol, error) { + lx, err := Lex(fileName, reader) + if err != nil { return nil, err } + return Parse(lx) +} + +type parser struct { + parse.Parser + protocol *Protocol +} + +func (this *parser) parse() error { + err := this.Next() + if err != nil { return err } + for this.Token.Kind != parse.EOF { + err = this.parseTopLevel() + if err != nil { return err } + } + return nil +} + +func (this *parser) parseTopLevel() error { + err := this.ExpectDesc("message or typedef", TokenMethod, TokenIdent) + if err != nil { return err } + if this.EOF() { return nil } + + switch this.Kind() { + case TokenMethod: return this.parseMessage() + case TokenIdent: return this.parseTypedef() + } + panic("bug") +} + +func (this *parser) parseMessage() error { + err := this.Expect(TokenMethod) + if err != nil { return err } + method, err := this.parseHexNumber(this.Value(), 0xFFFF) + if err != nil { return err } + err = this.ExpectNext(TokenIdent) + if err != nil { return err } + name := this.Value() + err = this.Next() + if err != nil { return err } + typ, err := this.parseType() + if err != nil { return err } + this.protocol.Messages[uint16(method)] = Message { + Name: name, + Type: typ, + } + return nil +} + +func (this *parser) parseTypedef() error { + err := this.Expect(TokenIdent) + if err != nil { return err } + name := this.Value() + err = this.Next() + if err != nil { return err } + typ, err := this.parseType() + if err != nil { return err } + this.protocol.Types[name] = typ + return nil +} + +func (this *parser) parseType() (Type, error) { + err := this.ExpectDesc("type", TokenIdent, TokenLBracket, TokenLBrace) + if err != nil { return nil, err } + + switch this.Kind() { + case TokenIdent: + switch this.Value() { + case "U8": return TypeInt { Bits: 8 }, this.Next() + case "U16": return TypeInt { Bits: 16 }, this.Next() + case "U32": return TypeInt { Bits: 32 }, this.Next() + case "U64": return TypeInt { Bits: 64 }, this.Next() + case "U128": return TypeInt { Bits: 128 }, this.Next() + case "U256": return TypeInt { Bits: 256 }, this.Next() + case "I8": return TypeInt { Bits: 8, Signed: true }, this.Next() + case "I16": return TypeInt { Bits: 16, Signed: true }, this.Next() + case "I32": return TypeInt { Bits: 32, Signed: true }, this.Next() + case "I64": return TypeInt { Bits: 64, Signed: true }, this.Next() + case "I128": return TypeInt { Bits: 128, Signed: true }, this.Next() + case "I256": return TypeInt { Bits: 256, Signed: true }, this.Next() + case "F16": return TypeFloat { Bits: 16 }, this.Next() + case "F32": return TypeFloat { Bits: 32 }, this.Next() + case "F64": return TypeFloat { Bits: 64 }, this.Next() + case "F128": return TypeFloat { Bits: 128 }, this.Next() + case "F256": return TypeFloat { Bits: 256 }, this.Next() + case "String": return TypeString { }, this.Next() + case "Buffer": return TypeBuffer { }, this.Next() + case "Table": return TypeTable { }, this.Next() + } + return this.parseTypeNamed() + case TokenLBracket: + return this.parseTypeArray() + case TokenLBrace: + return this.parseTypeTable() + } + panic("bug") +} + +func (this *parser) parseTypeNamed() (TypeNamed, error) { + err := this.Expect(TokenIdent) + if err != nil { return TypeNamed { }, err } + name := this.Value() + err = this.Next() + if err != nil { return TypeNamed { }, err } + return TypeNamed { Name: name }, nil +} + +func (this *parser) parseTypeArray() (TypeArray, error) { + err := this.Expect(TokenLBracket) + if err != nil { return TypeArray { }, err } + err = this.ExpectNext(TokenRBracket) + if err != nil { return TypeArray { }, err } + err = this.Next() + if err != nil { return TypeArray { }, err } + typ, err := this.parseType() + if err != nil { return TypeArray { }, err } + return TypeArray { Element: typ }, nil +} + +func (this *parser) parseTypeTable() (TypeTableDefined, error) { + err := this.Expect(TokenLBrace) + if err != nil { return TypeTableDefined { }, err } + err = this.Next() + if err != nil { return TypeTableDefined { }, err } + typ := TypeTableDefined { + Fields: make(map[uint16] Field), + } + for { + err := this.ExpectDesc("table field", TokenKey, TokenRBrace) + if err != nil { return TypeTableDefined { }, err } + if this.Is(TokenRBrace) { + break + } + key, field, err := this.parseField() + if err != nil { return TypeTableDefined { }, err } + typ.Fields[key] = field + err = this.Expect(TokenComma, TokenRBrace) + if err != nil { return TypeTableDefined { }, err } + if this.Is(TokenRBrace) { + break + } + err = this.Next() + if err != nil { return TypeTableDefined { }, err } + } + err = this.Next() + if err != nil { return TypeTableDefined { }, err } + return typ, nil +} + +func (this *parser) parseField() (uint16, Field, error) { + err := this.Expect(TokenKey) + if err != nil { return 0, Field { }, err } + key, err := this.parseHexNumber(this.Value(), 0xFFFF) + if err != nil { return 0, Field { }, err } + err = this.ExpectNext(TokenIdent) + if err != nil { return 0, Field { }, err } + name := this.Value() + err = this.Next() + if err != nil { return 0, Field { }, err } + typ, err := this.parseType() + if err != nil { return 0, Field { }, err } + return uint16(key), Field { + Name: name, + Type: typ, + }, nil +} + +func (this *parser) parseHexNumber(input string, maxValue int64) (int64, error) { + number, err := strconv.ParseInt(input, 16, 64) + if err != nil { + return 0, parse.Errorf(this.Pos(), "%v", err) + } + if maxValue > 0 && number > maxValue { + return 0, parse.Errorf(this.Pos(), "value too large (max %X)", maxValue) + } + return number, nil +} diff --git a/generate/parse_test.go b/generate/parse_test.go new file mode 100644 index 0000000..a447ebb --- /dev/null +++ b/generate/parse_test.go @@ -0,0 +1,68 @@ +package generate + +import "fmt" +import "strings" +import "testing" +import "git.tebibyte.media/sashakoshka/goparse" + +func TestParse(test *testing.T) { + correct := defaultProtocol() + correct.Messages[0x0000] = Message { + Name: "Connect", + Type: TypeTableDefined { + Fields: map[uint16] Field { + 0x0000: Field { Name: "Name", Type: TypeString { } }, + 0x0001: Field { Name: "Password", Type: TypeString { } }, + }, + }, + } + correct.Messages[0x0001] = Message { + Name: "UserList", + Type: TypeTableDefined { + Fields: map[uint16] Field { + 0x0000: Field { Name: "Users", Type: TypeArray { Element: TypeNamed { Name: "User" } } }, + }, + }, + } + correct.Types["User"] = TypeTableDefined { + Fields: map[uint16] Field { + 0x0000: Field { Name: "Name", Type: TypeString { } }, + 0x0001: Field { Name: "Bio", Type: TypeString { } }, + 0x0002: Field { Name: "Followers", Type: TypeInt { Bits: 32 } }, + }, + } + test.Log("CORRECT:", &correct) + + got, err := ParseReader("test.pdl", strings.NewReader(` + M0000 Connect { + 0000 Name String, + 0001 Password String, + } + + M0001 UserList { + 0000 Users []User, + } + + User { + 0000 Name String, + 0001 Bio String, + 0002 Followers U32, + } + `)) + if err != nil { test.Fatal(parse.Format(err)) } + test.Log("GOT: ", got) + + correctStr := fmt.Sprint(&correct) + gotStr := fmt.Sprint(got) + + if correctStr != gotStr { + test.Error("not equal") + for index := range min(len(correctStr), len(gotStr)) { + if correctStr[index] == gotStr[index] { continue } + test.Log("C:", correctStr[max(0, index - 8):min(len(correctStr), index + 8)]) + test.Log("G:", gotStr[max(0, index - 8):min(len(gotStr), index + 8)]) + break + } + test.FailNow() + } +} diff --git a/generate/protocol.go b/generate/protocol.go index 02fd456..1610c86 100644 --- a/generate/protocol.go +++ b/generate/protocol.go @@ -1,244 +1,107 @@ package generate -import "io" import "fmt" -import "errors" -import "strconv" -import "strings" -import "github.com/gomarkdown/markdown" -import "github.com/gomarkdown/markdown/ast" -import "github.com/gomarkdown/markdown/parser" +import "maps" +import "slices" +import "crypto/md5" -// Protocol describes a protocol. type Protocol struct { - Messages []Message + Messages map[uint16] Message + Types map[string] Type } -// Message describes a protocol message. type Message struct { - Doc string - Method uint16 - Name string - Fields []Field + Name string + Type Type +} + +type Type interface { + fmt.Stringer +} + +type TypeInt struct { + Bits int + Signed bool +} + +func (typ TypeInt) String() string { + output := "" + if typ.Signed { + output += "I" + } else { + output += "U" + } + output += fmt.Sprint(typ.Bits) + return output +} + +type TypeFloat struct { + Bits int +} + +func (typ TypeFloat) String() string { + return fmt.Sprintf("F%d", typ.Bits) +} + +type TypeString struct { } + +func (TypeString) String() string { + return "String" +} + +type TypeBuffer struct { } + +func (TypeBuffer) String() string { + return "Buffer" +} + +type TypeArray struct { + Element Type +} + +func (typ TypeArray) String() string { + return fmt.Sprintf("[]%v", typ.Element) +} + +type TypeTable struct { } + +func (TypeTable) String() string { + return "Table" +} + +type TypeTableDefined struct { + Fields map[uint16] Field +} + +func (typ TypeTableDefined) String() string { + output := "{" + for _, key := range slices.Sorted(maps.Keys(typ.Fields)) { + output += fmt.Sprintf("%04X %v", key, typ.Fields[key]) + } + output += "}" + return output } -// Field describes a named value within a message. type Field struct { - Doc string - Tag uint16 - Name string - Optional bool - Type string + Name string + Type Type } -// ParseReader parses a protocol definition from a reader. -func ParseReader(reader io.Reader) (*Protocol, error) { - data, err := io.ReadAll(reader) - if err != nil { return nil, err } - protocol := new(Protocol) - err = protocol.UnmarshalText(data) - if err != nil { return nil, err } - return protocol, nil +func (field Field) String() string { + return fmt.Sprintf("%s %v", field.Name, field.Type) } -// UnmarshalText unmarshals markdown-formatted text data into the protocol. -func (this *Protocol) UnmarshalText(text []byte) error { - var state int; const ( - stateIdle = iota - stateMessage - stateMessageDoc - stateMessageField - ) - - var message *Message - addMessage := func(method uint16, name string) { - this.Messages = append(this.Messages, Message { - Method: method, - Name: name, - }) - message = &this.Messages[len(this.Messages) - 1] - } - - root := markdown.Parse(text, parser.New()) - for _, node := range root.GetChildren() { - if node, ok := node.(*ast.Heading); ok { - if node.Level == 2 { - if removeBreaks(flatten(node)) == "Messages" { - state = stateMessage - continue - } - } - - if node.Level > 3 { - state = stateIdle - continue - } - - if state != stateIdle && node.Level == 3 { - heading := removeBreaks(flatten(node)) - method, name, err := splitMessageHeading(heading) - if err != nil { return err } - addMessage(method, name) - state = stateMessageDoc - } - } - - if state == stateIdle { continue } - if message == nil { continue } - - // TODO when we are adding text content to the doc comment, it - // might be wise to do stuff like indent lists and quotes so - // that go doc renders them correctly - switch node := node.(type) { - case *ast.Paragraph: - if message.Doc != "" { message.Doc += "\n\n" } - message.Doc += removeBreaks(flatten(node)) - case *ast.BlockQuote: - if message.Doc != "" { message.Doc += "\n\n> " } - message.Doc += removeBreaks(flatten(node)) - case *ast.List: - // FIXME format the list - if message.Doc != "" { message.Doc += "\n\n" } - message.Doc += removeBreaks(flatten(node)) - case *ast.Table: - fields, err := processFieldTable(node) - if err != nil { return err} - message.Fields = append(message.Fields, fields...) - } - } - - return nil +type TypeNamed struct { + Name string } -func processFieldTable(node *ast.Table) ([]Field, error) { - fields := []Field { } - children := node.GetChildren() - if len(children) != 2 { - return nil, errors.New("malformed field table") - } - - // get columns - columns := []string { } - if header, ok := children[0].(*ast.TableHeader); ok { - children := header.GetChildren() - if len(children) != 1 { - return nil, errors.New("malformed field table header") - } - if row, ok := header.Children[0].(*ast.TableRow); ok { - for _, cell := range row.GetChildren() { - if cell, ok := cell.(*ast.TableCell); ok { - columns = append(columns, flatten(cell)) - } - } - } else { - return nil, errors.New("malformed field table header") - } - for index, column := range columns { - columns[index] = strings.ToLower(column) - } - } else { - return nil, errors.New("malformed field table: no header") - } - - // get data - if body, ok := children[1].(*ast.TableBody); ok { - for _, node := range body.GetChildren() { - if row, ok := node.(*ast.TableRow); ok { - children := row.GetChildren() - if len(children) != len(columns) { - return nil, errors.New ( - "malformed field table row: wrong " + - "number of columns") - } - - field := Field { } - - for index, node := range children { - if cell, ok := node.(*ast.TableCell); ok { - text := flatten(cell) - switch columns[index] { - case "tag": - tag, err := parseTag(text) - if err != nil { return nil, err } - field.Tag = tag - case "name": - field.Name = text - case "required": - field.Optional = !parseBool(text) - case "optional": - field.Optional = parseBool(text) - case "type": - field.Type = text - case "comment", "purpose", "documentation": - field.Doc = text - } - }} - - fields = append(fields, field) - }} - } else { - return nil, errors.New("malformed field table: no body") - } - return fields, nil +func (typ TypeNamed) String() string { + return typ.Name } -type nodeFlattener struct { - text string -} -func (this *nodeFlattener) String() string { return this.text } -func (this *nodeFlattener) Visit(node ast.Node, entering bool) ast.WalkStatus { - if entering { - if node := node.AsLeaf(); node != nil { - this.text += string(node.Literal) - } - } - return ast.GoToNext -} -func flatten(node ast.Node) string { - flattener := new(nodeFlattener) - ast.Walk(node, flattener) - return flattener.text -} - - -func removeBreaks(text string) string { - text = strings.ReplaceAll(text, "\n", " ") - text = strings.ReplaceAll(text, "\r", "") - return text -} - -func parseBool(text string) bool { - switch(strings.ToLower(text)) { - case "yes": return true - case "no": return false - case "true": return true - case "false": return false - } - return false -} - -func parseTag(text string) (uint16, error) { - tag, err := strconv.ParseUint(text, 10, 16) - if err != nil { - return 0, fmt.Errorf("malformed tag '%s': %w", text, err) - } - return uint16(tag), nil -} - -func splitMessageHeading(text string) (uint16, string, error) { - text = strings.TrimSpace(text) - methodText, name, ok := strings.Cut(text, " ") - if !ok { - return 0, "", fmt.Errorf( - "malformed message heading '%s': no message name", - text) - } - method, err := strconv.ParseUint(methodText, 16, 16) - if err != nil { - return 0, "", fmt.Errorf( - "malformed method number '%s': %w", - methodText, err) - } - name = strings.TrimSpace(name) - return uint16(method), name, nil +func HashType(typ Type) [16]byte { + // TODO: if we ever want to make the compiler more efficient, this would + // be a good place to start, complex string concatenation in a hot path + // (sorta) + return md5.Sum([]byte(typ.String())) } diff --git a/go.mod b/go.mod index 1acc120..af2d24c 100644 --- a/go.mod +++ b/go.mod @@ -4,19 +4,5 @@ go 1.23.0 require ( git.tebibyte.media/sashakoshka/go-util v0.9.1 - github.com/gomarkdown/markdown v0.0.0-20241205020045-f7e15b2f3e62 - github.com/quic-go/quic-go v0.48.2 -) - -require ( - github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect - github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect - github.com/onsi/ginkgo/v2 v2.9.5 // indirect - go.uber.org/mock v0.4.0 // indirect - golang.org/x/crypto v0.26.0 // indirect - golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect - golang.org/x/mod v0.17.0 // indirect - golang.org/x/net v0.28.0 // indirect - golang.org/x/sys v0.23.0 // indirect - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + git.tebibyte.media/sashakoshka/goparse v0.2.0 ) diff --git a/go.sum b/go.sum index 2f2e05a..e3ad650 100644 --- a/go.sum +++ b/go.sum @@ -1,60 +1,4 @@ git.tebibyte.media/sashakoshka/go-util v0.9.1 h1:eGAbLwYhOlh4aq/0w+YnJcxT83yPhXtxnYMzz6K7xGo= git.tebibyte.media/sashakoshka/go-util v0.9.1/go.mod h1:0Q1t+PePdx6tFYkRuJNcpM1Mru7wE6X+it1kwuOH+6Y= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= -github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/gomarkdown/markdown v0.0.0-20241205020045-f7e15b2f3e62 h1:pbAFUZisjG4s6sxvRJvf2N7vhpCvx2Oxb3PmS6pDO1g= -github.com/gomarkdown/markdown v0.0.0-20241205020045-f7e15b2f3e62/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= -github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= -github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= -github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= -github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= -github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= -github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= -golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= -golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= -golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +git.tebibyte.media/sashakoshka/goparse v0.2.0 h1:uQmKvOCV2AOlCHEDjg9uclZCXQZzq2PxaXfZ1aIMiQI= +git.tebibyte.media/sashakoshka/goparse v0.2.0/go.mod h1:tSQwfuD+EujRoKr6Y1oaRy74ZynatzkRLxjE3sbpCmk= diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..1a3addd --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,173 @@ +package testutil + +import "fmt" +import "slices" +import "strings" +import "reflect" + +// Snake lets you compare blocks of data where the ordering of certain parts may +// be swapped every which way. It is designed for comparing the encoding of +// maps where the ordering of individual elements is inconsistent. +// +// The snake is divided into sectors, which hold a number of variations. For a +// sector to be satisfied by some data, some ordering of it must match the data +// exactly. for the snake to be satisfied by some data, its sectors must match +// the data in order, but the internal ordering of each sector doesn't matter. +type Snake [] [] []byte +// snake sector variation + +// S returns a new snake. +func S(data ...byte) Snake { + return (Snake { }).Add(data...) +} + +// AddVar returns a new snake with the given sector added on to it. Successive +// calls of this method can be chained together to create a big ass snake. +func (sn Snake) AddVar(sector ...[]byte) Snake { + slice := make(Snake, len(sn) + 1) + copy(slice, sn) + slice[len(slice) - 1] = sector + return slice +} + +// Add is like AddVar, but adds a sector with only one variation, which means it +// does not vary, hence why the method is called that. +func (sn Snake) Add(data ...byte) Snake { + return sn.AddVar(data) +} + +// Check determines if the data satisfies the snake. +func (sn Snake) Check(data []byte) (ok bool, n int) { + left := data + variations := map[int] []byte { } + for _, sector := range sn { + clear(variations) + for key, variation := range sector { + variations[key] = variation + } + for len(variations) > 0 { + found := false + for key, variation := range variations { + if len(left) < len(variation) { continue } + if !slices.Equal(left[:len(variation)], variation) { continue } + n += len(variation) + left = data[n:] + delete(variations, key) + found = true + } + if !found { return false, n } + } + } + if n < len(data) { + return false, n + } + return true, n +} + +// Flatten returns the snake flattened to a byte array. The result of this +// function always satisfies the snake. +func (sn Snake) Flatten() []byte { + flat := []byte { } + for _, sector := range sn { + for _, variation := range sector { + flat = append(flat, variation...) + } + } + return flat +} + +func (sn Snake) String() string { + if len(sn) == 0 || len(sn[0]) == 0 || len(sn[0][0]) == 0 { + return "EMPTY" + } + + out := strings.Builder { } + for index, sector := range sn { + if index > 0 { out.WriteString(" : ") } + out.WriteRune('[') + for index, variation := range sector { + if index > 0 { out.WriteString(" / ") } + for _, byt := range variation { + fmt.Fprintf(&out, "%02x", byt) + } + } + out.WriteRune(']') + } + return out.String() +} + +// HexBytes formats bytes into a hexadecimal string. +func HexBytes(data []byte) string { + if len(data) == 0 { return "EMPTY" } + out := strings.Builder { } + for _, byt := range data { + fmt.Fprintf(&out, "%02x", byt) + } + return out.String() +} + +// Describe returns a string representing the type and data of the given value. +func Describe(value any) string { + desc := describer { } + desc.describe(reflect.ValueOf(value)) + return desc.String() +} + +type describer struct { + strings.Builder + indent int +} + +func (this *describer) describe(value reflect.Value) { + value = reflect.ValueOf(value.Interface()) + switch value.Kind() { + case reflect.Array, reflect.Slice: + this.printf("[\n") + this.indent += 1 + for index := 0; index < value.Len(); index ++ { + this.iprintf("") + this.describe(value.Index(index)) + this.iprintf("\n") + } + this.indent -= 1 + this.iprintf("]") + case reflect.Struct: + this.printf("struct {\n") + this.indent += 1 + typ := value.Type() + for index := range typ.NumField() { + indexBuffer := [1]int { index } + this.iprintf("%s: ", typ.Field(index).Name) + this.describe(value.FieldByIndex(indexBuffer[:])) + this.iprintf("\n") + } + this.indent -= 1 + this.iprintf("}\n") + case reflect.Map: + this.printf("map {\n") + this.indent += 1 + iter := value.MapRange() + for iter.Next() { + this.iprintf("") + this.describe(iter.Key()) + this.printf(": ") + this.describe(iter.Value()) + this.iprintf("\n") + } + this.indent -= 1 + this.iprintf("}\n") + case reflect.Pointer: + this.printf("& ") + this.describe(value.Elem()) + default: + this.printf("<%v %v>", value.Type(), value.Interface()) + } +} + +func (this *describer) printf(format string, v ...any) { + fmt.Fprintf(this, format, v...) +} + +func (this *describer) iprintf(format string, v ...any) { + fmt.Fprintf(this, strings.Repeat("\t", this.indent) + format, v...) +} diff --git a/internal/testutil/testutil_test.go b/internal/testutil/testutil_test.go new file mode 100644 index 0000000..663831a --- /dev/null +++ b/internal/testutil/testutil_test.go @@ -0,0 +1,66 @@ +package testutil + +import "testing" + +func TestSnakeA(test *testing.T) { + snake := S(1, 6).AddVar( + []byte { 1 }, + []byte { 2 }, + []byte { 3 }, + []byte { 4 }, + []byte { 5 }, + ).Add(9) + + test.Log(snake) + + ok, n := snake.Check([]byte { 1, 6, 1, 2, 3, 4, 5, 9 }) + if !ok { test.Fatal("false negative:", n) } + ok, n = snake.Check([]byte { 1, 6, 5, 4, 3, 2, 1, 9 }) + if !ok { test.Fatal("false negative:", n) } + ok, n = snake.Check([]byte { 1, 6, 3, 1, 4, 2, 5, 9 }) + if !ok { test.Fatal("false negative:", n) } + + ok, n = snake.Check([]byte { 1, 6, 9 }) + if ok { test.Fatal("false positive:", n) } + ok, n = snake.Check([]byte { 1, 6, 1, 2, 3, 4, 5, 6, 9 }) + if ok { test.Fatal("false positive:", n) } + ok, n = snake.Check([]byte { 1, 6, 0, 2, 3, 4, 5, 6, 9 }) + if ok { test.Fatal("false positive:", n) } + ok, n = snake.Check([]byte { 1, 6, 7, 1, 4, 2, 5, 9 }) + if ok { test.Fatal("false positive:", n) } + ok, n = snake.Check([]byte { 1, 6, 7, 3, 1, 4, 2, 5, 9 }) + if ok { test.Fatal("false positive:", n) } + ok, n = snake.Check([]byte { 1, 6, 7, 3, 1, 4, 2, 5, 9 }) + if ok { test.Fatal("false positive:", n) } + ok, n = snake.Check([]byte { 1, 6, 1, 2, 3, 4, 5, 9, 10}) + if ok { test.Fatal("false positive:", n) } +} + +func TestSnakeB(test *testing.T) { + snake := S(1, 6).AddVar( + []byte { 1 }, + []byte { 2 }, + ).Add(9).AddVar( + []byte { 3, 2 }, + []byte { 0 }, + []byte { 1, 1, 2, 3 }, + ) + + test.Log(snake) + + ok, n := snake.Check([]byte { 1, 6, 1, 2, 9, 3, 2, 0, 1, 1, 2, 3}) + if !ok { test.Fatal("false negative:", n) } + ok, n = snake.Check([]byte { 1, 6, 2, 1, 9, 0, 1, 1, 2, 3, 3, 2}) + if !ok { test.Fatal("false negative:", n) } + + ok, n = snake.Check([]byte { 1, 6, 9 }) + if ok { test.Fatal("false positive:", n) } + ok, n = snake.Check([]byte { 1, 6, 1, 2, 9 }) + if ok { test.Fatal("false positive:", n) } + ok, n = snake.Check([]byte { 1, 6, 9, 3, 2, 0, 1, 1, 2, 3}) + if ok { test.Fatal("false positive:", n) } + ok, n = snake.Check([]byte { 1, 6, 2, 9, 0, 1, 1, 2, 3, 3, 2}) + if ok { test.Fatal("false positive:", n) } + ok, n = snake.Check([]byte { 1, 6, 1, 2, 9, 3, 2, 1, 1, 2, 3}) + if ok { test.Fatal("false positive:", n) } +} diff --git a/listen.go b/listen.go index 09f1a03..4c0681c 100644 --- a/listen.go +++ b/listen.go @@ -1,9 +1,8 @@ package hopp import "net" -import "context" +import "errors" import "crypto/tls" -import "github.com/quic-go/quic-go" // Listener is an object which listens for incoming HOPP connections. type Listener interface { @@ -17,7 +16,8 @@ type Listener interface { } // Listen listens for incoming HOPP connections. The network must be one of -// "quic", "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". +// "quic", "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, quic is +// not supported. func Listen(network, address string) (Listener, error) { switch network { case "quic", "quic4", "quic6": return ListenQUIC(network, address, nil) @@ -30,19 +30,8 @@ func Listen(network, address string) (Listener, error) { // The network must be one of "quic", "quic4", (IPv4-only) or "quic6" // (IPv6-only). func ListenQUIC(network, address string, tlsConf *tls.Config) (Listener, error) { - tlsConf = tlsConfig(tlsConf) - quicConf := quicConfig() - udpNetwork, err := quicNetworkToUDPNetwork(network) - if err != nil { return nil, err } - addr, err := net.ResolveUDPAddr(udpNetwork, address) - if err != nil { return nil, err } - udpListener, err := net.ListenUDP(udpNetwork, addr) - if err != nil { return nil, err } - quicListener, err := quic.Listen(udpListener, tlsConf, quicConf) - if err != nil { return nil, err } - return &listenerQUIC { - underlying: quicListener, - }, nil + // tlsConf = tlsConfig(tlsConf) + return nil, errors.New("quic is not yet implemented") } // ListenUnix listens for incoming HOPP connections using a Unix domain socket @@ -58,24 +47,6 @@ func ListenUnix(network, address string) (Listener, error) { }, nil } -type listenerQUIC struct { - underlying *quic.Listener -} - -func (this *listenerQUIC) Accept() (Conn, error) { - conn, err := this.underlying.Accept(context.Background()) - if err != nil { return nil, err } - return AdaptB(quicMultiConn { underlying: conn }), nil -} - -func (this *listenerQUIC) Close() error { - return this.underlying.Close() -} - -func (this *listenerQUIC) Addr() net.Addr { - return this.underlying.Addr() -} - type listenerUnix struct { underlying *net.UnixListener } diff --git a/message.go b/message.go deleted file mode 100644 index 3eaaa6c..0000000 --- a/message.go +++ /dev/null @@ -1,52 +0,0 @@ -package hopp - -import "fmt" -import "encoding" -import "git.tebibyte.media/sashakoshka/hopp/tape" - -// Message is any object that can be sent or received over a HOPP connection. -type Message interface { - // Method returns the method number of the message. This must be unique - // within the protocol, and should not change between calls. - Method() uint16 - encoding.BinaryMarshaler - encoding.BinaryUnmarshaler -} - -var _ Message = new(MessageData) - -// MessageData represents a message that organizes its data into table pairs. It -// can be used to alter a protocol at runtime, transmit data with arbitrary -// keys, etc. Bear in mind that is less performant than generating code because -// it has to make extra memory allocations and such. -type MessageData struct { - // Methd holds the method number. This should only be set once. - Methd uint16 - // Pairs maps tags to values. - Pairs map[uint16] []byte -} - -// Method returns the message's method field. -func (this *MessageData) Method() uint16 { - return this.Methd -} - -// MarshalBinary implements the [encoding.BinaryMarshaler] interface. The -// message is encoded using TAPE (Table Pair Encoding). -func (this *MessageData) MarshalBinary() ([]byte, error) { - buffer, err := tape.EncodePairs(this.Pairs) - if err != nil { return nil, fmt.Errorf("marshaling MessageData: %w", err) } - return buffer, nil -} - -// UnmarshalBinary implements the [encoding.BinaryUnmarshaler] interface. The -// message is decoded using TAPE (Table Pair Encoding). -func (this *MessageData) UnmarshalBinary(buffer []byte) error { - this.Pairs = make(map[uint16] []byte) - pairs, err := tape.DecodePairs(buffer) - if err != nil { return fmt.Errorf("unmarshaling MessageData: %w", err) } - for key, value := range pairs { - this.Pairs[key] = value - } - return nil -} diff --git a/metadapta.go b/metadapta.go index 1879b15..cae8b12 100644 --- a/metadapta.go +++ b/metadapta.go @@ -1,14 +1,20 @@ package hopp import "io" +import "os" import "fmt" import "net" import "sync" -import "git.tebibyte.media/sashakoshka/hopp/tape" +import "time" +import "sync/atomic" import "git.tebibyte.media/sashakoshka/go-util/sync" +// TODO investigate why 30 never reaches the server, causing it to wait for ever +// and never close the connection, causing the client to also wait forever + const closeMethod = 0xFFFF const int64Max = int64((^uint64(0)) >> 1) +const defaultChunkSize = 0x1000 // Party represents a side of a connection. type Party bool; const ( @@ -16,7 +22,16 @@ type Party bool; const ( ClientSide Party = true ) +func (party Party) String() string { + if party == ServerSide { + return "server" + } else { + return "client" + } +} + type a struct { + sizeLimit int64 underlying net.Conn party Party transID int64 @@ -32,6 +47,7 @@ type a struct { // oriented transport such as TCP or UNIX domain stream sockets. func AdaptA(underlying net.Conn, party Party) Conn { conn := &a { + sizeLimit: defaultSizeLimit, underlying: underlying, party: party, transMap: make(map[int64] *transA), @@ -49,7 +65,7 @@ func AdaptA(underlying net.Conn, party Party) Conn { func (this *a) Close() error { close(this.done) - return this.underlying.Close() + return nil } func (this *a) LocalAddr() net.Addr { @@ -63,30 +79,45 @@ func (this *a) RemoteAddr() net.Addr { func (this *a) OpenTrans() (Trans, error) { this.transLock.Lock() defer this.transLock.Unlock() + if this.transID == int64Max { + return nil, fmt.Errorf("could not open transaction: %w", ErrIntegerOverflow) + } id := this.transID - this.transID ++ trans := &transA { parent: this, id: id, incoming: usync.NewGate[incomingMessage](), } this.transMap[id] = trans - if this.transID == int64Max { - return nil, fmt.Errorf("could not open transaction: %w", ErrIntegerOverflow) + if this.party == ClientSide { + this.transID ++ + } else { + this.transID -- } - this.transID ++ return trans, nil } func (this *a) AcceptTrans() (Trans, error) { + eof := fmt.Errorf("could not accept transaction: %w", io.EOF) select { case trans := <- this.transChan: + if trans == nil { + return nil, eof + } return trans, nil case <- this.done: - return nil, fmt.Errorf("could not accept transaction: %w", io.EOF) + return nil, eof } } +func (this *a) SetDeadline(t time.Time) error { + return this.underlying.SetDeadline(t) +} + +func (this *a) SetSizeLimit(limit int64) { + this.sizeLimit = limit +} + func (this *a) unlistTransactionSafe(id int64) { this.transLock.Lock() defer this.transLock.Unlock() @@ -96,27 +127,32 @@ func (this *a) unlistTransactionSafe(id int64) { func (this *a) sendMessageSafe(trans int64, method uint16, data []byte) error { this.sendLock.Lock() defer this.sendLock.Unlock() - return encodeMessageA(this.underlying, trans, method, data) + return encodeMessageA(this.underlying, this.sizeLimit, trans, method, data) } func (this *a) receive() { defer func() { this.underlying.Close() + close(this.transChan) this.transLock.Lock() defer this.transLock.Unlock() for _, trans := range this.transMap { trans.closeDontUnlist() } clear(this.transMap) + this.underlying.Close() }() + + // receive MMBs in a loop and forward them to transactions until shit + // starts closing for { - transID, method, payload, err := decodeMessageA(this.underlying) + transID, method, chunked, payload, err := decodeMessageA(this.underlying, this.sizeLimit) if err != nil { this.err = fmt.Errorf("could not receive message: %w", err) return } - err = this.receiveMultiplex(transID, method, payload) + err = this.multiplexMMB(transID, method, chunked, payload) if err != nil { this.err = fmt.Errorf("could not receive message: %w", err) return @@ -124,7 +160,7 @@ func (this *a) receive() { } } -func (this *a) receiveMultiplex(transID int64, method uint16, payload []byte) error { +func (this *a) multiplexMMB(transID int64, method uint16, chunked bool, payload []byte) error { if transID == 0 { return ErrMessageMalformed } trans, err := func() (*transA, error) { @@ -133,6 +169,12 @@ func (this *a) receiveMultiplex(transID int64, method uint16, payload []byte) er trans, ok := this.transMap[transID] if !ok { + // check if this is a superfluous close message and just + // do nothing if so + if method == closeMethod { + return nil, nil + } + // it is forbidden for the other party to initiate a transaction // with an ID from this party if this.party == partyFromTransID(transID) { @@ -150,28 +192,58 @@ func (this *a) receiveMultiplex(transID int64, method uint16, payload []byte) er }() if err != nil { return err } - trans.incoming.Send(incomingMessage { - method: method, - payload: payload, - }) + if trans == nil { + return nil + } + + if method == closeMethod { + return trans.Close() + } else { + trans.incoming.Send(incomingMessage { + method: method, + chunked: chunked, + payload: payload, + }) + } return nil } +// most methods in transA don't need to be goroutine safe except those marked +// as such type transA struct { - parent *a - id int64 - incoming usync.Gate[incomingMessage] + parent *a + id int64 + incoming usync.Gate[incomingMessage] + currentReader io.Reader + currentWriter io.Closer + writeBuffer []byte + closed atomic.Bool + closeErr error + + deadline *time.Timer + deadlineLock sync.Mutex } func (this *transA) Close() error { + // MUST be goroutine safe err := this.closeDontUnlist() this.parent.unlistTransactionSafe(this.ID()) return err } -func (this *transA) closeDontUnlist() error { - this.Send(closeMethod, nil) - return this.incoming.Close() +func (this *transA) closeWithError(err error) error { + this.closeErr = err + return this.Close() +} + +func (this *transA) closeDontUnlist() (err error) { + // MUST be goroutine safe + this.incoming.Close() + if !this.closed.Load() { + err = this.Send(closeMethod, nil) + } + this.closed.Store(true) + return err } func (this *transA) ID() int64 { @@ -182,58 +254,257 @@ func (this *transA) Send(method uint16, data []byte) error { return this.parent.sendMessageSafe(this.id, method, data) } +func (this *transA) SendWriter(method uint16) (io.WriteCloser, error) { + // close previous writer if necessary + if this.currentWriter != nil { + this.currentWriter.Close() + this.currentWriter = nil + } + + // create new writer + writer := &writerA { + parent: this, + // there is only ever one writer at a time, so they can all + // share a buffer + buffer: this.writeBuffer[:0], + method: method, + chunkSize: defaultChunkSize, + open: true, + } + this.currentWriter = writer + return writer, nil +} + func (this *transA) Receive() (method uint16, data []byte, err error) { - receive := this.incoming.Receive() + method, reader, err := this.ReceiveReader() + if err != nil { return 0, nil, err } + data, err = io.ReadAll(reader) + if err != nil { return 0, nil, err } + return method, data, nil +} + +func (this *transA) ReceiveReader() (uint16, io.Reader, error) { + // if the transaction has been closed, return an appropriate error. + if err := this.errIfClosed(); err != nil { + return 0, nil, err + } + + // drain previous reader if necessary + if this.currentReader != nil { + io.Copy(io.Discard, this.currentReader) + } + + // create new reader + reader := &readerA { + parent: this, + } + method, err := reader.pull() + if err != nil { return 0, nil, err} + this.currentReader = reader + return method, reader, nil +} + +func (this *transA) SetDeadline(t time.Time) error { + this.deadlineLock.Lock() + defer this.deadlineLock.Unlock() + + if t == (time.Time { }) { + if this.deadline != nil { + this.deadline.Stop() + } + return nil + } + + until := time.Until(t) + if this.deadline == nil { + this.deadline.Reset(until) + return nil + } + this.deadline = time.AfterFunc(until, func () { + this.closeWithError(os.ErrDeadlineExceeded) + }) + return nil +} + +// TODO +// func (this *transA) SetReadDeadline(t time.Time) error { +// // TODO +// } +// +// func (this *transA) SetWriteDeadline(t time.Time) error { +// // TODO +// } + +func (this *transA) errIfClosed() error { + if !this.closed.Load() { + return nil + } + return this.bestErr() +} + +func (this *transA) bestErr() error { + if this.parent.err != nil { + return this.parent.err + } + if this.closeErr != nil { + return this.closeErr + } + return io.EOF +} + +type readerA struct { + parent *transA + leftover []byte + eof bool +} + +// pull pulls the next MMB in this message from the transaction. +func (this *readerA) pull() (uint16, error) { + // if the previous message ended the chain, return an io.EOF + if this.eof { + return 0, io.EOF + } + + // get an MMB from the transaction we are a part of + receive := this.parent.incoming.Receive() if receive != nil { if message, ok := <- receive; ok { if message.method != closeMethod { - return message.method, message.payload, nil + this.leftover = append(this.leftover, message.payload...) + if !message.chunked { + this.eof = true + } + return message.method, nil } } } // close and return error on failure - this.Close() - if this.parent.err == nil { - return 0, nil, fmt.Errorf("could not receive message: %w", io.EOF) - } else { - return 0, nil, this.parent.err + this.eof = true + this.parent.Close() + return 0, fmt.Errorf("could not receive message: %w", this.parent.bestErr()) +} + +func (this *readerA) Read(buffer []byte) (int, error) { + if len(this.leftover) == 0 { + if this.eof { return 0, io.EOF } + this.pull() } + + copied := copy(buffer, this.leftover) + this.leftover = this.leftover[copied:] + return copied, nil +} + +type writerA struct { + parent *transA + buffer []byte + method uint16 + chunkSize int64 + open bool +} + +func (this *writerA) Write(data []byte) (n int, err error) { + if !this.open { return 0, io.EOF } + toSend := data + for len(toSend) > 0 { + nn, err := this.writeOne(toSend) + n += nn + toSend = toSend[nn:] + if err != nil { return n, err } + } + return n, nil +} + +func (this *writerA) Close() error { + this.open = false + return nil +} + +func (this *writerA) writeOne(data []byte) (n int, err error) { + data = data[:min(len(data), int(this.chunkSize))] + + // if there is more room, append to the buffer and exit + if int64(len(this.buffer) + len(data)) <= this.chunkSize { + this.buffer = append(this.buffer, data...) + n = len(data) + // if have a full chunk, flush + if int64(len(this.buffer)) == this.chunkSize { + err = this.flush() + if err != nil { return n, err } + } + return n, nil + } + + // if not, flush and store as much as we can in the buffer + err = this.flush() + if err != nil { return n, err } + this.buffer = append(this.buffer, data...) + return n, nil +} + +func (this *writerA) flush() error { + return this.parent.parent.sendMessageSafe(this.parent.id, this.method, this.buffer) } type incomingMessage struct { method uint16 + chunked bool payload []byte } -func encodeMessageA(writer io.Writer, trans int64, method uint16, data []byte) error { - buffer := make([]byte, 12 + len(data)) - tape.EncodeI64(buffer[:8], trans) - tape.EncodeI16(buffer[8:10], method) - length, ok := tape.U16CastSafe(len(data)) - if !ok { return ErrPayloadTooLarge } - tape.EncodeI16(buffer[10:12], length) - copy(buffer[12:], data) +func encodeMessageA( + writer io.Writer, + sizeLimit int64, + trans int64, + method uint16, + data []byte, +) error { + if int64(len(data)) > sizeLimit { + return ErrPayloadTooLarge + } + buffer := make([]byte, 18 + len(data)) + encodeI64(buffer[:8], trans) + encodeI16(buffer[8:10], method) + encodeI64(buffer[10:18], uint64(len(data))) + copy(buffer[18:], data) _, err := writer.Write(buffer) return err } -func decodeMessageA(reader io.Reader) (int64, uint16, []byte, error) { - headerBuffer := [12]byte { } - _, err := io.ReadFull(reader, headerBuffer[:]) - if err != nil { return 0, 0, nil, err } - transID, err := tape.DecodeI64[int64](headerBuffer[:8]) - if err != nil { return 0, 0, nil, err } - method, err := tape.DecodeI16[uint16](headerBuffer[8:10]) - if err != nil { return 0, 0, nil, err } - length, err := tape.DecodeI16[uint16](headerBuffer[10:12]) - if err != nil { return 0, 0, nil, err } - payloadBuffer := make([]byte, int(length)) +func decodeMessageA( + reader io.Reader, + sizeLimit int64, +) ( + transID int64, + method uint16, + chunked bool, + payloadBuffer []byte, + err error, +) { + headerBuffer := [18]byte { } + _, err = io.ReadFull(reader, headerBuffer[:]) + if err != nil { return 0, 0, false, nil, err } + transID, err = decodeI64[int64](headerBuffer[:8]) + if err != nil { return 0, 0, false, nil, err } + method, err = decodeI16[uint16](headerBuffer[8:10]) + if err != nil { return 0, 0, false, nil, err } + size, err := decodeI64[uint64](headerBuffer[10:18]) + if err != nil { return 0, 0, false, nil, err } + chunked, size = splitCCBSize(size) + if size > uint64(sizeLimit) { + return 0, 0, false, nil, ErrPayloadTooLarge + } + payloadBuffer = make([]byte, int(size)) _, err = io.ReadFull(reader, payloadBuffer) - if err != nil { return 0, 0, nil, err } - return transID, method, payloadBuffer, nil + if err != nil { return 0, 0, false, nil, err } + return transID, method, chunked, payloadBuffer, nil } func partyFromTransID(id int64) Party { return id > 0 } + +func splitCCBSize(size uint64) (bool, uint64) { + return size >> 63 > 1, size & 0x7FFFFFFFFFFFFFFF +} diff --git a/metadapta_test.go b/metadapta_test.go index 1b2bc44..62dfdd9 100644 --- a/metadapta_test.go +++ b/metadapta_test.go @@ -24,71 +24,132 @@ func TestConnA(test *testing.T) { "world", "When the impostor is sus!", } - - network := "tcp" - addr := "localhost:7959" - // server - listener, err := net.Listen(network, addr) - if err != nil { test.Fatal(err) } - defer listener.Close() - go func() { - test.Log("SERVER listening") - conn, err := listener.Accept() - if err != nil { test.Error("SERVER", err); return } - defer conn.Close() - a := AdaptA(conn, ServerSide) + clientFunc := func(a Conn) { + test.Log("CLIENT accepting transaction") + trans, err := a.AcceptTrans() + if err != nil { test.Fatal("CLIENT", err) } + test.Log("CLIENT accepted transaction") + test.Cleanup(func() { trans.Close() }) + for method, payload := range payloads { + test.Log("CLIENT waiting...") + gotMethod, gotPayloadBytes, err := trans.Receive() + if err != nil { test.Fatal("CLIENT", err) } + gotPayload := string(gotPayloadBytes) + test.Log("CLIENT m:", gotMethod, "p:", gotPayload) + if int(gotMethod) != method { + test.Errorf("CLIENT method not equal") + } + if gotPayload != payload { + test.Errorf("CLIENT payload not equal") + } + } + test.Log("CLIENT waiting for transaction close...") + gotMethod, gotPayload, err := trans.Receive() + if !errors.Is(err, io.EOF) { + test.Error("CLIENT wrong error:", err) + test.Error("CLIENT method:", gotMethod) + test.Error("CLIENT payload:", gotPayload) + test.Fatal("CLIENT ok byeeeeeeeeeeeee") + } + } + + serverFunc := func(a Conn) { trans, err := a.OpenTrans() if err != nil { test.Error("SERVER", err); return } - defer trans.Close() + test.Cleanup(func() { trans.Close() }) for method, payload := range payloads { - test.Log("SERVER", method, payload) + test.Log("SERVER m:", method, "p:", payload) err := trans.Send(uint16(method), []byte(payload)) if err != nil { test.Error("SERVER", err); return } } - }() + test.Log("SERVER closing connection") + } - // client - test.Log("CLIENT dialing") - conn, err := net.Dial(network, addr) - if err != nil { test.Fatal("CLIENT", err) } - test.Log("CLIENT dialed") - a := AdaptA(conn, ClientSide) - defer a.Close() - test.Log("CLIENT accepting transaction") - trans, err := a.AcceptTrans() - if err != nil { test.Fatal("CLIENT", err) } - test.Log("CLIENT accepted transaction") - defer trans.Close() - for method, payload := range payloads { - test.Log("CLIENT waiting...") - gotMethod, gotPayloadBytes, err := trans.Receive() - if err != nil { test.Fatal("CLIENT", err) } - gotPayload := string(gotPayloadBytes) - test.Log("CLIENT", gotMethod, gotPayload) - if int(gotMethod) != method { - test.Errorf("CLIENT method not equal") - } - if gotPayload != payload { - test.Errorf("CLIENT payload not equal") - } + clientServerEnvironment(test, clientFunc, serverFunc) +} + +func TestTransOpenCloseA(test *testing.T) { + // currently: + // + // | data sent | data recvd | close sent | close recvd + // 10 | X | X | X | server hangs + // 20 | X | X | X | client hangs + // 30 | X | | X | + // + // when a close message is recvd, it tries to push to the trans and + // hangs on trans.incoming.Send, which hangs on sending the value to the + // underlying channel. why is this? + // + // check if we are really getting values from the channel when pulling + // from the trans channel when we are expecting a close. + + clientFunc := func(conn Conn) { + // 10 + trans, err := conn.OpenTrans() + if err != nil { test.Error("CLIENT", err); return } + test.Log("CLIENT sending 10") + trans.Send(10, []byte("hi")) + trans.Close() + + // 20 + test.Log("CLIENT awaiting 20") + trans, err = conn.AcceptTrans() + if err != nil { test.Error("CLIENT", err); return } + test.Cleanup(func() { trans.Close() }) + gotMethod, gotPayload, err := trans.Receive() + if err != nil { test.Error("CLIENT", err); return } + test.Logf("CLIENT m: %d p: %s", gotMethod, gotPayload) + if gotMethod != 20 { test.Error("CLIENT wrong method")} + + // 30 + trans, err = conn.OpenTrans() + if err != nil { test.Error("CLIENT", err); return } + test.Log("CLIENT sending 30") + trans.Send(30, []byte("good")) + trans.Close() } - _, _, err = trans.Receive() - if !errors.Is(err, io.EOF) { - test.Fatal("CLIENT wrong error:", err) + + serverFunc := func(conn Conn) { + // 10 + test.Log("SERVER awaiting 10") + trans, err := conn.AcceptTrans() + if err != nil { test.Error("SERVER", err); return } + test.Cleanup(func() { trans.Close() }) + gotMethod, gotPayload, err := trans.Receive() + if err != nil { test.Error("SERVER", err); return } + test.Logf("SERVER m: %d p: %s", gotMethod, gotPayload) + if gotMethod != 10 { test.Error("SERVER wrong method")} + + // 20 + trans, err = conn.OpenTrans() + if err != nil { test.Error("SERVER", err); return } + test.Log("SERVER sending 20") + trans.Send(20, []byte("hi how r u")) + trans.Close() + + // 30 + test.Log("SERVER awaiting 30") + trans, err = conn.AcceptTrans() + if err != nil { test.Error("SERVER", err); return } + test.Cleanup(func() { trans.Close() }) + gotMethod, gotPayload, err = trans.Receive() + if err != nil { test.Error("SERVER", err); return } + test.Logf("SERVER m: %d p: %s", gotMethod, gotPayload) + if gotMethod != 30 { test.Error("SERVER wrong method")} } - test.Log("CLIENT done") - // TODO test error from trans/connection closed by other side + + clientServerEnvironment(test, clientFunc, serverFunc) } func TestEncodeMessageA(test *testing.T) { buffer := new(bytes.Buffer) payload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 } - err := encodeMessageA(buffer, 0x5800FEABC3104F04, 0x6B12, payload) + err := encodeMessageA(buffer, defaultSizeLimit, 0x5800FEABC3104F04, 0x6B12, payload) correct := []byte { 0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04, 0x6B, 0x12, - 0x00, 0x06, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, } if err != nil { @@ -102,19 +163,19 @@ func TestEncodeMessageA(test *testing.T) { func TestEncodeMessageAErr(test *testing.T) { buffer := new(bytes.Buffer) payload := make([]byte, 0x10000) - err := encodeMessageA(buffer, 0x5800FEABC3104F04, 0x6B12, payload) + err := encodeMessageA(buffer, 0x20, 0x5800FEABC3104F04, 0x6B12, payload) if !errors.Is(err, ErrPayloadTooLarge) { test.Fatalf("wrong error: %v", err) } } func TestDecodeMessageA(test *testing.T) { - transID, method, payload, err := decodeMessageA(bytes.NewReader([]byte { + transID, method, _, payload, err := decodeMessageA(bytes.NewReader([]byte { 0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04, 0x6B, 0x12, - 0x00, 0x06, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, - })) + }), defaultSizeLimit) if err != nil { test.Fatal(err) } @@ -131,13 +192,76 @@ func TestDecodeMessageA(test *testing.T) { } func TestDecodeMessageAErr(test *testing.T) { - _, _, _, err := decodeMessageA(bytes.NewReader([]byte { + _, _, _, _, err := decodeMessageA(bytes.NewReader([]byte { 0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04, 0x6B, 0x12, - 0x01, 0x06, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x06, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, - })) + }), defaultSizeLimit) if !errors.Is(err, io.ErrUnexpectedEOF) { test.Fatalf("wrong error: %v", err) } } + +func TestEncodeDecodeMessageA(test *testing.T) { + correctTransID := int64(2) + correctMethod := uint16(30) + correctPayload := []byte("good") + buffer := bytes.Buffer { } + err := encodeMessageA(&buffer, defaultSizeLimit, correctTransID, correctMethod, correctPayload) + if err != nil { test.Fatal(err) } + transID, method, chunked, payload, err := decodeMessageA(&buffer, defaultSizeLimit) + if got, correct := transID, int64(2); got != correct { + test.Fatalf("not equal: %v %v", got, correct) + } + if got, correct := method, uint16(30); got != correct { + test.Fatalf("not equal: %v %v", got, correct) + } + if chunked { + test.Fatalf("message should not be chunked") + } + if got, correct := payload, correctPayload; !slices.Equal(got, correct) { + test.Fatalf("not equal: %v %v", got, correct) + } +} + +func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), serverFunc func(conn Conn)) { + network := "tcp" + addr := "localhost:7959" + + // server + listener, err := net.Listen(network, addr) + if err != nil { test.Fatal(err) } + test.Cleanup(func() { listener.Close() }) + go func() { + test.Log("SERVER listening") + conn, err := listener.Accept() + if err != nil { test.Error("SERVER", err); return } + defer conn.Close() + test.Cleanup(func() { conn.Close() }) + a := AdaptA(conn, ServerSide) + test.Cleanup(func() { a.Close() }) + + serverFunc(a) + test.Log("SERVER closing") + }() + + // client + test.Log("CLIENT dialing") + conn, err := net.Dial(network, addr) + if err != nil { test.Fatal("CLIENT", err) } + test.Log("CLIENT dialed") + a := AdaptA(conn, ClientSide) + test.Cleanup(func() { a.Close() }) + + clientFunc(a) + + test.Log("CLIENT waiting for connection close...") + trans, err := a.AcceptTrans() + if !errors.Is(err, io.EOF) { + test.Error("CLIENT wrong error:", err) + test.Fatal("CLIENT trans:", trans) + } + test.Log("CLIENT DONE") + conn.Close() +} diff --git a/metadaptb.go b/metadaptb.go index f8bed04..e1118fe 100644 --- a/metadaptb.go +++ b/metadaptb.go @@ -2,19 +2,23 @@ package hopp import "io" import "net" +import "time" +import "bytes" +import "errors" import "context" -import "git.tebibyte.media/sashakoshka/hopp/tape" // B implements METADAPT-B over a multiplexed stream-oriented transport such as // QUIC. type b struct { + sizeLimit int64 underlying MultiConn } -// AdaptB returns a connection implementing METADAPT-B over a singular stream- -// oriented transport such as TCP or UNIX domain stream sockets. +// AdaptB returns a connection implementing METADAPT-B over a multiplexed +// stream-oriented transport such as QUIC. func AdaptB(underlying MultiConn) Conn { return &b { + sizeLimit: defaultSizeLimit, underlying: underlying, } } @@ -34,33 +38,113 @@ func (this *b) RemoteAddr() net.Addr { func (this *b) OpenTrans() (Trans, error) { stream, err := this.underlying.OpenStream() if err != nil { return nil, err } - return transB { underlying: stream }, nil + return this.newTrans(stream), nil } func (this *b) AcceptTrans() (Trans, error) { stream, err := this.underlying.AcceptStream(context.Background()) if err != nil { return nil, err } - return transB { underlying: stream }, nil + return this.newTrans(stream), nil +} + +func (this *b) SetSizeLimit(limit int64) { + this.sizeLimit = limit +} + +func (this *b) SetDeadline(t time.Time) error { + return this.underlying.SetDeadline(t) +} + +func (this *b) newTrans(underlying Stream) *transB { + return &transB { + sizeLimit: this.sizeLimit, + underlying: underlying, + } } type transB struct { - underlying Stream + sizeLimit int64 + underlying Stream + currentData io.Reader + currentWriter *writerB } -func (trans transB) Close() error { - return trans.underlying.Close() +func (this *transB) Close() error { + return this.underlying.Close() } -func (trans transB) ID() int64 { - return trans.underlying.ID() +func (this *transB) ID() int64 { + return this.underlying.ID() } -func (trans transB) Send(method uint16, data []byte) error { - return encodeMessageB(trans.underlying, method, data) +func (this *transB) Send(method uint16, data []byte) error { + return encodeMessageB(this.underlying, this.sizeLimit, method, data) } -func (trans transB) Receive() (uint16, []byte, error) { - return decodeMessageB(trans.underlying) +func (this *transB) SendWriter(method uint16) (io.WriteCloser, error) { + if this.currentWriter != nil { + this.currentWriter.Close() + } + // TODO: come up with a fix that allows us to pipe data through the + // writer. as of now, it just reads whatever is written into a buffer + // and sends the message on close. we should probably introduce chunked + // encoding to METADAPT-B to fix this. the implementation would be + // simpler than on METADAPT-A, but most of the code could just be + // copied over. + writer := &writerB { + parent: this, + method: method, + } + this.currentWriter = writer + return writer, nil +} + +func (this *transB) Receive() (uint16, []byte, error) { + // get a reader for the next message + method, size, data, err := this.receiveReader() + if err != nil { return 0, nil, err } + // read the entire thing + payloadBuffer := make([]byte, int(size)) + _, err = io.ReadFull(data, payloadBuffer) + if err != nil { return 0, nil, err } + // we have used up the reader by now so we can forget it exists + this.currentData = nil + return method, payloadBuffer, nil +} + +func (this *transB) ReceiveReader() (uint16, io.Reader, error) { + method, _, data, err := this.receiveReader() + return method, data, err +} + +func (this *transB) receiveReader() (uint16, int64, io.Reader, error) { + // decode the message + method, size, data, err := decodeMessageB(this.underlying, this.sizeLimit) + if err != nil { return 0, 0, nil, err } + // discard current reader if there is one + if this.currentData == nil { + io.Copy(io.Discard, this.currentData) + } + this.currentData = data + return method, size, data, nil +} + +func (this *transB) SetDeadline(t time.Time) error { + return this.underlying.SetDeadline(t) +} + +type writerB struct { + parent *transB + buffer bytes.Buffer + method uint16 +} + +func (this *writerB) Write(data []byte) (int, error) { + return this.buffer.Write(data) +} + +func (this *writerB) Close() error { + return this.parent.Send(this.method, this.buffer.Bytes()) } // MultiConn represens a multiplexed stream-oriented transport for use in @@ -74,37 +158,56 @@ type MultiConn interface { AcceptStream(context.Context) (Stream, error) // OpenStream opens a new stream. OpenStream() (Stream, error) + // See the documentation for [net.Conn.SetDeadline]. + SetDeadline(time.Time) error } // Stream represents a single stream returned by a [MultiConn]. type Stream interface { // See documentation for [net.Conn]. io.ReadWriteCloser + // See the documentation for [net.Conn.SetDeadline]. + SetDeadline(time.Time) error // ID returns the stream ID ID() int64 } -func encodeMessageB(writer io.Writer, method uint16, data []byte) error { - buffer := make([]byte, 4 + len(data)) - tape.EncodeI16(buffer[:2], method) - length, ok := tape.U16CastSafe(len(data)) - if !ok { return ErrPayloadTooLarge } - tape.EncodeI16(buffer[2:4], length) - copy(buffer[4:], data) +func encodeMessageB(writer io.Writer, sizeLimit int64, method uint16, data []byte) error { + if int64(len(data)) > sizeLimit { + return ErrPayloadTooLarge + } + buffer := make([]byte, 10 + len(data)) + encodeI16(buffer[:2], method) + encodeI64(buffer[2:10], uint64(len(data))) + copy(buffer[10:], data) _, err := writer.Write(buffer) return err } -func decodeMessageB(reader io.Reader) (uint16, []byte, error) { - headerBuffer := [4]byte { } - _, err := io.ReadFull(reader, headerBuffer[:]) - if err != nil { return 0, nil, err } - method, err := tape.DecodeI16[uint16](headerBuffer[:2]) - if err != nil { return 0, nil, err } - length, err := tape.DecodeI16[uint16](headerBuffer[2:4]) - if err != nil { return 0, nil, err } - payloadBuffer := make([]byte, int(length)) - _, err = io.ReadFull(reader, payloadBuffer) - if err != nil { return 0, nil, err } - return method, payloadBuffer, nil +func decodeMessageB( + reader io.Reader, + sizeLimit int64, +) ( + method uint16, + size int64, + data io.Reader, + err error, +) { + headerBuffer := [10]byte { } + _, err = io.ReadFull(reader, headerBuffer[:]) + if err != nil { + if errors.Is(err, io.EOF) { return 0, 0, nil, io.ErrUnexpectedEOF } + return 0, 0, nil, err + } + method, err = decodeI16[uint16](headerBuffer[:2]) + if err != nil { return 0, 0, nil, err } + length, err := decodeI64[uint64](headerBuffer[2:10]) + if err != nil { return 0, 0, nil, err } + if length > uint64(sizeLimit) { + return 0, 0, nil, ErrPayloadTooLarge + } + return method, int64(length), &io.LimitedReader { + R: reader, + N: int64(length), + }, nil } diff --git a/metadaptb_test.go b/metadaptb_test.go index 416f5fd..cad341f 100644 --- a/metadaptb_test.go +++ b/metadaptb_test.go @@ -9,9 +9,9 @@ import "testing" func TestEncodeMessageB(test *testing.T) { buffer := new(bytes.Buffer) payload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 } - err := encodeMessageB(buffer, 0x6B12, payload) + err := encodeMessageB(buffer, defaultSizeLimit, 0x6B12, payload) correct := []byte { - 0x6B, 0x12, + 0x6B, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, } @@ -26,24 +26,25 @@ func TestEncodeMessageB(test *testing.T) { func TestEncodeMessageBErr(test *testing.T) { buffer := new(bytes.Buffer) payload := make([]byte, 0x10000) - err := encodeMessageB(buffer, 0x6B12, payload) + err := encodeMessageB(buffer, 255, 0x6B12, payload) if !errors.Is(err, ErrPayloadTooLarge) { test.Fatalf("wrong error: %v", err) } } func TestDecodeMessageB(test *testing.T) { - method, payload, err := decodeMessageB(bytes.NewReader([]byte { - 0x6B, 0x12, + method, _, data, err := decodeMessageB(bytes.NewReader([]byte { + 0x6B, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, - })) + }), defaultSizeLimit) if err != nil { test.Fatal(err) } if got, correct := method, uint16(0x6B12); got != correct { test.Fatalf("not equal: %v %v", got, correct) } + payload, _ := io.ReadAll(data) correctPayload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 } if got, correct := payload, correctPayload; !slices.Equal(got, correct) { test.Fatalf("not equal: %v %v", got, correct) @@ -51,11 +52,9 @@ func TestDecodeMessageB(test *testing.T) { } func TestDecodeMessageBErr(test *testing.T) { - _, _, err := decodeMessageB(bytes.NewReader([]byte { - 0x6B, 0x12, - 0x01, 0x06, - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, - })) + _, _, _, err := decodeMessageB(bytes.NewReader([]byte { + 0x6B, 0x12, 0x00, 0x00, 0x00, 0x00, + }), defaultSizeLimit) if !errors.Is(err, io.ErrUnexpectedEOF) { test.Fatalf("wrong error: %v", err) } diff --git a/quicwrap.go b/quicwrap.go deleted file mode 100644 index 45b00b3..0000000 --- a/quicwrap.go +++ /dev/null @@ -1,54 +0,0 @@ -package hopp - -import "net" -import "context" -import "github.com/quic-go/quic-go" - -var _ MultiConn = quicMultiConn { } -type quicMultiConn struct { - underlying quic.Connection -} - -func (conn quicMultiConn) Close() error { - return conn.underlying.CloseWithError(0, "good bye") -} - -func (conn quicMultiConn) LocalAddr() net.Addr { - return conn.underlying.LocalAddr() -} - -func (conn quicMultiConn) RemoteAddr() net.Addr { - return conn.underlying.RemoteAddr() -} - -func (conn quicMultiConn) AcceptStream(ctx context.Context) (Stream, error) { - strea, err := conn.underlying.AcceptStream(ctx) - if err != nil { return nil, err } - return quicStream { underlying: strea }, nil -} - -func (conn quicMultiConn) OpenStream() (Stream, error) { - strea, err := conn.underlying.OpenStream() - if err != nil { return nil, err } - return quicStream { underlying: strea }, nil -} - -type quicStream struct { - underlying quic.Stream -} - -func (strea quicStream) Read(buffer []byte) (n int, err error) { - return strea.underlying.Read(buffer) -} - -func (strea quicStream) Write(buffer []byte) (n int, err error) { - return strea.underlying.Read(buffer) -} - -func (strea quicStream) Close() error { - return strea.underlying.Close() -} - -func (strea quicStream) ID() int64 { - return int64(strea.underlying.StreamID()) -} diff --git a/tape/decode.go b/tape/decode.go new file mode 100644 index 0000000..d754e85 --- /dev/null +++ b/tape/decode.go @@ -0,0 +1,192 @@ +package tape + +import "io" +import "math" +import "bufio" + +// Decodable is any type that can decode itself from a decoder. +type Decodable interface { + // Decode reads data from decoder, replacing the data of the object. It + // returns the amount of bytes written, and an error if the write + // stopped early. + Decode(decoder *Decoder) (n int, err error) +} + +// Decoder decodes data from an [io.Reader]. +type Decoder struct { + bufio.Reader +} + +// NewDecoder creates a new decoder that reads from reader. +func NewDecoder(reader io.Reader) *Decoder { + decoder := &Decoder { } + decoder.Reader.Reset(reader) + return decoder +} + +// ReadFull calls [io.ReadFull] on the reader. +func (this *Decoder) ReadFull(buffer []byte) (n int, err error) { + return io.ReadFull(this, buffer) +} + +// ReadInt8 decodes an 8-bit signed integer from the input reader. +func (this *Decoder) ReadInt8() (value int8, n int, err error) { + uncasted, n, err := this.ReadUint8() + return int8(uncasted), n, err +} + +// ReadUint8 decodes an 8-bit unsigned integer from the input reader. +func (this *Decoder) ReadUint8() (value uint8, n int, err error) { + buffer := [1]byte { } + n, err = this.ReadFull(buffer[:]) + return uint8(buffer[0]), n, err +} + +// ReadInt16 decodes an 16-bit signed integer from the input reader. +func (this *Decoder) ReadInt16() (value int16, n int, err error) { + uncasted, n, err := this.ReadUint16() + return int16(uncasted), n, err +} + +// ReadUint16 decodes an 16-bit unsigned integer from the input reader. +func (this *Decoder) ReadUint16() (value uint16, n int, err error) { + buffer := [2]byte { } + n, err = this.ReadFull(buffer[:]) + return uint16(buffer[0]) << 8 | + uint16(buffer[1]), n, err +} + +// ReadInt32 decodes an 32-bit signed integer from the input reader. +func (this *Decoder) ReadInt32() (value int32, n int, err error) { + uncasted, n, err := this.ReadUint32() + return int32(uncasted), n, err +} + +// ReadUint32 decodes an 32-bit unsigned integer from the input reader. +func (this *Decoder) ReadUint32() (value uint32, n int, err error) { + buffer := [4]byte { } + n, err = this.ReadFull(buffer[:]) + return uint32(buffer[0]) << 24 | + uint32(buffer[1]) << 16 | + uint32(buffer[2]) << 8 | + uint32(buffer[3]), n, err +} + +// ReadInt64 decodes an 64-bit signed integer from the input reader. +func (this *Decoder) ReadInt64() (value int64, n int, err error) { + uncasted, n, err := this.ReadUint64() + return int64(uncasted), n, err +} + +// ReadUint64 decodes an 64-bit unsigned integer from the input reader. +func (this *Decoder) ReadUint64() (value uint64, n int, err error) { + buffer := [8]byte { } + n, err = this.ReadFull(buffer[:]) + return uint64(buffer[0]) << 56 | + uint64(buffer[1]) << 48 | + uint64(buffer[2]) << 40 | + uint64(buffer[3]) << 32 | + uint64(buffer[4]) << 24 | + uint64(buffer[5]) << 16 | + uint64(buffer[6]) << 8 | + uint64(buffer[7]), n, err +} + +// ReadIntN decodes an N-byte signed integer from the input reader. +func (this *Decoder) ReadIntN(bytes int) (value int64, n int, err error) { + uncasted, n, err := this.ReadUintN(bytes) + return int64(uncasted), n, err +} + +// ReadUintN decodes an N-byte unsigned integer from the input reader. +func (this *Decoder) ReadUintN(bytes int) (value uint64, n int, err error) { + // TODO: don't make multiple read calls (without allocating) + buffer := [1]byte { } + for bytesLeft := bytes; bytesLeft > 0; bytesLeft -- { + nn, err := this.ReadFull(buffer[:]) + n += nn; if err != nil { return 0, n, err } + value |= uint64(buffer[0]) << ((bytesLeft - 1) * 8) + } + // *read* integers too big, but don't return them. + if bytes > 8 { value = 0 } + return value, n, nil +} + +// ReadFloat16 decodes a 16-bit floating point value from the input reader. +func (this *Decoder) ReadFloat16() (value float32, n int, err error) { + bits, nn, err := this.ReadUint16() + n += nn; if err != nil { return 0, n, err } + return math.Float32frombits(f16bitsToF32bits(bits)), n, nil +} + +// ReadFloat32 decldes a 32-bit floating point value from the input reader. +func (this *Decoder) ReadFloat32() (value float32, n int, err error) { + bits, nn, err := this.ReadUint32() + n += nn; if err != nil { return 0, n, err } + return math.Float32frombits(bits), n, nil +} + +// ReadFloat64 decldes a 64-bit floating point value from the input reader. +func (this *Decoder) ReadFloat64() (value float64, n int, err error) { + bits, nn, err := this.ReadUint64() + n += nn; if err != nil { return 0, n, err } + return math.Float64frombits(bits), n, nil +} + +// ReadTag decodes a [Tag] from the input reader. +func (this *Decoder) ReadTag() (value Tag, n int, err error) { + uncasted, nn, err := this.ReadUint8() + n += nn; if err != nil { return 0, n, err } + return Tag(uncasted), n, nil +} + +// f16bitsToF32bits returns uint32 (float32 bits) converted from specified uint16. +// Taken from https://github.com/x448/float16/blob/v0.8.4/float16 +// +// MIT License +// +// Copyright (c) 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +func f16bitsToF32bits(in uint16) uint32 { + // All 65536 conversions with this were confirmed to be correct + // by Montgomery Edwards⁴⁴⁸ (github.com/x448). + + sign := uint32(in&0x8000) << 16 // sign for 32-bit + exp := uint32(in&0x7c00) >> 10 // exponenent for 16-bit + coef := uint32(in&0x03ff) << 13 // significand for 32-bit + + if exp == 0x1f { + if coef == 0 { + // infinity + return sign | 0x7f800000 | coef + } + // NaN + return sign | 0x7fc00000 | coef + } + + if exp == 0 { + if coef == 0 { + // zero + return sign + } + + // normalize subnormal numbers + exp++ + for coef&0x7f800000 == 0 { + coef <<= 1 + exp-- + } + coef &= 0x007fffff + } + + return sign | ((exp + (0x7f - 0xf)) << 23) | coef +} diff --git a/tape/dynamic.go b/tape/dynamic.go new file mode 100644 index 0000000..501ed63 --- /dev/null +++ b/tape/dynamic.go @@ -0,0 +1,505 @@ +package tape + +// dont smoke reflection, kids!!!!!!!!! +// totally reflectric, reflectrified, etc. this is probably souper slow but +// certainly no slower than the built in json encoder i'd imagine. +// TODO: add support for struct tags: `tape:"0000"`, tape:"0001"` so they can get +// transformed into tables with a defined schema + +// TODO: test all of these smaller functions individually + +import "fmt" +import "reflect" + +var dummyMap map[uint16] any +var dummyBuffer []byte + +type errCantAssign string +func (err errCantAssign) Error() string { + return string(err) +} +func errCantAssignf(format string, v ...any) errCantAssign { + return errCantAssign(fmt.Sprintf(format, v...)) +} + +// EncodeAny encodes an "any" value. Returns an error if the underlying type is +// unsupported. Supported types are: +// +// - int +// - int +// - uint +// - uint +// - string +// - [] +// - map[uint16] +func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) { + // primitives + reflectValue := reflect.ValueOf(value) + switch reflectValue.Kind() { + case reflect.Int: return encoder.WriteInt32(int32(reflectValue.Int())) + case reflect.Uint: return encoder.WriteUint32(uint32(reflectValue.Uint())) + case reflect.Int8: return encoder.WriteInt8(int8(reflectValue.Int())) + case reflect.Uint8: return encoder.WriteUint8(uint8(reflectValue.Uint())) + case reflect.Int16: return encoder.WriteInt16(int16(reflectValue.Int())) + case reflect.Uint16: return encoder.WriteUint16(uint16(reflectValue.Uint())) + case reflect.Int32: return encoder.WriteInt32(int32(reflectValue.Int())) + case reflect.Uint32: return encoder.WriteUint32(uint32(reflectValue.Uint())) + case reflect.Int64: return encoder.WriteInt64(int64(reflectValue.Int())) + case reflect.Uint64: return encoder.WriteUint64(uint64(reflectValue.Uint())) + case reflect.String: + if reflectValue.Len() > MaxStructureLength { + return 0, ErrTooLong + } + return EncodeAny(encoder, []byte(reflectValue.String()), tag) + } + if reflectValue.CanConvert(reflect.TypeOf(dummyBuffer)) { + if reflectValue.Len() > MaxStructureLength { + return 0, ErrTooLong + } + if tag.Is(LBA) { + nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1) + n += nn; if err != nil { return n, err } + } + nn, err := encoder.Write(reflectValue.Bytes()) + n += nn; if err != nil { return n, err } + return n, nil + } + + // aggregates + reflectType := reflect.TypeOf(value) + switch reflectType.Kind() { + case reflect.Slice: + return encodeAnySlice(encoder, value, tag) + // case reflect.Array: + // TODO: we can encode arrays. but can we decode into them? + // that's the fucken question. maybe we just do the first + // return encodeAnySlice(encoder, reflect.ValueOf(value).Slice(0, reflectType.Len()).Interface(), tag) + case reflect.Map: + if reflectValue.Len() > MaxStructureLength { + return 0, ErrTooLong + } + if reflectType.Key() == reflect.TypeOf(uint16(0)) { + return encodeAnyMap(encoder, value, tag) + } + return n, fmt.Errorf("cannot encode map key %T, key must be uint16", value) + } + return n, fmt.Errorf("cannot encode type %T", value) +} + +// DecodeAny decodes data and places it into destination, which must be a +// pointer to a supported type. See [EncodeAny] for a list of supported types. +func DecodeAny(decoder *Decoder, destination any, tag Tag) (n int, err error) { + reflectDestination := reflect.ValueOf(destination) + if reflectDestination.Kind() != reflect.Pointer { + return n, fmt.Errorf("expected pointer destination, not %v", destination) + } + return decodeAny(decoder, reflectDestination.Elem(), tag) +} + +// unknownSlicePlaceholder is inserted by skeletonValue and informs the program +// that the destination for the slice needs to be generated based on the item +// tag in the OTA. +type unknownSlicePlaceholder struct { } +var unknownSlicePlaceholderType = reflect.TypeOf(unknownSlicePlaceholder { }) + +// decodeAny is internal to [DecodeAny]. It takes in an addressable +// [reflect.Value] as the destination. If the decoded value cannot fit in the +// destination, it skims over the payload, leaves the destination empty, and +// returns without an error. +func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) { + n, err = decodeAnyOrError(decoder, destination, tag) + if _, ok := err.(errCantAssign); ok { + if n > 0 { panic(fmt.Sprintf("decodeAnyOrError decoded more than it should: %d", n)) } + nn, err := Skim(decoder, tag) + n += nn; if err != nil { return n, err } + return n, nil + } + return n, err +} + +// decodeAnyOrError is internal to [decodeAny]. It takes in an addressable +// [reflect.Value] as the destination. If the decoded value cannot fit in the +// destination, it decodes nothing and returns an error of type errCantAssign, +// except for the case of a mismatched OTA element tag, wherein it will skim +// over the rest of the payload, leave the destination empty, and return without +// an error. +func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) { + err = canSet(destination.Type(), tag) + if err != nil { return n, err } + + switch tag.WithoutCN() { + case SI: + // SI: (none) + setInt(destination, uint64(tag.CN())) + case LI: + // LI: + nn, err := decodeAndSetUint(decoder, destination, tag.CN() + 1) + n += nn; if err != nil { return n, err } + case LSI: + // LSI: + nn, err := decodeAndSetInt(decoder, destination, tag.CN() + 1) + n += nn; if err != nil { return n, err } + case FP: + // FP: + nn, err := decodeAndSetFloat(decoder, destination, tag.CN() + 1) + n += nn; if err != nil { return n, err } + case SBA: + // SBA: * + length := tag.CN() + if length > MaxStructureLength { + return 0, ErrTooLong + } + buffer := make([]byte, length) + nn, err := decoder.Read(buffer) + n += nn; if err != nil { return n, err } + setByteArray(destination, buffer) + case LBA: + // LBA: * + length, nn, err := decoder.ReadUintN(tag.CN() + 1) + n += nn; if err != nil { return n, err } + if length > uint64(MaxStructureLength) { + return 0, ErrTooLong + } + buffer := make([]byte, length) + nn, err = decoder.Read(buffer) + n += nn; if err != nil { return n, err } + setByteArray(destination, buffer) + case OTA: + // OTA: * + length, nn, err := decoder.ReadUintN(tag.CN() + 1) + n += nn; if err != nil { return n, err } + if length > uint64(MaxStructureLength) { + return 0, ErrTooLong + } + lengthCast, err := Uint64ToIntSafe(length) + if err != nil { return n, err } + oneTag, nn, err := decoder.ReadTag() + n += nn; if err != nil { return n, err } + if destination.Cap() < lengthCast { + destination.Grow(lengthCast - destination.Cap()) + } + // skip the rest of the array if the one tag doesn't + // match up with the destination + err = canSet(destination.Type().Elem(), oneTag) + if _, ok := err.(errCantAssign); ok { + for _ = range length { + nn, err := Skim(decoder, oneTag) + n += nn; if err != nil { return n, err } + } + break + } + if err != nil { return n, err } + destination.SetLen(lengthCast) + for index := range length { + nn, err := decodeAny(decoder, destination.Index(int(index)), oneTag) + n += nn + if _, ok := err.(errCantAssign); ok { + continue + } else if err != nil { + return n, err + } + } + case KTV: + // KTV: ( )* + length, nn, err := decoder.ReadUintN(tag.CN() + 1) + n += nn; if err != nil { return n, err } + if length > uint64(MaxStructureLength) { + return 0, ErrTooLong + } + destination.Clear() + for _ = range length { + key, nn, err := decoder.ReadUint16() + n += nn; if err != nil { return n, err } + itemTag, nn, err := decoder.ReadTag() + n += nn; if err != nil { return n, err } + value, err := skeletonValue(decoder, itemTag) + if err != nil { return n, err } + nn, err = decodeAny(decoder, value.Elem(), itemTag) + n += nn; if err != nil { return n, err } + destination.SetMapIndex(reflect.ValueOf(key), value.Elem()) + } + default: + return n, fmt.Errorf("unknown TN %d", tag.TN()) + } + return n, nil +} + +// TagAny returns the correct tag for an "any" value. Returns an error if the +// underlying type is unsupported. See [EncodeAny] for a list of supported +// types. +func TagAny(value any) (Tag, error) { + return tagAny(reflect.ValueOf(value)) +} + +func tagAny(reflectValue reflect.Value) (Tag, error) { + // primitives + switch reflectValue.Kind() { + case reflect.Int: return LSI.WithCN(3), nil + case reflect.Int8: return LSI.WithCN(0), nil + case reflect.Int16: return LSI.WithCN(1), nil + case reflect.Int32: return LSI.WithCN(3), nil + case reflect.Int64: return LSI.WithCN(7), nil + case reflect.Uint: return LI.WithCN(3), nil + case reflect.Uint8: return LI.WithCN(0), nil + case reflect.Uint16: return LI.WithCN(1), nil + case reflect.Uint32: return LI.WithCN(3), nil + case reflect.Uint64: return LI.WithCN(7), nil + case reflect.String: return bufferLenTag(reflectValue.Len()), nil + } + if reflectValue.CanConvert(reflect.TypeOf(dummyBuffer)) { + return bufferLenTag(reflectValue.Len()), nil + } + + // aggregates + reflectType := reflectValue.Type() + switch reflectType.Kind() { + case reflect.Slice: return OTA.WithCN(IntBytes(uint64(reflectValue.Len())) - 1), nil + case reflect.Array: return OTA.WithCN(reflectType.Len()), nil + case reflect.Map: + if reflectType.Key() == reflect.TypeOf(uint16(0)) { + return KTV.WithCN(IntBytes(uint64(reflectValue.Len())) - 1), nil + } + return 0, fmt.Errorf("cannot encode map key %v, key must be uint16", reflectType.Key()) + } + return 0, fmt.Errorf("cannot get tag of type %v", reflectType) +} + +func encodeAnySlice(encoder *Encoder, value any, tag Tag) (n int, err error) { + // OTA: * + reflectValue := reflect.ValueOf(value) + nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1) + n += nn; if err != nil { return n, err } + reflectType := reflect.TypeOf(value) + oneTag, err := tagAny(reflect.Zero(reflectType.Elem())) + if err != nil { return n, err } + for index := 0; index < reflectValue.Len(); index += 1 { + itemTag, err := tagAny(reflectValue.Index(index)) + if err != nil { return n, err } + if itemTag.CN() > oneTag.CN() { oneTag = itemTag } + } + if oneTag.Is(SBA) { oneTag += 1 << 5 } + nn, err = encoder.WriteUint8(uint8(oneTag)) + n += nn; if err != nil { return n, err } + for index := 0; index < reflectValue.Len(); index += 1 { + item := reflectValue.Index(index).Interface() + nn, err = EncodeAny(encoder, item, oneTag) + n += nn; if err != nil { return n, err } + } + return n, err +} + +func encodeAnyMap(encoder *Encoder, value any, tag Tag) (n int, err error) { + // KTV: ( )* + reflectValue := reflect.ValueOf(value) + nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1) + n += nn; if err != nil { return n, err } + iter := reflectValue.MapRange() + for iter.Next() { + reflectValue := iter.Value().Elem() + key := iter.Key().Interface().(uint16) + value := reflectValue.Interface() + nn, err = encoder.WriteUint16(key) + n += nn; if err != nil { return n, err } + itemTag, err := tagAny(reflectValue) + if err != nil { return n, err } + nn, err = encoder.WriteUint8(uint8(itemTag)) + n += nn; if err != nil { return n, err } + nn, err = EncodeAny(encoder, value, itemTag) + n += nn; if err != nil { return n, err } + } + return n, nil +} + +func canSet(destination reflect.Type, tag Tag) error { + switch tag.WithoutCN() { + case SI, LI, LSI: + switch destination.Kind() { + case + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + default: + return errCantAssignf("cannot assign integer to %v", destination) + } + case FP: + switch destination.Kind() { + case reflect.Float32, reflect.Float64: + default: + return errCantAssignf("cannot assign float to %v", destination) + } + case SBA, LBA: + if destination.Kind() != reflect.Slice { + return errCantAssignf("cannot assign byte array to %v", destination) + } + if destination.Elem() != reflect.TypeOf(byte(0)) { + return errCantAssignf("cannot convert %v to *[]byte", destination) + } + case OTA: + if destination.Kind() != reflect.Slice { + return errCantAssignf("cannot assign array to %v", destination) + } + case KTV: + if destination != reflect.TypeOf(dummyMap) { + return errCantAssignf("cannot assign table to %v", destination) + } + default: + return fmt.Errorf("unknown TN %d", tag.TN()) + } + return nil +} + +// setInt expects a settable destination. +func setInt[T int64 | uint64](destination reflect.Value, value T) { + switch { + case destination.CanInt(): + destination.Set(reflect.ValueOf(int64(value)).Convert(destination.Type())) + case destination.CanUint(): + destination.Set(reflect.ValueOf(value).Convert(destination.Type())) + default: + panic("setInt called on an unsupported type") + } +} + +// setFloat expects a settable destination. +func setFloat(destination reflect.Value, value float64) { + destination.Set(reflect.ValueOf(value).Convert(destination.Type())) +} + +// setByteArrayexpects a settable destination. +func setByteArray(destination reflect.Value, value []byte) { + destination.Set(reflect.ValueOf(value)) +} + +// decodeAndSetInt expects a settable destination. +func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) { + value, nn, err := decoder.ReadIntN(bytes) + n += nn; if err != nil { return n, err } + setInt(destination, value) + return n, nil +} + +// decodeAndSetUint expects a settable destination. +func decodeAndSetUint(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) { + value, nn, err := decoder.ReadUintN(bytes) + n += nn; if err != nil { return n, err } + setInt(destination, value) + return n, nil +} + +// decodeAndSetInt expects a settable destination. +func decodeAndSetFloat(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) { + switch bytes { + case 8: + value, nn, err := decoder.ReadFloat64() + n += nn; if err != nil { return n, err } + setFloat(destination, float64(value)) + return n, nil + case 4: + value, nn, err := decoder.ReadFloat32() + n += nn; if err != nil { return n, err } + setFloat(destination, float64(value)) + return n, nil + } + return n, errCantAssignf("unsupported bit width float%d", bytes * 8) +} + +// skeletonValue returns a pointer value. In order for it to be set, it must be +// dereferenced using Elem(). +func skeletonValue(decoder *Decoder, tag Tag) (reflect.Value, error) { + typ, err := typeOf(decoder, tag) + if err != nil { return reflect.Value { }, err } + return reflect.New(typ), nil +} + +// typeOf returns the type of the current tag being decoded. It does not use up +// the decoder, it only peeks. +func typeOf(decoder *Decoder, tag Tag) (reflect.Type, error) { + switch tag.WithoutCN() { + case SI: + return reflect.TypeOf(uint8(0)), nil + case LI: + switch tag.CN() { + case 0: return reflect.TypeOf(uint8(0)), nil + case 1: return reflect.TypeOf(uint16(0)), nil + case 3: return reflect.TypeOf(uint32(0)), nil + case 7: return reflect.TypeOf(uint64(0)), nil + } + return nil, fmt.Errorf("unknown CN %d for LI", tag.CN()) + case LSI: + switch tag.CN() { + case 0: return reflect.TypeOf(int8(0)), nil + case 1: return reflect.TypeOf(int16(0)), nil + case 3: return reflect.TypeOf(int32(0)), nil + case 7: return reflect.TypeOf(int64(0)), nil + } + return nil, fmt.Errorf("unknown CN %d for LSI", tag.CN()) + case FP: + switch tag.CN() { + case 3: return reflect.TypeOf(float32(0)), nil + case 7: return reflect.TypeOf(float64(0)), nil + } + return nil, fmt.Errorf("unknown CN %d for FP", tag.CN()) + case SBA: return reflect.SliceOf(reflect.TypeOf(byte(0))), nil + case LBA: return reflect.SliceOf(reflect.TypeOf(byte(0))), nil + case OTA: + elemTag, dimension, err := peekSlice(decoder, tag) + if err != nil { return nil, err } + if elemTag.Is(OTA) { panic("peekSlice cannot return OTA") } + typ, err := typeOf(decoder, elemTag) + if err != nil { return nil, err } + for _ = range dimension { + typ = reflect.SliceOf(typ) + } + return typ, nil + case KTV: return reflect.TypeOf(dummyMap), nil + } + return nil, fmt.Errorf("unknown TN %d", tag.TN()) +} + +// peekSlice returns the element tag and dimension count of the OTA currently +// being decoded. It does not use up the decoder, it only peeks. +func peekSlice(decoder *Decoder, tag Tag) (Tag, int, error) { + offset := 0 + dimension := 0 + currentTag := tag + for { + elem, populated, n, err := peekSliceOnce(decoder, currentTag, offset) + if err != nil { return 0, 0, err } + currentTag = elem + offset = n + dimension += 1 + if elem.Is(OTA) { + if !populated { + // default to a large byte array, will be + // interpreted as a string. + return LBA, dimension + 1, nil + } + } else { + return elem, dimension, nil + } + } +} + +// peekSliceOnce returns the element tag of the OTA located offset bytes ahead +// of the current position. It does not use up the decoder, it only peeks. The n +// return value denotes how far away from 0 it peeked. If the OTA has more than +// zero items, populated will be set to true. +func peekSliceOnce(decoder *Decoder, tag Tag, offset int) (elem Tag, populated bool, n int, err error) { + lengthStart := offset + lengthEnd := lengthStart + tag.CN() + 1 + elemTagStart := lengthEnd + elemTagEnd := elemTagStart + 1 + + headerBytes, err := decoder.Peek(elemTagEnd) + if err != nil { return 0, false, 0, err } + + elem = Tag(headerBytes[len(headerBytes) - 1]) + for index := lengthStart; index < lengthEnd; index += 1 { + if headerBytes[index] > 0 { + populated = true + break + } + } + n = elemTagEnd + + return +} diff --git a/tape/dynamic_test.go b/tape/dynamic_test.go new file mode 100644 index 0000000..7e7bab7 --- /dev/null +++ b/tape/dynamic_test.go @@ -0,0 +1,310 @@ +package tape + +import "fmt" +import "bytes" +import "testing" +import "reflect" +import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil" + +type userDefinedInteger int16 + +func TestEncodeAnyInt(test *testing.T) { + err := testEncodeAny(test, uint8(0xCA), LI.WithCN(0), tu.S(0xCA)) + if err != nil { test.Fatal(err) } + err = testEncodeAny(test, 400, LSI.WithCN(3), tu.S( + 0, 0, 0x1, 0x90, + )) + if err != nil { test.Fatal(err) } +} + +func TestEncodeAnyTable(test *testing.T) { + err := testEncodeAny(test, map[uint16] any { + 0xF3B9: 1, + 0x0102: 2, + 0x0000: "hi!", + 0xFFFF: []uint16 { 0xBEE5, 0x7777 }, + 0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} }, + 0x2345: [][]int16 { []int16 { 0x5 }, []int16 { 0x17, -0xAAA } }, + 0x3456: userDefinedInteger(0x3921), + }, KTV.WithCN(0), tu.S(7).AddVar( + []byte { + 0xF3, 0xB9, + byte(LSI.WithCN(3)), + 0, 0, 0, 1, + }, + []byte { + 0x01, 0x02, + byte(LSI.WithCN(3)), + 0, 0, 0, 2, + }, + []byte { + 0, 0, + byte(SBA.WithCN(3)), + 'h', 'i', '!', + }, + []byte { + 0xFF, 0xFF, + byte(OTA.WithCN(0)), 2, byte(LI.WithCN(1)), + 0xBE, 0xE5, 0x77, 0x77, + }, + []byte { + 0x12, 0x34, + byte(OTA.WithCN(0)), 2, byte(OTA.WithCN(0)), + 1, byte(LI.WithCN(1)), + 0, 0x5, + 2, byte(LI.WithCN(1)), + 0, 0x17, + 0xAA, 0xAA, + }, + []byte { + 0x23, 0x45, + byte(OTA.WithCN(0)), 2, byte(OTA.WithCN(0)), + 1, byte(LSI.WithCN(1)), + 0, 0x5, + 2, byte(LSI.WithCN(1)), + 0, 0x17, + 0xF5, 0x56, + }, + []byte { + 0x34, 0x56, + byte(LSI.WithCN(1)), + 0x39, 0x21, + }, + )) + if err != nil { test.Fatal(err) } +} + +func TestDecodeWrongType(test *testing.T) { + datas := [][]byte { + /* int8 */ []byte { byte(LSI.WithCN(0)), 0x45 }, + /* int16 */ []byte { byte(LSI.WithCN(1)), 0x45, 0x67 }, + /* int32 */ []byte { byte(LSI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB }, + /* int64 */ []byte { byte(LSI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 }, + /* uint5 */ []byte { byte(SI.WithCN(12)) }, + /* uint8 */ []byte { byte(LI.WithCN(0)), 0x45 }, + /* uint16 */ []byte { byte(LI.WithCN(1)), 0x45, 0x67 }, + /* uint32 */ []byte { byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB }, + /* uint64 */ []byte { byte(LI.WithCN(7)), 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23 }, + /* string */ []byte { byte(SBA.WithCN(7)), 'p', 'u', 'p', 'e', 'v', 'e', 'r' }, + /* []byte */ []byte { byte(SBA.WithCN(5)), 'b', 'l', 'a', 'r', 'g' }, + /* []string */ []byte { + byte(OTA.WithCN(0)), 2, byte(LBA.WithCN(0)), + 0x08, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, + 0x05, 0x11, 0x11, 0x11, 0x11, 0x11, + }, + /* map[uint16] any */ []byte { + byte(KTV.WithCN(0)), 2, + 0x02, 0x23, byte(LSI.WithCN(1)), 0x45, 0x67, + 0x02, 0x23, byte(LI.WithCN(3)), 0x45, 0x67, 0x89, 0xAB, + }, + } + + for index, data := range datas { + test.Logf("data %2d %v [%s]", index, Tag(data[0]), tu.HexBytes(data[1:])) + // integers should only assign to other integers + if index > 8 { + cas := func(destination any) { + n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0])) + if err != nil { test.Fatalf("error: %v | n: %d", err, n) } + reflectValue := reflect.ValueOf(destination).Elem() + if reflectValue.CanInt() { + if reflectValue.Int() != 0 { + test.Fatalf("destination not zero: %v", reflectValue.Elem().Interface()) + } + } else { + if reflectValue.Uint() != 0 { + test.Fatalf("destination not zero: %v", reflectValue.Elem().Interface()) + } + } + if n != len(data) - 1 { + test.Fatalf("n not equal: %d != %d", n, len(data) - 1) + } + } + test.Log("- int8") + { var dest int8; cas(&dest) } + test.Log("- int16") + { var dest int16; cas(&dest) } + test.Log("- int32") + { var dest int32; cas(&dest) } + test.Log("- int64") + { var dest int64; cas(&dest) } + test.Log("- uint8") + { var dest uint8; cas(&dest) } + test.Log("- uint16") + { var dest uint16; cas(&dest) } + test.Log("- uint32") + { var dest uint32; cas(&dest) } + test.Log("- uint64") + { var dest uint64; cas(&dest) } + } + arrayCase := func(destination any) { + n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data[1:])), destination, Tag(data[0])) + if err != nil { test.Fatalf("error: %v | n: %d", err, n) } + reflectDestination := reflect.ValueOf(destination) + reflectValue := reflectDestination.Elem() + if reflectValue.Len() != 0 { + test.Fatalf("len(destination) not zero: %v", reflectValue.Interface()) + } + if n != len(data) - 1 { + test.Fatalf("n not equal: %d != %d", n, len(data) - 1) + } + } + // SBA/LBA types should only assign to other SBA/LBA types + if index != 9 && index != 10 { + test.Log("- string") + { var dest string; arrayCase(&dest) } + test.Log("- []byte") + { var dest []byte; arrayCase(&dest) } + } + // arrays should only assign to other arrays + if index != 11 { + test.Log("- []string") + { var dest []string; arrayCase(&dest) } + } + // tables should only assign to other tables + if index != 12 { + test.Log("- map[uint16] any") + { var dest = map[uint16] any { }; arrayCase(&dest) } + } + } +} + +func TestEncodeDecodeAnyTable(test *testing.T) { + err := testEncodeDecodeAny(test, map[uint16] any { + 0xF3B9: uint32(1), + 0x0102: uint32(2), + 0x0103: int64(23432), + 0x0104: int64(-88777), + 0x0000: []byte("hi!"), + 0xFFFF: []uint16 { 0xBEE5, 0x7777 }, + 0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} }, + }, nil) + if err != nil { test.Fatal(err) } +} + +func TestPeekSlice(test *testing.T) { + buffer := bytes.NewBuffer([]byte { + 2, byte(OTA.WithCN(3)), + 0, 0, 0, 1, byte(LI.WithCN(1)), + 0, 0x5, + 2, byte(LI.WithCN(1)), + 0, 0x17, + 0xAA, 0xAA, + }) + decoder := NewDecoder(buffer) + + elem, dimension, err := peekSlice(decoder, OTA.WithCN(0)) + if err != nil { test.Fatal(err) } + if elem != LI.WithCN(1) { + test.Fatalf("wrong element tag: %v %02X", elem, byte(elem)) + } + if got, correct := dimension, 2; got != correct { + test.Fatalf("wrong dimension: %d != %d", got, correct) + } +} + +func TestPeekSliceOnce(test *testing.T) { + buffer := bytes.NewBuffer([]byte { + 2, byte(OTA.WithCN(3)), + 0, 0, 0, 1, byte(LI.WithCN(1)), + 0, 0x5, + 2, byte(LI.WithCN(1)), + 0, 0x17, + 0xAA, 0xAA, + }) + decoder := NewDecoder(buffer) + + test.Log("--- stage 1") + elem, populated, n, err := peekSliceOnce(decoder, OTA.WithCN(0), 0) + if err != nil { test.Fatal(err) } + if elem != OTA.WithCN(3) { + test.Fatal("wrong element tag:", elem) + } + if !populated { + test.Fatal("wrong populated:", populated) + } + if got, correct := n, 2; got != correct { + test.Fatalf("wrong n: %d != %d", got, correct) + } + + test.Log("--- stage 2") + elem, populated, n, err = peekSliceOnce(decoder, elem, n) + if err != nil { test.Fatal(err) } + if elem != LI.WithCN(1) { + test.Fatal("wrong element tag:", elem) + } + if !populated { + test.Fatal("wrong populated:", populated) + } + if got, correct := n, 7; got != correct { + test.Fatalf("wrong n: %d != %d", got, correct) + } +} + +func encAny(value any) ([]byte, Tag, int, error) { + tag, err := TagAny(value) + if err != nil { return nil, 0, 0, err } + buffer := bytes.Buffer { } + encoder := NewEncoder(&buffer) + n, err := EncodeAny(encoder, value, tag) + if err != nil { return nil, 0, n, err } + encoder.Flush() + return buffer.Bytes(), tag, n, nil +} + +func decAny(data []byte) (Tag, any, int, error) { + destination := map[uint16] any { } + tag, err := TagAny(destination) + if err != nil { return 0, nil, 0, err } + n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data)), &destination, tag) + if err != nil { return 0, nil, n, err } + return tag, destination, n, nil +} + +func testEncodeAny(test *testing.T, value any, correctTag Tag, correctBytes tu.Snake) error { + bytes, tag, n, err := encAny(value) + if err != nil { return err } + test.Log("n: ", n) + test.Log("tag: ", tag) + test.Log("got: ", tu.HexBytes(bytes)) + test.Log("correct:", correctBytes) + if tag != correctTag { + return fmt.Errorf("tag not equal: %v != %v", tag, correctTag) + } + if ok, n := correctBytes.Check(bytes); !ok { + return fmt.Errorf("bytes not equal at index %d", n) + } + if n != len(bytes) { + return fmt.Errorf("n not equal: %d != %d", n, len(bytes)) + } + return nil +} + +func testEncodeDecodeAny(test *testing.T, value, correctValue any) error { + if correctValue == nil { + correctValue = value + } + + test.Log("encoding...") + bytes, tag, n, err := encAny(value) + if err != nil { return err } + test.Log("n: ", n) + test.Log("tag:", tag) + test.Log("got:", tu.HexBytes(bytes)) + test.Log("decoding...", tag) + if n != len(bytes) { + return fmt.Errorf("n not equal: %d != %d", n, len(bytes)) + } + + _, decoded, n, err := decAny(bytes) + if err != nil { return err } + test.Log("got: ", tu.Describe(decoded)) + test.Log("correct:", tu.Describe(correctValue)) + if !reflect.DeepEqual(decoded, correctValue) { + return fmt.Errorf("values not equal") + } + if n != len(bytes) { + return fmt.Errorf("n not equal: %d != %d", n, len(bytes)) + } + return nil +} diff --git a/tape/encode.go b/tape/encode.go new file mode 100644 index 0000000..efce7a0 --- /dev/null +++ b/tape/encode.go @@ -0,0 +1,189 @@ +package tape + +import "io" +import "math" +import "bufio" + +// Encodable is any type that can write itself to an encoder. +type Encodable interface { + // Encode sends data to encoder. It returns the amount of bytes written, + // and an error if the write stopped early. + Encode(encoder *Encoder) (n int, err error) +} + +// Encoder encodes data to an io.Writer. +type Encoder struct { + bufio.Writer +} + +// NewEncoder creates a new encoder that writes to writer. +func NewEncoder(writer io.Writer) *Encoder { + encoder := &Encoder { } + encoder.Reset(writer) + return encoder +} + +// WriteInt8 encodes an 8-bit signed integer to the output writer. +func (this *Encoder) WriteInt8(value int8) (n int, err error) { + return this.WriteUint8(uint8(value)) +} + +// WriteUint8 encodes an 8-bit unsigned integer to the output writer. +func (this *Encoder) WriteUint8(value uint8) (n int, err error) { + return this.Write([]byte { byte(value) }) +} + +// WriteInt16 encodes an 16-bit signed integer to the output writer. +func (this *Encoder) WriteInt16(value int16) (n int, err error) { + return this.WriteUint16(uint16(value)) +} + +// WriteUint16 encodes an 16-bit unsigned integer to the output writer. +func (this *Encoder) WriteUint16(value uint16) (n int, err error) { + return this.Write([]byte { + byte(value >> 8), + byte(value), + }) +} + +// WriteInt32 encodes an 32-bit signed integer to the output writer. +func (this *Encoder) WriteInt32(value int32) (n int, err error) { + return this.WriteUint32(uint32(value)) +} + +// WriteUint32 encodes an 32-bit unsigned integer to the output writer. +func (this *Encoder) WriteUint32(value uint32) (n int, err error) { + return this.Write([]byte { + byte(value >> 24), + byte(value >> 16), + byte(value >> 8), + byte(value), + }) +} + +// WriteInt64 encodes an 64-bit signed integer to the output writer. +func (this *Encoder) WriteInt64(value int64) (n int, err error) { + return this.WriteUint64(uint64(value)) +} + +// WriteUint64 encodes an 64-bit unsigned integer to the output writer. +func (this *Encoder) WriteUint64(value uint64) (n int, err error) { + return this.Write([]byte { + byte(value >> 56), + byte(value >> 48), + byte(value >> 40), + byte(value >> 32), + byte(value >> 24), + byte(value >> 16), + byte(value >> 8), + byte(value), + }) +} + +// WriteIntN encodes an N-byte signed integer to the output writer. +func (this *Encoder) WriteIntN(value int64, bytes int) (n int, err error) { + return this.WriteUintN(uint64(value), bytes) +} + +// for Write/ReadUintN, increase buffers if go somehow gets support for over 64 +// bit integers. we could also make an expanding int type in goutil to use here, +// or maybe there is one in the stdlib. keep the int64 versions as well though +// because its ergonomic. + +// WriteUintN encodes an N-byte unsigned integer to the output writer. +func (this *Encoder) WriteUintN(value uint64, bytes int) (n int, err error) { + // TODO: don't make multiple write calls (without allocating) + buffer := [1]byte { } + for bytesLeft := bytes; bytesLeft > 0; bytesLeft -- { + buffer[0] = byte(value) >> ((bytesLeft - 1) * 8) + nn, err := this.Write(buffer[:]) + n += nn; if err != nil { return n, err } + } + return n, nil +} + +// WriteFloat16 encodes a 16-bit floating point value to the output writer. +func (this *Encoder) WriteFloat16(value float32) (n int, err error) { + return this.WriteUint16(f32bitsToF16bits(math.Float32bits(value))) +} + +// WriteFloat32 encodes a 32-bit floating point value to the output writer. +func (this *Encoder) WriteFloat32(value float32) (n int, err error) { + return this.WriteUint32(math.Float32bits(value)) +} + +// WriteFloat64 encodes a 64-bit floating point value to the output writer. +func (this *Encoder) WriteFloat64(value float64) (n int, err error) { + return this.WriteUint64(math.Float64bits(value)) +} + +// WriteTag encodes a [Tag] to the output writer. +func (this *Encoder) WriteTag(value Tag) (n int, err error) { + return this.WriteUint8(uint8(value)) +} + +// f32bitsToF16bits returns uint16 (Float16 bits) converted from the specified float32. +// Conversion rounds to nearest integer with ties to even. +// Taken from https://github.com/x448/float16/blob/v0.8.4/float16 +// +// MIT License +// +// Copyright (c) 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +func f32bitsToF16bits(u32 uint32) uint16 { + // Translated from Rust to Go by Montgomery Edwards⁴⁴⁸ (github.com/x448). + // All 4294967296 conversions with this were confirmed to be correct by x448. + // Original Rust implementation is by Kathryn Long (github.com/starkat99) with MIT license. + + sign := u32 & 0x80000000 + exp := u32 & 0x7f800000 + coef := u32 & 0x007fffff + + if exp == 0x7f800000 { + // NaN or Infinity + nanBit := uint32(0) + if coef != 0 { + nanBit = uint32(0x0200) + } + return uint16((sign >> 16) | uint32(0x7c00) | nanBit | (coef >> 13)) + } + + halfSign := sign >> 16 + + unbiasedExp := int32(exp>>23) - 127 + halfExp := unbiasedExp + 15 + + if halfExp >= 0x1f { + return uint16(halfSign | uint32(0x7c00)) + } + + if halfExp <= 0 { + if 14-halfExp > 24 { + return uint16(halfSign) + } + coef := coef | uint32(0x00800000) + halfCoef := coef >> uint32(14-halfExp) + roundBit := uint32(1) << uint32(13-halfExp) + if (coef&roundBit) != 0 && (coef&(3*roundBit-1)) != 0 { + halfCoef++ + } + return uint16(halfSign | halfCoef) + } + + uHalfExp := uint32(halfExp) << 10 + halfCoef := coef >> 13 + roundBit := uint32(0x00001000) + if (coef&roundBit) != 0 && (coef&(3*roundBit-1)) != 0 { + return uint16((halfSign | uHalfExp | halfCoef) + 1) + } + return uint16(halfSign | uHalfExp | halfCoef) +} diff --git a/tape/error.go b/tape/error.go new file mode 100644 index 0000000..18a58a2 --- /dev/null +++ b/tape/error.go @@ -0,0 +1,12 @@ +package tape + +// Error enumerates common errors in this package. +type Error string; const ( + ErrTooLong Error = "data structure too long" + ErrTooLarge Error = "number too large" +) + +// Error implements the error interface. +func (err Error) Error() string { + return string(err) +} diff --git a/tape/limits.go b/tape/limits.go new file mode 100644 index 0000000..38c8c1a --- /dev/null +++ b/tape/limits.go @@ -0,0 +1,26 @@ +package tape + +// MaxStructureLength determines how long a TAPE data structure can be. This +// applies to: +// +// - OTA +// - SBA/LBA +// - KTV +// +// By default it is set at 2^20 (about a million). +// You shouldn't need to change this. If you do, it should only be set once at +// the start of the program. +var MaxStructureLength = 1024 * 1024 + +// MaxInt is the maximum value an int can hold. This varies depending on the +// system. +const MaxInt int = int(^uint(0) >> 1) + +// Uint64ToIntSafe casts the input to an int if it can be done without overflow, +// or returns an error otherwise. +func Uint64ToIntSafe(input uint64) (int, error) { + if input > uint64(MaxInt) { + return 0, ErrTooLarge + } + return int(input), nil +} diff --git a/tape/measure.go b/tape/measure.go new file mode 100644 index 0000000..b57fd9d --- /dev/null +++ b/tape/measure.go @@ -0,0 +1,12 @@ +package tape + +// IntBytes returns the number of bytes required to hold a given unsigned +// integer. +func IntBytes(value uint64) int { + bytes := 0 + for value > 0 || bytes == 0 { + value >>= 8; + bytes ++ + } + return bytes +} diff --git a/tape/measure_test.go b/tape/measure_test.go new file mode 100644 index 0000000..4cce9dd --- /dev/null +++ b/tape/measure_test.go @@ -0,0 +1,21 @@ +package tape + +import "testing" + +func TestIntBytes(test *testing.T) { + if correct, got := 1, IntBytes(0); correct != got { + test.Fatal("wrong:", got) + } + if correct, got := 1, IntBytes(1); correct != got { + test.Fatal("wrong:", got) + } + if correct, got := 1, IntBytes(16); correct != got { + test.Fatal("wrong:", got) + } + if correct, got := 1, IntBytes(255); correct != got { + test.Fatal("wrong:", got) + } + if correct, got := 2, IntBytes(256); correct != got { + test.Fatal("wrong:", got) + } +} diff --git a/tape/pairs.go b/tape/pairs.go deleted file mode 100644 index d51e593..0000000 --- a/tape/pairs.go +++ /dev/null @@ -1,83 +0,0 @@ -package tape - -import "iter" - -// DecodePairs decodes message tag/value pairs from a byte slice. It returns an -// iterator over all pairs, where the first value is the tag and the second is -// the value. If data yielded by the iterator is retained, it must be copied -// first. -func DecodePairs(data []byte) (iter.Seq2[uint16, []byte], error) { - // determine section bounds - if len(data) < 2 { return nil, ErrDataTooLarge } - length16, _ := DecodeI16[uint16](data[0:2]) - data = data[2:] - length := int(length16) - headerSize := length * 4 - if len(data) < headerSize { return nil, ErrDataTooLarge } - valuesData := data[headerSize:] - - // ensure the value buffer is big enough - var valuesSize int - for index := range length { - offset := index * 4 - end, _ := DecodeI16[uint16](data[offset + 2:offset + 4]) - valuesSize = int(end) - } - if valuesSize > len(valuesData) { - return nil, ErrDataTooLarge - } - - // return iterator - return func(yield func(uint16, []byte) bool) { - start := uint16(0) - for index := range length { - offset := index * 4 - key , _ := DecodeI16[uint16](data[offset + 0:offset + 2]) - end, _ := DecodeI16[uint16](data[offset + 2:offset + 4]) - // if nextValuesOffset < len(valuesData) { - if !yield(key, valuesData[start:end]) { - return - } - // } else { - // if !yield(key, nil) { - // return - // } - // } - start = end - } - }, nil -} - -// EncodePairs encodes message tag/value pairs into a byte slice. -func EncodePairs(pairs map[uint16] []byte) ([]byte, error) { - // determine section bounds - headerSize := 2 + len(pairs) * 4 - valuesSize := 0 - for _, value := range pairs { - valuesSize += len(value) - } - - // generate data - buffer := make([]byte, headerSize + valuesSize) - length16, ok := U16CastSafe(len(pairs)) - if !ok { return nil, ErrDataTooLarge } - EncodeI16[uint16](buffer[0:2], length16) - index := 0 - end := headerSize - for key, value := range pairs { - start := end - end += len(value) - tagOffset := 2 + index * 4 - end16, ok := U16CastSafe(end - headerSize) - if !ok { return nil, ErrDataTooLarge } - - // write tag and length - EncodeI16[uint16](buffer[tagOffset + 0:tagOffset + 2], key) - EncodeI16[uint16](buffer[tagOffset + 2:tagOffset + 4], end16) - - // write value - copy(buffer[start:end], value) - index ++ - } - return buffer, nil -} diff --git a/tape/pairs_test.go b/tape/pairs_test.go deleted file mode 100644 index 31bb7d6..0000000 --- a/tape/pairs_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package tape - -import "slices" -import "testing" - -func TestDecodePairs(test *testing.T) { - pairs := map[uint16] []byte { - 3894: []byte("foo"), - 7: []byte("br"), - } - got, err := DecodePairs([]byte { - 0, 2, - 0, 7, 0, 2, - 15, 54, 0, 5, - 98, 114, - 102, 111, 111}) - if err != nil { test.Fatal(err) } - length := 0 - for key, value := range got { - test.Log(key, value) - if !slices.Equal(pairs[key], value) { test.Fatal("not equal") } - length ++ - } - test.Log("length") - if length != len(pairs) { test.Fatal("wrong length") } -} - -func TestEncodePairs(test *testing.T) { - pairs := map[uint16] []byte { - 3894: []byte("foo"), - 7: []byte("br"), - } - got, err := EncodePairs(pairs) - if err != nil { test.Fatal(err) } - test.Log(got) - valid := slices.Equal(got, []byte { - 0, 2, - 15, 54, 0, 3, - 0, 7, 0, 5, - 102, 111, 111, - 98, 114}) || - slices.Equal(got, []byte { - 0, 2, - 0, 7, 0, 2, - 15, 54, 0, 5, - 98, 114, - 102, 111, 111}) - if !valid { test.Fatal("not equal") } -} - -func FuzzDecodePairs(fuzz *testing.F) { - fuzz.Add([]byte { - 0, 2, - 0, 7, 0, 2, - 15, 54, 0, 5, - 98, 114, - 102, 111, 111}) - fuzz.Fuzz(func(t *testing.T, buffer []byte) { - // ensure it does not panic :P - DecodePairs(buffer) - }) -} diff --git a/tape/skim.go b/tape/skim.go new file mode 100644 index 0000000..b4d2ab6 --- /dev/null +++ b/tape/skim.go @@ -0,0 +1,54 @@ +package tape + +import "fmt" + +// Skim uses up data from a decoder to "skim" over one value (and all else +// contained within it) without actually putting the data anywhere. +func Skim(decoder *Decoder, tag Tag) (n int, err error) { + switch tag.WithoutCN() { + case SI: + // SI: (none) + return n, nil + case LI, LSI, FP: + // LI: + // LSI: + // FP: + nn, err := decoder.Discard(tag.CN() + 1) + n += nn; if err != nil { return n, err } + case SBA: + // SBA: * + nn, err := decoder.Discard(tag.CN()) + n += nn; if err != nil { return n, err } + case LBA: + // LBA: * + length, nn, err := decoder.ReadUintN(tag.CN() + 1) + n += nn; if err != nil { return n, err } + nn, err = decoder.Discard(int(length)) + n += nn; if err != nil { return n, err } + case OTA: + // OTA: * + length, nn, err := decoder.ReadUintN(tag.CN() + 1) + n += nn; if err != nil { return n, err } + oneTag, nn, err := decoder.ReadTag() + n += nn; if err != nil { return n, err } + for _ = range length { + nn, err := Skim(decoder, oneTag) + n += nn; if err != nil { return n, err } + } + case KTV: + // KTV: ( )* + length, nn, err := decoder.ReadUintN(tag.CN() + 1) + n += nn; if err != nil { return n, err } + for _ = range length { + nn, err := decoder.Discard(2) + n += nn; if err != nil { return n, err } + itemTag, nn, err := decoder.ReadTag() + n += nn; if err != nil { return n, err } + nn, err = Skim(decoder, itemTag) + n += nn; if err != nil { return n, err } + } + default: + return n, fmt.Errorf("unknown TN %d", tag.TN()) + } + return n, nil +} diff --git a/tape/skim_test.go b/tape/skim_test.go new file mode 100644 index 0000000..411f5eb --- /dev/null +++ b/tape/skim_test.go @@ -0,0 +1,137 @@ +package tape + +import "bytes" +import "testing" + +func TestSkimInteger(test *testing.T) { + data := []byte { + 0x12, 0x45, 0x23, 0xF9, + } + mainDataLen := len(data) + // extra junk + data = append(data, 0x00, 0x01, 0x02, 0x03,) + + n, err := Skim(NewDecoder(bytes.NewBuffer(data)), LI.WithCN(3)) + if err != nil { + test.Fatal(err) + } + if got, correct := n, mainDataLen; got != correct { + test.Fatalf("n not equal: %d != %d", got, correct) + } +} + +func TestSkimArray(test *testing.T) { + data := []byte { + 2, byte(LI.WithCN(1)), + 0xBE, 0xE5, 0x77, 0x77, + } + mainDataLen := len(data) + // extra junk + data = append(data, 0x00, 0x01, 0x02, 0x03,) + + n, err := Skim(NewDecoder(bytes.NewBuffer(data)), OTA.WithCN(0)) + if err != nil { + test.Fatal(err) + } + if got, correct := n, mainDataLen; got != correct { + test.Fatalf("n not equal: %d != %d", got, correct) + } +} + +func TestSkimNestedArray(test *testing.T) { + data := []byte { + 2, byte(OTA.WithCN(0)), + 1, byte(LSI.WithCN(1)), + 0, 0x5, + 2, byte(LSI.WithCN(1)), + 0, 0x17, + 0xF5, 0x56, + } + mainDataLen := len(data) + // extra junk + data = append(data, 0x00, 0x01, 0x02, 0x03,) + + n, err := Skim(NewDecoder(bytes.NewBuffer(data)), OTA.WithCN(0)) + if err != nil { + test.Fatal(err) + } + if got, correct := n, mainDataLen; got != correct { + test.Fatalf("n not equal: %d != %d", got, correct) + } +} + +func TestSkimTable(test *testing.T) { + data := []byte { + 2, + 0xF3, 0xB9, + byte(LSI.WithCN(3)), + 0, 0, 0, 1, + + 0x01, 0x02, + byte(LSI.WithCN(3)), + 0, 0, 0, 2, + } + mainDataLen := len(data) + // extra junk + data = append(data, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03) + + n, err := Skim(NewDecoder(bytes.NewBuffer(data)), KTV.WithCN(0)) + if got, correct := n, mainDataLen; got != correct { + test.Fatalf("n not equal: %d != %d ... (%d)", got, correct, len(data)) + } + if err != nil { + test.Fatal(err) + } +} + +func TestSkimTableComplex(test *testing.T) { + data := []byte { + 7, + 0xF3, 0xB9, + byte(LSI.WithCN(3)), + 0, 0, 0, 1, + + 0x01, 0x02, + byte(LSI.WithCN(3)), + 0, 0, 0, 2, + + 0, 0, + byte(SBA.WithCN(3)), + 'h', 'i', '!', + + 0xFF, 0xFF, + byte(OTA.WithCN(0)), 2, byte(LI.WithCN(1)), + 0xBE, 0xE5, 0x77, 0x77, + + 0x12, 0x34, + byte(OTA.WithCN(0)), 2, byte(OTA.WithCN(0)), + 1, byte(LI.WithCN(1)), + 0, 0x5, + 2, byte(LI.WithCN(1)), + 0, 0x17, + 0xAA, 0xAA, + + 0x23, 0x45, + byte(OTA.WithCN(0)), 2, byte(OTA.WithCN(0)), + 1, byte(LSI.WithCN(1)), + 0, 0x5, + 2, byte(LSI.WithCN(1)), + 0, 0x17, + 0xF5, 0x56, + + 0x34, 0x56, + byte(LSI.WithCN(1)), + 0x39, 0x21, + } + mainDataLen := len(data) + // extra junk + data = append(data, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03, 0x00, 0x01, 0x02, 0x03) + + n, err := Skim(NewDecoder(bytes.NewBuffer(data)), KTV.WithCN(0)) + if got, correct := n, mainDataLen; got != correct { + test.Fatalf("n not equal: %d != %d ... (%d)", got, correct, len(data)) + } + if err != nil { + test.Fatal(err) + } +} diff --git a/tape/tag.go b/tape/tag.go new file mode 100644 index 0000000..f4001f2 --- /dev/null +++ b/tape/tag.go @@ -0,0 +1,72 @@ +package tape + +import "fmt" + +// TODO: fix #7 + +type Tag byte; const ( + SI Tag = 0 << 5 // Small integer + LI Tag = 1 << 5 // Large unsigned integer + LSI Tag = 2 << 5 // Large signed integer + FP Tag = 3 << 5 // Floating point + SBA Tag = 4 << 5 // Small byte array + LBA Tag = 5 << 5 // Large byte array + OTA Tag = 6 << 5 // One-tag array + KTV Tag = 7 << 5 // Key-tag-value table + TNMask Tag = 0xE0 // The entire TN bitfield + CNMask Tag = 0x1F // The entire CN bitfield + CNLimit Tag = 32 // All valid CNs are < CNLimit +) + +func (tag Tag) TN() int { + return int(tag >> 5) +} + +func (tag Tag) CN() int { + return int(tag & CNMask) +} + +func (tag Tag) WithCN(cn int) Tag { + return (tag & TNMask) | Tag(cn % 32) +} + +func (tag Tag) WithoutCN() Tag { + return tag.WithCN(0) +} + +func (tag Tag) Is(other Tag) bool { + return tag.TN() == other.TN() +} + +func (tag Tag) String() string { + tn := fmt.Sprint(tag.TN()) + switch tag.WithoutCN() { + case SI: tn = "SI" + case LI: tn = "LI" + case LSI: tn = "LSI" + case FP: tn = "FP" + case SBA: tn = "SBA" + case LBA: tn = "LBA" + case OTA: tn = "OTA" + case KTV: tn = "KTV" + } + return fmt.Sprintf("%s:%d", tn, tag.CN()) +} + +// BufferTag returns the appropriate tag for a buffer. +func BufferTag(value []byte) Tag { + return bufferLenTag(len(value)) +} + +// StringTag returns the appropriate tag for a string. +func StringTag(value string) Tag { + return bufferLenTag(len(value)) +} + +func bufferLenTag(length int) Tag { + if length < int(CNLimit) { + return SBA.WithCN(length) + } else { + return LBA.WithCN(IntBytes(uint64(length))) + } +} diff --git a/tape/types.go b/tape/types.go deleted file mode 100644 index 32a2e59..0000000 --- a/tape/types.go +++ /dev/null @@ -1,311 +0,0 @@ -// Package tape implements Table Pair Encoding. -package tape - -import "fmt" - -const dataMaxSize = 0xFFFF -const uint16Max = 0xFFFF - -// Error enumerates common errors in this package. -type Error string; const ( - ErrWrongBufferLength Error = "wrong buffer length" - ErrDataTooLarge Error = "data too large" -) - -// Error implements the error interface. -func (err Error) Error() string { - return string(err) -} - -// Int8 is any 8-bit integer. -type Int8 interface { ~uint8 | ~int8 } -// Int16 is any 16-bit integer. -type Int16 interface { ~uint16 | ~int16 } -// Int32 is any 32-bit integer. -type Int32 interface { ~uint32 | ~int32 } -// Int64 is any 64-bit integer. -type Int64 interface { ~uint64 | ~int64 } -// String is any string. -type String interface { ~string } - -// DecodeI8 decodes an 8 bit integer from the given data. -func DecodeI8[T Int8](data []byte) (T, error) { - if len(data) != 1 { return 0, fmt.Errorf("decoding int8: %w", ErrWrongBufferLength) } - return T(data[0]), nil -} - -// EncodeI8 encodes an 8 bit integer into the given buffer. -func EncodeI8[T Int8](buffer []byte, value T) error { - if len(buffer) != 1 { return fmt.Errorf("encoding int8: %w", ErrWrongBufferLength) } - buffer[0] = byte(value) - return nil -} - -// DecodeI16 decodes a 16 bit integer from the given data. -func DecodeI16[T Int16](data []byte) (T, error) { - if len(data) != 2 { return 0, fmt.Errorf("decoding int16: %w", ErrWrongBufferLength) } - return T(data[0]) << 8 | T(data[1]), nil -} - -// EncodeI16 encodes a 16 bit integer into the given buffer. -func EncodeI16[T Int16](buffer []byte, value T) error { - if len(buffer) != 2 { return fmt.Errorf("encoding int16: %w", ErrWrongBufferLength) } - buffer[0] = byte(value >> 8) - buffer[1] = byte(value) - return nil -} - -// DecodeI32 decodes a 32 bit integer from the given data. -func DecodeI32[T Int32](data []byte) (T, error) { - if len(data) != 4 { return 0, fmt.Errorf("decoding int32: %w", ErrWrongBufferLength) } - return T(data[0]) << 24 | - T(data[1]) << 16 | - T(data[2]) << 8 | - T(data[3]), nil -} - -// EncodeI32 encodes a 32 bit integer into the given buffer. -func EncodeI32[T Int32](buffer []byte, value T) error { - if len(buffer) != 4 { return fmt.Errorf("encoding int32: %w", ErrWrongBufferLength) } - buffer[0] = byte(value >> 24) - buffer[1] = byte(value >> 16) - buffer[2] = byte(value >> 8) - buffer[3] = byte(value) - return nil -} - -// DecodeI64 decodes a 64 bit integer from the given data. -func DecodeI64[T Int64](data []byte) (T, error) { - if len(data) != 8 { return 0, fmt.Errorf("decoding int64: %w", ErrWrongBufferLength) } - return T(data[0]) << 56 | - T(data[1]) << 48 | - T(data[2]) << 40 | - T(data[3]) << 32 | - T(data[4]) << 24 | - T(data[5]) << 16 | - T(data[6]) << 8 | - T(data[7]), nil -} - -// EncodeI64 encodes a 64 bit integer into the given buffer. -func EncodeI64[T Int64](buffer []byte, value T) error { - if len(buffer) != 8 { return fmt.Errorf("encoding int64: %w", ErrWrongBufferLength) } - buffer[0] = byte(value >> 56) - buffer[1] = byte(value >> 48) - buffer[2] = byte(value >> 40) - buffer[3] = byte(value >> 32) - buffer[4] = byte(value >> 24) - buffer[5] = byte(value >> 16) - buffer[6] = byte(value >> 8) - buffer[7] = byte(value) - return nil -} - -// DecodeString decodes a string from the given data. -func DecodeString[T String](data []byte) (T, error) { - return T(data), nil -} - -// EncodeString encodes a string into the given buffer. -func EncodeString[T String](data []byte, value T) error { - if len(data) != len(value) { return fmt.Errorf("encoding string: %w", ErrWrongBufferLength) } - copy(data, value) - return nil -} - -// StringSize returns the size of a string. Returns 0 and an error if the size -// is too large. -func StringSize[T String](value T) (int, error) { - if len(value) > dataMaxSize { return 0, ErrDataTooLarge } - return len(value), nil -} - -// DecodeStringArray decodes a packed string array from the given data. -func DecodeStringArray[T String](data []byte) ([]T, error) { - result := []T { } - for len(data) > 0 { - if len(data) < 2 { return nil, fmt.Errorf("decoding []string: %w", ErrWrongBufferLength) } - itemSize16, _ := DecodeI16[uint16](data[:2]) - itemSize := int(itemSize16) - data = data[2:] - if len(data) < itemSize { return nil, fmt.Errorf("decoding []string: %w", ErrWrongBufferLength) } - result = append(result, T(data[:itemSize])) - data = data[itemSize:] - } - return result, nil -} - -// EncodeStringArray encodes a packed string array into the given buffer. -func EncodeStringArray[T String](buffer []byte, value []T) error { - for _, item := range value { - length, err := StringSize(item) - if err != nil { return err } - if len(buffer) < 2 + length { return fmt.Errorf("encoding []string: %w", ErrWrongBufferLength) } - EncodeI16(buffer[:2], uint16(length)) - buffer = buffer[2:] - copy(buffer, item) - buffer = buffer[length:] - } - if len(buffer) > 0 { return fmt.Errorf("encoding []string: %w", ErrWrongBufferLength) } - return nil -} - -// StringArraySize returns the size of a packed string array. Returns 0 and an -// error if the size is too large. -func StringArraySize[T String](value []T) (int, error) { - total := 0 - for _, item := range value { - total += 2 + len(item) - } - if total > dataMaxSize { return 0, ErrDataTooLarge } - return total, nil -} - -// DecodeI8Array decodes a packed array of 8 bit integers from the given data. -func DecodeI8Array[T Int8](data []byte) ([]T, error) { - result := make([]T, len(data)) - for index, item := range data { - result[index] = T(item) - } - return result, nil -} - -// EncodeI8Array encodes a packed array of 8 bit integers into the given buffer. -func EncodeI8Array[T Int8](buffer []byte, value []T) error { - if len(buffer) != len(value) { return fmt.Errorf("encoding []int8: %w", ErrWrongBufferLength) } - for index, item := range value { - buffer[index] = byte(item) - } - return nil -} - -// I8ArraySize returns the size of a packed 8 bit integer array. Returns 0 and -// an error if the size is too large. -func I8ArraySize[T Int8](value []T) (int, error) { - total := len(value) - if total > dataMaxSize { return 0, ErrDataTooLarge } - return total, nil -} - -// DecodeI16Array decodes a packed array of 16 bit integers from the given data. -func DecodeI16Array[T Int16](data []byte) ([]T, error) { - if len(data) % 2 != 0 { return nil, fmt.Errorf("decoding []int16: %w", ErrWrongBufferLength) } - length := len(data) / 2 - result := make([]T, length) - for index := range length { - offset := index * 2 - result[index] = T(data[offset]) << 8 | T(data[offset + 1]) - } - return result, nil -} - -// EncodeI16Array encodes a packed array of 16 bit integers into the given buffer. -func EncodeI16Array[T Int16](buffer []byte, value []T) error { - if len(buffer) != len(value) * 2 { return fmt.Errorf("encoding []int16: %w", ErrWrongBufferLength) } - for _, item := range value { - buffer[0] = byte(item >> 8) - buffer[1] = byte(item) - buffer = buffer[2:] - } - return nil -} - -// I16ArraySize returns the size of a packed 16 bit integer array. Returns 0 and -// an error if the size is too large. -func I16ArraySize[T Int16](value []T) (int, error) { - total := len(value) * 2 - if total > dataMaxSize { return 0, ErrDataTooLarge } - return total, nil -} - -// DecodeI32Array decodes a packed array of 32 bit integers from the given data. -func DecodeI32Array[T Int32](data []byte) ([]T, error) { - if len(data) % 4 != 0 { return nil, fmt.Errorf("decoding []int32: %w", ErrWrongBufferLength) } - length := len(data) / 4 - result := make([]T, length) - for index := range length { - offset := index * 4 - result[index] = - T(data[offset + 0]) << 24 | - T(data[offset + 1]) << 16 | - T(data[offset + 2]) << 8 | - T(data[offset + 3]) - } - return result, nil -} - -// EncodeI32Array encodes a packed array of 32 bit integers into the given buffer. -func EncodeI32Array[T Int32](buffer []byte, value []T) error { - if len(buffer) != len(value) * 4 { return fmt.Errorf("encoding []int32: %w", ErrWrongBufferLength) } - for _, item := range value { - buffer[0] = byte(item >> 24) - buffer[1] = byte(item >> 16) - buffer[2] = byte(item >> 8) - buffer[3] = byte(item) - buffer = buffer[4:] - } - return nil -} - -// I32ArraySize returns the size of a packed 32 bit integer array. Returns 0 and -// an error if the size is too large. -func I32ArraySize[T Int32](value []T) (int, error) { - total := len(value) * 4 - if total > dataMaxSize { return 0, ErrDataTooLarge } - return total, nil -} - -// DecodeI64Array decodes a packed array of 32 bit integers from the given data. -func DecodeI64Array[T Int64](data []byte) ([]T, error) { - if len(data) % 8 != 0 { return nil, fmt.Errorf("decoding []int64: %w", ErrWrongBufferLength) } - length := len(data) / 8 - result := make([]T, length) - for index := range length { - offset := index * 8 - result[index] = - T(data[offset + 0]) << 56 | - T(data[offset + 1]) << 48 | - T(data[offset + 2]) << 40 | - T(data[offset + 3]) << 32 | - T(data[offset + 4]) << 24 | - T(data[offset + 5]) << 16 | - T(data[offset + 6]) << 8 | - T(data[offset + 7]) - } - return result, nil -} - -// EncodeI64Array encodes a packed array of 64 bit integers into the given buffer. -func EncodeI64Array[T Int64](buffer []byte, value []T) error { - if len(buffer) != len(value) * 8 { return fmt.Errorf("encoding []int64: %w", ErrWrongBufferLength) } - for _, item := range value { - buffer[0] = byte(item >> 56) - buffer[1] = byte(item >> 48) - buffer[2] = byte(item >> 40) - buffer[3] = byte(item >> 32) - buffer[4] = byte(item >> 24) - buffer[5] = byte(item >> 16) - buffer[6] = byte(item >> 8) - buffer[7] = byte(item) - buffer = buffer[8:] - } - return nil -} - -// I64ArraySize returns the size of a packed 64 bit integer array. Returns 0 and -// an error if the size is too large. -func I64ArraySize[T Int64](value []T) (int, error) { - total := len(value) * 8 - if total > dataMaxSize { return 0, ErrDataTooLarge } - return total, nil -} - -// U16CastSafe safely casts an integer to a uint16. If an overflow or underflow -// occurs, it will return (0, false). -func U16CastSafe(n int) (uint16, bool) { - if n < uint16Max && n >= 0 { - return uint16(n), true - } else { - return 0, false - } -} diff --git a/tape/types_test.go b/tape/types_test.go deleted file mode 100644 index 994a586..0000000 --- a/tape/types_test.go +++ /dev/null @@ -1,292 +0,0 @@ -package tape - -import "slices" -import "errors" -import "testing" -import "math/rand" - -const largeNumberNTestRounds = 2048 -const randStringBytes = "-abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - -func TestI8(test *testing.T) { - var buffer [16]byte - err := EncodeI8[uint8](buffer[:], 5) - if err.Error() != "encoding int8: wrong buffer length" { test.Fatal(err) } - err = EncodeI8[uint8](buffer[:0], 5) - if err.Error() != "encoding int8: wrong buffer length" { test.Fatal(err) } - _, err = DecodeI8[uint8](buffer[:]) - if err.Error() != "decoding int8: wrong buffer length" { test.Fatal(err) } - _, err = DecodeI8[uint8](buffer[:0]) - if err.Error() != "decoding int8: wrong buffer length" { test.Fatal(err) } - - for number := range uint8(255) { - err := EncodeI8[uint8](buffer[:1], number) - if err != nil { test.Fatal(err) } - decoded, err := DecodeI8[uint8](buffer[:1]) - if err != nil { test.Fatal(err) } - if decoded != number { - test.Fatalf("%d != %d", decoded, number) - } - } -} - -func TestI16(test *testing.T) { - var buffer [16]byte - err := EncodeI16[uint16](buffer[:], 5) - if err.Error() != "encoding int16: wrong buffer length" { test.Fatal(err) } - err = EncodeI16[uint16](buffer[:0], 5) - if err.Error() != "encoding int16: wrong buffer length" { test.Fatal(err) } - _, err = DecodeI16[uint16](buffer[:]) - if err.Error() != "decoding int16: wrong buffer length" { test.Fatal(err) } - _, err = DecodeI16[uint16](buffer[:0]) - if err.Error() != "decoding int16: wrong buffer length" { test.Fatal(err) } - - for _ = range largeNumberNTestRounds { - number := uint16(rand.Int()) - err := EncodeI16[uint16](buffer[:2], number) - if err != nil { test.Fatal(err) } - decoded, err := DecodeI16[uint16](buffer[:2]) - if err != nil { test.Fatal(err) } - if decoded != number { - test.Fatalf("%d != %d", decoded, number) - } - } -} - -func TestI32(test *testing.T) { - var buffer [16]byte - err := EncodeI32[uint32](buffer[:], 5) - if err.Error() != "encoding int32: wrong buffer length" { test.Fatal(err) } - err = EncodeI32[uint32](buffer[:0], 5) - if err.Error() != "encoding int32: wrong buffer length" { test.Fatal(err) } - _, err = DecodeI32[uint32](buffer[:]) - if err.Error() != "decoding int32: wrong buffer length" { test.Fatal(err) } - _, err = DecodeI32[uint32](buffer[:0]) - if err.Error() != "decoding int32: wrong buffer length" { test.Fatal(err) } - - for _ = range largeNumberNTestRounds { - number := uint32(rand.Int()) - err := EncodeI32[uint32](buffer[:4], number) - if err != nil { test.Fatal(err) } - decoded, err := DecodeI32[uint32](buffer[:4]) - if err != nil { test.Fatal(err) } - if decoded != number { - test.Fatalf("%d != %d", decoded, number) - } - } -} - -func TestI64(test *testing.T) { - var buffer [16]byte - err := EncodeI64[uint64](buffer[:], 5) - if err.Error() != "encoding int64: wrong buffer length" { test.Fatal(err) } - err = EncodeI64[uint64](buffer[:0], 5) - if err.Error() != "encoding int64: wrong buffer length" { test.Fatal(err) } - _, err = DecodeI64[uint64](buffer[:]) - if err.Error() != "decoding int64: wrong buffer length" { test.Fatal(err) } - _, err = DecodeI64[uint64](buffer[:0]) - if err.Error() != "decoding int64: wrong buffer length" { test.Fatal(err) } - - for _ = range largeNumberNTestRounds { - number := uint64(rand.Int()) - err := EncodeI64[uint64](buffer[:8], number) - if err != nil { test.Fatal(err) } - decoded, err := DecodeI64[uint64](buffer[:8]) - if err != nil { test.Fatal(err) } - if decoded != number { - test.Fatalf("%d != %d", decoded, number) - } - } -} - -func TestString(test *testing.T) { - var buffer [16]byte - err := EncodeString[string](buffer[:], "hello") - if !errIs(err, ErrWrongBufferLength, "encoding string: wrong buffer length") { test.Fatal(err) } - err = EncodeString[string](buffer[:0], "hello") - if !errIs(err, ErrWrongBufferLength, "encoding string: wrong buffer length") { test.Fatal(err) } - _, err = DecodeString[string](buffer[:]) - if err != nil { test.Fatal(err) } - _, err = DecodeString[string](buffer[:0]) - if err != nil { test.Fatal(err) } - - for _ = range largeNumberNTestRounds { - length := rand.Intn(16) - str := randString(length) - err := EncodeString[string](buffer[:length], str) - if err != nil { test.Fatal(err) } - decoded, err := DecodeString[string](buffer[:length]) - if err != nil { test.Fatal(err) } - if decoded != str { - test.Fatalf("%s != %s", decoded, str) - } - } -} - -func TestI8Array(test *testing.T) { - var buffer [64]byte - err := EncodeI8Array[uint8](buffer[:], []uint8 { 0, 4, 50, 19 }) - if !errIs(err, ErrWrongBufferLength, "encoding []int8: wrong buffer length") { test.Fatal(err) } - err = EncodeI8Array[uint8](buffer[:0], []uint8 { 0, 4, 50, 19 }) - if !errIs(err, ErrWrongBufferLength, "encoding []int8: wrong buffer length") { test.Fatal(err) } - _, err = DecodeI8Array[uint8](buffer[:]) - if err != nil { test.Fatal(err) } - _, err = DecodeI8Array[uint8](buffer[:0]) - if err != nil { test.Fatal(err) } - - for _ = range largeNumberNTestRounds { - array := randInts[uint8](rand.Intn(16)) - length, _ := I8ArraySize(array) - if length != len(array) { test.Fatalf("%d != %d", length, len(array)) } - err := EncodeI8Array[uint8](buffer[:length], array) - if err != nil { test.Fatal(err) } - decoded, err := DecodeI8Array[uint8](buffer[:length]) - if err != nil { test.Fatal(err) } - if !slices.Equal(decoded, array) { - test.Fatalf("%v != %v", decoded, array) - } - } -} - -func TestI16Array(test *testing.T) { - var buffer [128]byte - err := EncodeI16Array[uint16](buffer[:], []uint16 { 0, 4, 50, 19 }) - if !errIs(err, ErrWrongBufferLength, "encoding []int16: wrong buffer length") { test.Fatal(err) } - err = EncodeI16Array[uint16](buffer[:0], []uint16 { 0, 4, 50, 19 }) - if !errIs(err, ErrWrongBufferLength, "encoding []int16: wrong buffer length") { test.Fatal(err) } - _, err = DecodeI16Array[uint16](buffer[:]) - if err != nil { test.Fatal(err) } - _, err = DecodeI16Array[uint16](buffer[:0]) - if err != nil { test.Fatal(err) } - - for _ = range largeNumberNTestRounds { - array := randInts[uint16](rand.Intn(16)) - length, _ := I16ArraySize(array) - if length != 2 * len(array) { test.Fatalf("%d != %d", length, 2 * len(array)) } - err := EncodeI16Array[uint16](buffer[:length], array) - if err != nil { test.Fatal(err) } - decoded, err := DecodeI16Array[uint16](buffer[:length]) - if err != nil { test.Fatal(err) } - if !slices.Equal(decoded, array) { - test.Fatalf("%v != %v", decoded, array) - } - } -} - -func TestI32Array(test *testing.T) { - var buffer [256]byte - err := EncodeI32Array[uint32](buffer[:], []uint32 { 0, 4, 50, 19 }) - if !errIs(err, ErrWrongBufferLength, "encoding []int32: wrong buffer length") { test.Fatal(err) } - err = EncodeI32Array[uint32](buffer[:0], []uint32 { 0, 4, 50, 19 }) - if !errIs(err, ErrWrongBufferLength, "encoding []int32: wrong buffer length") { test.Fatal(err) } - _, err = DecodeI32Array[uint32](buffer[:]) - if err != nil { test.Fatal(err) } - _, err = DecodeI32Array[uint32](buffer[:0]) - if err != nil { test.Fatal(err) } - - for _ = range largeNumberNTestRounds { - array := randInts[uint32](rand.Intn(16)) - length, _ := I32ArraySize(array) - if length != 4 * len(array) { test.Fatalf("%d != %d", length, 4 * len(array)) } - err := EncodeI32Array[uint32](buffer[:length], array) - if err != nil { test.Fatal(err) } - decoded, err := DecodeI32Array[uint32](buffer[:length]) - if err != nil { test.Fatal(err) } - if !slices.Equal(decoded, array) { - test.Fatalf("%v != %v", decoded, array) - } - } -} - -func TestI64Array(test *testing.T) { - var buffer [512]byte - err := EncodeI64Array[uint64](buffer[:], []uint64 { 0, 4, 50, 19 }) - if !errIs(err, ErrWrongBufferLength, "encoding []int64: wrong buffer length") { test.Fatal(err) } - err = EncodeI64Array[uint64](buffer[:0], []uint64 { 0, 4, 50, 19 }) - if !errIs(err, ErrWrongBufferLength, "encoding []int64: wrong buffer length") { test.Fatal(err) } - _, err = DecodeI64Array[uint64](buffer[:]) - if err != nil { test.Fatal(err) } - _, err = DecodeI64Array[uint64](buffer[:0]) - if err != nil { test.Fatal(err) } - - for _ = range largeNumberNTestRounds { - array := randInts[uint64](rand.Intn(16)) - length, _ := I64ArraySize(array) - if length != 8 * len(array) { test.Fatalf("%d != %d", length, 8 * len(array)) } - err := EncodeI64Array[uint64](buffer[:length], array) - if err != nil { test.Fatal(err) } - decoded, err := DecodeI64Array[uint64](buffer[:length]) - if err != nil { test.Fatal(err) } - if !slices.Equal(decoded, array) { - test.Fatalf("%v != %v", decoded, array) - } - } -} - -func TestStringArray(test *testing.T) { - var buffer [8192]byte - err := EncodeStringArray[string](buffer[:], []string { "0", "4", "50", "19" }) - if !errIs(err, ErrWrongBufferLength, "encoding []string: wrong buffer length") { test.Fatal(err) } - err = EncodeStringArray[string](buffer[:0], []string { "0", "4", "50", "19" }) - if !errIs(err, ErrWrongBufferLength, "encoding []string: wrong buffer length") { test.Fatal(err) } - _, err = DecodeStringArray[string](buffer[:0]) - if err != nil { test.Fatal(err) } - - for _ = range largeNumberNTestRounds { - array := randStrings[string](rand.Intn(16), 16) - length, _ := StringArraySize(array) - // TODO test length - err := EncodeStringArray[string](buffer[:length], array) - if err != nil { test.Fatal(err) } - decoded, err := DecodeStringArray[string](buffer[:length]) - if err != nil { test.Fatal(err) } - if !slices.Equal(decoded, array) { - test.Fatalf("%v != %v", decoded, array) - } - } -} - -func TestU16CastSafe(test *testing.T) { - number, ok := U16CastSafe(90_000) - if ok { test.Fatalf("false positive: %v, %v", number, ok) } - number, ok = U16CastSafe(-478) - if ok { test.Fatalf("false positive: %v, %v", number, ok) } - number, ok = U16CastSafe(3870) - if !ok { test.Fatalf("false negative: %v, %v", number, ok) } - if got, correct := number, uint16(3870); got != correct { - test.Fatalf("not equal: %v %v", got, correct) - } - number, ok = U16CastSafe(0) - if !ok { test.Fatalf("false negative: %v, %v", number, ok) } - if got, correct := number, uint16(0); got != correct { - test.Fatalf("not equal: %v %v", got, correct) - } -} - -func randString(length int) string { - buffer := make([]byte, length) - for index := range buffer { - buffer[index] = randStringBytes[rand.Intn(len(randStringBytes))] - } - return string(buffer) -} - -func randInts[T interface { ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 }] (length int) []T { - buffer := make([]T, length) - for index := range buffer { - buffer[index] = T(rand.Int()) - } - return buffer -} - -func randStrings[T interface { ~string }] (length, maxItemLength int) []T { - buffer := make([]T, length) - for index := range buffer { - buffer[index] = T(randString(rand.Intn(maxItemLength))) - } - return buffer -} - -func errIs(err error, wraps error, description string) bool { - return err != nil && (wraps == nil || errors.Is(err, wraps)) && err.Error() == description -}