message-size-increase #3

Open
sashakoshka wants to merge 109 commits from message-size-increase into main
32 changed files with 3094 additions and 1654 deletions

View File

@ -1,8 +1,11 @@
package hopp package hopp
import "io"
import "net" import "net"
// import "time" // import "time"
const defaultSizeLimit int64 = 1024 * 1024 // 1 megabyte
// Conn is a HOPP connection. // Conn is a HOPP connection.
type Conn interface { type Conn interface {
// Close closes the connection. Any blocked operations on the connection // Close closes the connection. Any blocked operations on the connection
@ -19,22 +22,39 @@ type Conn interface {
// AcceptTrans accepts a transaction from the other party. This must // AcceptTrans accepts a transaction from the other party. This must
// be called in a loop to avoid the connection locking up. // be called in a loop to avoid the connection locking up.
AcceptTrans() (Trans, error) AcceptTrans() (Trans, error)
// SetSizeLimit sets a limit (in bytes) for how large messages can be.
// By default, this limit is 1 megabyte.
SetSizeLimit(limit int64)
} }
// Trans is a HOPP transaction. // Trans is a HOPP transaction. Methods of this interface are not safe for
// concurrent use with the exception of the Close and ID methods. The
// recommended use case is one goroutine per transaction.
type Trans interface { type Trans interface {
// Close closes the transaction. Any blocked operations will be // Close closes the transaction. Any blocked operations will be
// unblocked and return errors. // unblocked and return errors. This method is safe for concurrent use.
Close() error Close() error
// ID returns the transaction ID. This must not change, and it must be // ID returns the transaction ID. This must not change, and it must be
// unique within the connection. // unique within the connection. This method is safe for concurrent use.
ID() int64 ID() int64
// TODO: add methods for setting send and receive deadlines // TODO: add methods for setting send and receive deadlines
// Send sends a message. // Send sends a message. This method is not safe for concurrent use.
Send(method uint16, data []byte) error Send(method uint16, data []byte) error
// Receive receives a message. // SendWriter sends data written to an [io.Writer]. The writer must be
// closed after use. Closing the writer flushes any data that hasn't
// been written yet. Any writer previously opened through this function
// will be discarded. This method is not safe for concurrent use, and
// neither is its result.
SendWriter(method uint16) (io.WriteCloser, error)
// Receive receives a message. This method is not safe for concurrent
// use.
Receive() (method uint16, data []byte, err error) Receive() (method uint16, data []byte, err error)
// ReceiveReader receives a message as an [io.Reader]. Any reader
// previously opened through this function will be discarded. This
// method is not safe for concurrent use, and neither is its result.
ReceiveReader() (method uint16, data io.Reader, err error)
} }

109
design/pdl-compiler.md Normal file
View File

@ -0,0 +1,109 @@
# PDL Compiler Specification
Given one or more PDL files representing a protocol, the compiler shall generate
a Go package named "protocol", which shall contain types for message and type
definitions, as well as encoding and decoding methods.
## Static Section
The compiler shall write a static section alongside the generated code. It
shall contain this text:
```go
// Table is a KTV table with an undefined schema.
type Table map[uint16] any
// Message is any message that can be sent along this protocol.
type Message interface {
codec.Encodable
codec.Decodable
// Method returns the method code of the message.
Method() uint16
}
```
## Preamble
At the start of each file but after the package name, the compiler shall emit
this text:
```go
/* # Do not edit this package by hand!
*
* This file was automatically generated by the Holanet PDL compiler. The
* source file is located at <path>
* Please edit that file instead, and re-compile it to this location.
*
* HOPP, TAPE, METADAPT, PDL/0 (c) 2025 holanet.xyz
*/
```
Where `<path>` is the path of the protocol definition file relative to the
generated file.
## Message Definitions
For each defined message, the compiler shall generate a Go type named
`MessageName`, where `Name` is the name of the message as written in its
definition. The message shall be encodable, and shall have `Encode` and `Decode`
methods as described below.
All messages shall satisfy a `Message` interface, which is defined in the
static section.
## Type Definitions
For each defined type, the compiler shall generate a Go type with the same name
as written in its definition. The Go type shall be encodable, and shall have
`EncodeValue`, `DecodeValue`, and `Tag` methods as described below.
## Encoding and Decoding Methods
Each message shall be given an `Encode` method and a `Decode` method,
which shall take in a `codec.Encoder` and a `codec.Decoder` respectively. Both
return an `(int, error)` pair describing the amount of bytes written and an
error if the write stopped early. `Encode` shall encode the data within the
message to the given encoder, and `Decode` shall decode data from the given
decoder and place it in the type's value. The methods shall not retain or close
any encoders or decoders they are given. Both methods shall have pointer
receivers. In effect, these methods shall satisfy `codec.Encodable` and
`codec.Decodable`.
Each defined type shall be given an `EncodeValue` method and a `DecodeValue`
method, which shall both take in a `tape.Tag`, then a `codec.Encoder` and a
`codec.Decoder` respectively. These methods shall encode and decode the value
according to the CN given by the tag. The TN shall be ignored. The message shall
also have a method `Tag` that takes no arguments and returns the preferred tag
of the type including the TN and CN.
## Connection
The compiler shall generate a `Conn` struct which embeds a `hopp.Conn`, which
is the real "porcelain" of the generated code. It shall provide methods to
create and accept transactions. Each transaction shall be a struct which embeds
a `hopp.Trans`, and shall have methods for sending and receiving messages.
### Sending
To send a message along a transaction, the program shall:
1. Obtain the method code from the message
3. Obtain a writer from the connection using the method code
4. Wrap the writer in a `codec.Encoder`
5. Use the encoder to encode the message
6. Close the writer
### Receiving
To receiving a message from a transaction, the program shall:
1. Obtain a method code and reader from the connection
2. Wrap the reader in a `codec.Decoder`
3. Switch on the method code, and decode the correct message using the decoder
4. Return the message to the caller as a value
The recieve function must return the message as a value instead of a pointer in
order to avoid making an allocation. Because of this, the return value must be
`any` instead of `Message`. The caller must then use a type switch to figure out
what message was sent.

104
design/pdl-language.md Normal file
View File

@ -0,0 +1,104 @@
# PDL Language Definition
PDL allows defining a protocol using HOPP and TAPE.
## Data Types
| Syntax | TN | CN | Description
| ---------- | ------- | -: | -----------
| I5 | SI | |
| I8 | LI | 0 |
| I16 | LI | 1 |
| I32 | LI | 3 |
| I64 | LI | 7 |
| I128[^2] | LI | 15 |
| I256[^2] | LI | 31 |
| U5 | SI | |
| U8 | LI | 0 |
| U16 | LI | 1 |
| U32 | LI | 3 |
| U64 | LI | 7 |
| U128[^2] | LI | 15 |
| U256[^2] | LI | 31 |
| F16 | FP | 1 |
| F32 | FP | 3 |
| F64 | FP | 7 |
| F128[^2] | FP | 15 |
| F256[^2] | FP | 31 |
| String | SBA/LBA | * | UTF-8 string
| Buffer | SBA/LBA | * | Byte array
| []\<TYPE\> | OTA | * | Array of any type[^1]
| Table | KTV | * | Table with undefined schema
| {...} | KTV | * | Table with defined schema
[^1]: Excluding SI and SBA. I5 and U5 cannot be used in an array, but String and
Buffer are simply forced to use their "long" variant.
[^2]: Some systems may lack support for this.
## Tokens
PDL files are divided into tokens, which assemble together into larger language
structures. They are separated by whitespace.
| Name | Syntax | Description
| -------- | ------------------ | -----------
| Method | `M[0-9A-Fa-f]{4}` | A 16-bit hexadecimal method code.
| Key | `[0-9A-Fa-f]{4}` | A 16-bit hexadecimal table key.
| Ident | `[A-Z][A-Za-z0-9]` | An identifier.
| Comma | `,` | A comma separator.
| LBrace | `{` | A left curly brace.
| RBrace | `}` | A right curly brace.
| LBracket | `[` | A left square bracket.
| RBracket | `]` | A right square bracket.
## Syntax
Types are expressed with an Ident. A table can be used by either writing the
name of the type (Table), or by defining a schema with curly braces. Arrays must
be expressed using two matching square brackets before their element type.
A table schema contains comma-separated fields in-between its braces. Each field
has three parts: the key number (Key), the field name (Ident), and the field
type. Tables, Arrays, etc. can be nested.
Files directly contain messages and types, which start with a Method token and
an Ident token respectively. A message consists of the method code (Method), the
message name (Ident), and the message's root type. This is usually a table, but
can be anything.
Here is an example of all that:
```
M0000 Connect {
0000 Name String,
0001 Password String,
}
M0001 UserList {
0000 Users []User,
}
User {
0000 Name String,
0001 Bio String,
0002 Followers U32,
}
```
## EBNF Description
Below is an EBNF description of the language.
```
<file> -> (<message> | <typedef)*
<method> -> /M[0-9A-Fa-f]{4}/
<key> -> /[0-9A-Fa-f]{4}/
<ident> -> /[A-Z][A-Za-z0-9]/
<field> -> <key> <ident> <type>
<type> -> <ident>
| "[" "]" <type>
| "{" (<field> ",")* [<field>] "}"
<message> -> <method> <ident> <type>
<typedef> -> <ident> <type>
```

View File

@ -18,12 +18,10 @@ dependant on which transport is being used.
A message refers to a block of octets sent within a transaction, paired with an A message refers to a block of octets sent within a transaction, paired with an
unsigned 16-bit method code. The order of messages within a given transaction is unsigned 16-bit method code. The order of messages within a given transaction is
preserved, but the order of messages accross the entire connection is not preserved, but the order of messages accross the entire connection is not
guaranteed. guaranteed. There is no functional limit on the size of a message payload, but
there may be one depending on which
The message payload must be 65,535 (unsigned 16-bit integer limit) octets or [METADAPT sub-protocol](#message-and-transaction-demarcation-protocol-metadapt)
smaller in length. This does not include the method code. Applications are free is in use.
to send whatever data they wish as the payload, but TAPE is recommended for
encoding it.
Method codes should be written in upper-case base 16 with the prefix "M" in Method codes should be written in upper-case base 16 with the prefix "M" in
logs, error messages, documentation, etc. For example, the method code 62,206 in logs, error messages, documentation, etc. For example, the method code 62,206 in
@ -37,100 +35,92 @@ fucking with you.
## Table Pair Encoding (TAPE) ## Table Pair Encoding (TAPE)
The Table Pair Encoding (TAPE) scheme is a method for encoding structured data The Table Pair Encoding (TAPE) scheme is a method for encoding structured data
within HOPP messages. It defines standard binary encoding methods for common within HOPP messages. It defines standard binary encoding methods for common
data types, as well as a corruption-resistant table structure that maps numeric data types, as well as aggregate data types such as tables and arrays. It is
IDs to values. It is designed to allow applications to be presented with data designed to allow applications to be presented with data they are not equipped
they are not equipped to handle while continuing to function normally. This to handle while continuing to function normally. This enables backwards
enables backwards compatibile application protocol changes. compatibile application protocol changes.
### Table Structure TAPE expresses types using tags. A tag is 8 bits in size, and is divided into
A table is divided into two sections: the header, and the values. The header two parts: the Type Number (TN), and the Configuration Number (CN). The TN is 3
begins with the number (U16) of pairs in the table, which is then followed by bits, and the CN is 5 bits. Both are interpreted as unsigned integers. Both
that many tag-offset pairs. A tag-offset pair consists of a numerical (U16) tag, sides of the connection must agree on the semantic meaning of the values and
followed the position (U16) of the value relative to the start of the values their arrangement.
section. The values section contains the value data for each pair, where the
start of each value is determined by its offset, and the end is determined by
the offset of the next value, or the end of the message if there is no value
after it.
Both sections must be in the same order, and because of this, each value offset A TAPE structure begins with one root, which consists of a tag followed by a
must be greater than or equal to the last. If a message has erratic structure payload. This is usually an aggregate data structure such as KTV to allow for
(such as unordered or out-of-bounds offsets), implementations may opt to discard several different values.
only the erratic pairs, as well as the pairs directly before those.
TAPE is based on an encoding method previously developed by silt.
### Data Value Types ### Data Value Types
The table below lists all data value types supported by TAPE. The table below lists all data value types supported by TAPE. They are discussed
in detail in the following sections.
| Name | Size | Description | Encoding Method | TN | Bits | Name | Description
| ----------- | --------------: | --------------------------- | --------------- | -: | ---: | ---- | -----------
| I8 | 1 | A signed 8-bit integer | BETC | 0 | 000 | SI | Small integer
| I16 | 2 | A signed 16-bit integer | BETC | 1 | 001 | LI | Large integer
| I32 | 4 | A signed 32-bit integer | BETC | 2 | 010 | FP | Floating point
| I64 | 8 | A signed 64-bit integer | BETC | 3 | 011 | SBA | Small byte array
| U8 | 1 | An unsigned 8-bit integer | BEU | 4 | 100 | LBA | Large byte array
| U16 | 2 | An unsigned 16-bit integer | BEU | 5 | 101 | OTA | One-tag array
| U32 | 4 | An unsigned 32-bit integer | BEU | 6 | 110 | KTV | Key-tag-value table
| U64 | 8 | An unsigned 64-bit integer | BEU | 7 | 111 | N/A | Reserved
| Array[^1] | SOP[^2] | An array of any above type | PASTA
| String | N/A | A UTF-8 string | UTF-8
| StringArray | n * 2 + SOP[^2] | An array the String type | VILA
[^1]: Array types are written as <E>Array, where <E> is the element type. For #### Small Integer (SI)
example, an array of I32 would be written as I32Array. StringArray still follows SI encodes an integer of up to 5 bits, which are stored in the CN. It has no
this rule, even though it is encoded differently from other arrays. Nesting payload. Whether the bits are interpreted as unsigned or as signed two's
arrays inside of arrays is prohibited. This problem can be avoided in most cases complement is semantic information and must be agreed upon by both sides of the
by effectively utilizing the table structure, or by improving the design of connection. Thus, the value may range from 0 to 31 if unsigned, and from -16 to
your protocol. 17 if signed.
[^2]: SOP (sum of parts) refers to the sum of the size of every item in a data #### Large Integer (LI)
structure. LI encodes an integer of up to 256 bits, which are stored in the payload. The CN
determine the length of the payload in bytes. The integer is big-endian. Whether
the payload is interpreted as unsigned or as signed two's complement is semantic
information and must be agreed upon by both sides of the connection. Thus, the
value may range from 0 to 31 if unsigned, and from -16 to 17 if signed.
### Encoding Methods #### Floating Point (FP)
Below are all encoding methods supported by TAPE. FP encodes an IEEE 754 floating point number of up to 256 bits, which are stored
in the payload. The CN determines the length of the payload in bytes, and it may
only be one of these values: 16, 32, 64, 128, or 256.
#### BETC #### Small Byte Array (SBA)
Big-Endian, Two's Complement signed integer. The size is defined as the least SBA encodes an array of up to 32 bytes, which are stored in the paylod. The
amount of whole octets which can fit all bits in the integer, regardless if the CN determines the length of the payload in bytes.
bits are on or off. Therefore, the size cannot change at runtime.
#### BEU #### Large Byte Array (LBA)
Big-Endian, Unsigned integer. The size is defined as the least amount of whole LBA encodes an array of up to 2^256 bytes, which are stored in the second part
octets which can fit all bits in the integer, regardless if the bits are on or of the payload, directly after the length. The length of the data length field
off. Therefore, the size cannot change at runtime. in bytes is determined by the CN.
#### PASTA #### One-Tag Array (OTA)
Packed Single-Type Array. The size is defined as the size of an individual item OTA encodes an array of up to 2^256 items, which are stored in the payload after
times the number of items. Items are placed one after the other with no gaps the length field and the item tag, where the length field comes first. Each item
in-between them, except as required to align the start of each item to the must be the same length, as they all share the same tag. The length of the data
nearest whole octet. Items should be of the same type and must be of the same length field in bytes is determined by the CN.
size.
#### UTF-8 #### Key-Tag-Value Table (KTV)
UTF-8 string. The size is defined as the least amount of whole octets which can KTV encodes a table of up to 2^256 key/value pairs, which are stored in the
fit all bits in the string, regardless if the bits are on or off. The size of payload after the length field. The pairs themselves consist of a 16-bit
this type is not fixed and may change at runtime, so this needs to be accounted unsigned big-endian key followed by a tag and then the payload. Pair values can
for during use. be of different types and sizes. The order of the pairs is not significant and
should never be treated as such.
#### VILA
Variable Item Length Array. The size is defined as the least amount of whole
octets which can fit each item plus one U16 per item. The size of this type is
not fixed and may change at runtime, so this needs to be accounted for during
use. The amount of items must be greater than zero. Items are each prefixed by
their size (in octets) encoded as a U16, and they are placed one after the other
with no gaps in-between them, except as required to align the start of each item
to the nearest whole octet. Items should be of the same type but do not need to
be of the same size.
## Transports ## Transports
A transport is a protocol that HOPP connections can run on top of. HOPP A transport is a protocol that HOPP connections can run on top of. HOPP
currently supports the QUIC transport protocol for communicating between currently supports the QUIC transport protocol for communicating between
machines, and UNIX domain sockets for quicker communication among applications machines, TCP/TLS for legacy systems that do not support QUIC, and UNIX domain
on the same machine. Both protocols are supported through METADAPT. sockets for faster communication among applications on the same machine. Both
protocols are supported through METADAPT.
## Message and Transaction Demarcation Protocol (METADAPT) ## Message and Transaction Demarcation Protocol (METADAPT)
The Message and Transaction Demarcation Protocol is used to break one or more The Message and Transaction Demarcation Protocol is used to break one or more
reliable data streams into transactions, which are broken down further into reliable data streams into transactions, which are broken down further into
messages. A message, as well as its associated metadata (length, transaction, messages. The representation of a message (or a part thereof) on the protocol,
method, etc.) together is referred to as METADAPT Message Block (MMB). including its associated metadata (length, transaction, method, etc.) is
referred to as METADAPT Message Block (MMB).
For transports that offer multiple multiplexed data streams that can be created For transports that offer multiple multiplexed data streams that can be created
and destroyed on-demand (such as QUIC) each stream is used as a transaction. If and destroyed on-demand (such as QUIC) each stream is used as a transaction. If
@ -145,8 +135,12 @@ METADAPT-A requires a transport which offers a single full-duplex data stream
that persists for the duration of the connection. All transactions are that persists for the duration of the connection. All transactions are
multiplexed onto this single stream. Each MMB contains a 12-octet long header, multiplexed onto this single stream. Each MMB contains a 12-octet long header,
with the transaction ID, then the method, and then the payload size (in octets). with the transaction ID, then the method, and then the payload size (in octets).
The transaction ID is encoded as an I64, and the method and payload size are The transaction ID is encoded as an I64, the method is encoded as a U16 and the
both encoded as U16s. The remainder of the message is the payload. Since each and payload size is encoded as a U64. Only the 63 least significant bits of the
payload size describe the actual size, the most significant bit controlling
chunking. See the section on chunking for more information.
The remainder of the message is the payload. Since each
MMB is self-describing, they are sent sequentially with no gaps in-between them. MMB is self-describing, they are sent sequentially with no gaps in-between them.
Transactions "open" when the first message with a given transaction ID is sent. Transactions "open" when the first message with a given transaction ID is sent.
@ -162,13 +156,24 @@ used up, the connection must fail. Don't worry about this though, because the
sun will have expanded to swallow earth by then. Your connection will not last sun will have expanded to swallow earth by then. Your connection will not last
that long. that long.
#### Message Chunking
The most significant bit of the payload size field of an MMB is called the Chunk
Control Bit (CCB). If the CCB of a given MMB is zero, the represented message is
interpreted as being self-contained and the data is processed immediately. If
the CCB is one, the message is interpreted as being chunked, with the data of
the current MMB being the first chunk. The data of further MMBs sent along the
transaction will be appended to the message until an MMB is read with a zero
CCB, in which case the MMB will be the last chunk and any more MMBs will be
interpreted as normal.
### METADAPT-B ### METADAPT-B
METADAPT-B requires a transport which offers multiple multiplexed full-duplex METADAPT-B requires a transport which offers multiple multiplexed full-duplex
data streams per connection that can be created and destroyed on-demand. Each data streams per connection that can be created and destroyed on-demand. Each
data stream is used as an individual transaction. Each MMB contains a 4-octet data stream is used as an individual transaction. Each MMB contains a 4-octet
long header with the method and then the payload size (in octets) both encoded long header with the method and then the payload size (in octets) encoded as a
as U16s. The remainder of the message is the payload. Since each MMB is U16 and U64 respectively. The remainder of the message is the payload. Since
self-describing, they are sent sequentially with no gaps in-between them. each MMB is self-describing, they are sent sequentially with no gaps in-between
them.
The ID of any transaction will reflect the ID of its corresponding stream. The The ID of any transaction will reflect the ID of its corresponding stream. The
lifetime of the transaction is tied to the lifetime of the stream, that is to lifetime of the transaction is tied to the lifetime of the stream, that is to

23
dial.go
View File

@ -1,9 +1,9 @@
package hopp package hopp
import "net" import "net"
import "errors"
import "context" import "context"
import "crypto/tls" import "crypto/tls"
import "github.com/quic-go/quic-go"
// Dial opens a connection to a server. The network must be one of "quic", // Dial opens a connection to a server. The network must be one of "quic",
// "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, "quic4" and // "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, "quic4" and
@ -19,9 +19,8 @@ type Dialer struct {
} }
// Dial opens a connection to a server. The network must be one of "quic", // Dial opens a connection to a server. The network must be one of "quic",
// "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, "quic4" and // "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, quic is not
// "quic6" don't do anything as the quic-go package doesn't seem to support this // supported.
// behavior.
func (diale Dialer) Dial(ctx context.Context, network, address string) (Conn, error) { func (diale Dialer) Dial(ctx context.Context, network, address string) (Conn, error) {
switch network { switch network {
case "quic", "quic4", "quic6": return diale.dialQUIC(ctx, network, address) case "quic", "quic4", "quic6": return diale.dialQUIC(ctx, network, address)
@ -31,12 +30,7 @@ func (diale Dialer) Dial(ctx context.Context, network, address string) (Conn, er
} }
func (diale Dialer) dialQUIC(ctx context.Context, network, address string) (Conn, error) { func (diale Dialer) dialQUIC(ctx context.Context, network, address string) (Conn, error) {
// sorry i fucking lied to you about the network parameter. for all return nil, errors.New("quic is not yet implemented")
// quic-go's bullshit bloat, it doesnt even support that. not even when
// instantiating a transport. go figure :/
conn, err := quic.DialAddr(ctx, address, tlsConfig(diale.TLSConfig), quicConfig())
if err != nil { return nil, err }
return AdaptB(quicMultiConn { underlying: conn }), nil
} }
func (diale Dialer) dialUnix(ctx context.Context, network, address string) (Conn, error) { func (diale Dialer) dialUnix(ctx context.Context, network, address string) (Conn, error) {
@ -60,15 +54,6 @@ func tlsConfig(conf *tls.Config) *tls.Config {
return conf return conf
} }
func quicConfig() *quic.Config {
return &quic.Config {
// TODO: perhaps we might want to put something here
// the quic config shouldn't be exported, just set up
// automatically. we can't have that strangely built quic-go
// package be part of the API, or any third-party packages for
// that matter. it must all be abstracted away.
}
}
func quicNetworkToUDPNetwork(network string) (string, error) { func quicNetworkToUDPNetwork(network string) (string, error) {
switch network { switch network {

View File

@ -2,301 +2,624 @@ package generate
import "io" import "io"
import "fmt" import "fmt"
import "bufio" import "maps"
import "math"
import "slices"
import "strings" import "strings"
import "git.tebibyte.media/sashakoshka/hopp/tape"
const send = const imports =
`// Send sends one message along a transaction. `
func Send(trans hopp.Trans, message hopp.Message) error { import "git.teibibyte.media/sashakoshka/hopp/tape"
buffer, err := message.MarshalBinary()
if err != nil { return err }
return trans.Send(message.Method(), buffer)
}
` `
// ResolveType resolves a HOPP type name to a Go type. For now, it supports all const preamble = `
// data types defined in TAPE. /* # Do not edit this package by hand!
func (this *Protocol) ResolveType(hopp string) (string, error) { *
switch hopp { * This file was automatically generated by the Holanet PDL compiler. The
case "I8": return "int8", nil * source file is located at <path>
case "I16": return "int16", nil * Please edit that file instead, and re-compile it to this location.
case "I32": return "int32", nil *
case "I64": return "int64", nil * HOPP, TAPE, METADAPT, PDL/0 (c) 2025 holanet.xyz
case "U8": return "uint8", nil */
case "U16": return "uint16", nil `
case "U32": return "uint32", nil
case "U64": return "uint64", nil const static = `
case "I8Array": return "[]int8", nil // Table is a KTV table with an undefined schema.
case "I16Array": return "[]int16", nil type Table map[uint16] any
case "I32Array": return "[]int32", nil
case "I64Array": return "[]int64", nil // Message is any message that can be sent along this protocol.
case "U8Array": return "[]uint8", nil type Message interface {
case "U16Array": return "[]uint16", nil tape.Encodable
case "U32Array": return "[]uint32", nil tape.Decodable
case "U64Array": return "[]uint64", nil
case "String": return "string", nil // Method returns the method code of the message.
case "StringArray": return "[]string", nil Method() uint16
default: return "", fmt.Errorf("unknown type: %s", hopp) }
} `
// Generator converts protocols into Go code.
type Generator struct {
// Output is where the generated code will be sent.
Output io.Writer
// PackageName is the package name that will be used in the file. If
// left empty, the default is "protocol".
PackageName string
nestingLevel int
protocol *Protocol
} }
// Generate turns this protocol into code. The package name for the generated func (this *Generator) Generate(protocol *Protocol) (n int, err error) {
// code must be specified. this.nestingLevel = 0
func (this *Protocol) Generate(writer io.Writer, packag string) error { this.protocol = protocol
out := bufio.NewWriter(writer) defer func() { this.protocol = nil }()
defer out.Flush()
fmt.Fprintf(out, "package %s\n\n", packag) // preamble and static section
fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp\"\n") packageName := "protocol"
fmt.Fprintf(out, "import \"git.tebibyte.media/sashakoshka/hopp/tape\"\n\n") if this.PackageName != "" {
packageName = this.PackageName
}
nn, err := this.iprintf("package %s\n", packageName)
n += nn; if err != nil { return n, err }
nn, err = this.print(preamble)
n += nn; if err != nil { return n, err }
nn, err = this.print(imports)
n += nn; if err != nil { return n, err }
nn, err = this.print(static)
n += nn; if err != nil { return n, err }
fmt.Fprintf(out, send) // type definitions
this.receive(out) for _, name := range slices.Sorted(maps.Keys(protocol.Types)) {
nn, err := this.generateTypedef(name, protocol.Types[name])
for _, message := range this.Messages { n += nn; if err != nil { return n, err }
err := this.defineMessage(out, message)
if err != nil { return err }
err = this.marshalMessage(out, message)
if err != nil { return err }
err = this.unmarshalMessage(out, message)
if err != nil { return err }
} }
return nil // messages
for _, method := range slices.Sorted(maps.Keys(protocol.Messages)) {
nn, err := this.generateMessage(method, protocol.Messages[method])
n += nn; if err != nil { return n, err }
}
return n, nil
} }
func (this *Protocol) receive(out io.Writer) error { func (this *Generator) generateTypedef(name string, typ Type) (n int, err error) {
fmt.Fprintf(out, "// Receive receives one message from a transaction.\n") // type definition
fmt.Fprintf(out, "func Receive(trans hopp.Trans) (hopp.Message, error) {\n") nn, err := this.iprintf(
fmt.Fprintf(out, "\tmethod, data, err := trans.Receive()\n") "\n// %s represents the protocol data type %s.\n",
fmt.Fprintf(out, "\tif err != nil { return nil, err }\n") name, name)
fmt.Fprintf(out, "\tswitch method {\n") n += nn; if err != nil { return n, err }
for _, message := range this.Messages { nn, err = this.iprintf("type %s ", name)
fmt.Fprintf(out, "\tcase 0x%04X:\n", message.Method) n += nn; if err != nil { return n, err }
fmt.Fprintf(out, "\t\tmessage := &Message%s { }\n", message.Name) nn, err = this.generateType(typ)
fmt.Fprintf(out, "\t\terr := message.UnmarshalBinary(data)\n") n += nn; if err != nil { return n, err }
fmt.Fprintf(out, "\t\tif err != nil { return nil, err }\n") nn, err = this.println()
fmt.Fprintf(out, "\t\treturn message, nil\n") n += nn; if err != nil { return n, err }
}
fmt.Fprintf(out, "\tdefault: return nil, hopp.ErrUnknownMethod\n") // Tag method
fmt.Fprintf(out, "\t}\n") // to be honest we probably don't need this method at all
fmt.Fprintf(out, "}\n\n") // nn, err = this.iprintf("\n// Tag returns the preferred TAPE tag.\n")
return nil // n += nn; if err != nil { return n, err }
// nn, err = this.iprintf("func (this *%s) Tag() tape.Tag {\n", name)
// n += nn; if err != nil { return n, err }
// this.push()
// nn, err = this.iprintf("return ")
// n += nn; if err != nil { return n, err }
// nn, err = this.generateTag(typ, "(*this)")
// n += nn; if err != nil { return n, err }
// nn, err = this.println()
// n += nn; if err != nil { return n, err }
// this.pop()
// nn, err = this.iprintf("}\n")
// n += nn; if err != nil { return n, err }
// EncodeValue method
nn, err = this.iprintf(
"\n// EncodeValue encodes the value of this type without the " +
"tag. The value is\n// encoded according to the parameters " +
"specified by the tag, if possible.\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"func (this *%s) EncodeValue(encoder *tape.Encoder, tag tape.Tag) (n int, err error) {\n",
name)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("var nn int\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(typ, "(*this)", "tag")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
// DecodeValue method
nn, err = this.iprintf(
"\n // DecodeValue decodes the value of this type without " +
"the tag. The value is\n// decoded according to the " +
"parameters specified by the tag, if possible.\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"func (this *%s) DecodeValue(decoder *tape.Decoder, tag tape.Tag) (n int, err error) {\n",
name)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.generateDecodeValue(typ, "this", "tag")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
return n, nil
} }
func (this *Protocol) defineMessage(out io.Writer, message Message) error { // generateMessage generates the structure, as well as encoding decoding
fmt.Fprintln(out, comment("//", fmt.Sprintf("(%d) %s\n", message.Method, message.Doc))) // functions for the given message.
fmt.Fprintf(out, "type Message%s struct {\n", message.Name) func (this *Generator) generateMessage(method uint16, message Message) (n int, err error) {
for _, field := range message.Fields { nn, err := this.iprintf(
typ, err := this.ResolveType(field.Type) "\n// %s represents the protocol message M%04X %s.\n",
if err != nil { return err } message.Name, method, message.Name)
if field.Doc != "" { nn, err = this.iprintf("type %s ", this.resolveMessageName(message.Name))
fmt.Fprintf(out, "\t%s\n", comment("\t//", field.Doc)) n += nn; if err != nil { return n, err }
} nn, err = this.generateType(message.Type)
if field.Optional { n += nn; if err != nil { return n, err }
typ = fmt.Sprintf("hopp.Option[%s]", typ) nn, err = this.println()
} n += nn; if err != nil { return n, err }
fmt.Fprintf(
out, "\t/* %d */ %s %s\n",
field.Tag, field.Name, typ)
}
fmt.Fprintf(out, "}\n\n")
fmt.Fprintf(out, "// Method returns the method number of the message.\n") // Encode method
fmt.Fprintf(out, "func (msg Message%s) Method() uint16 {\n", message.Name) nn, err = this.iprintf("\n// Encode encodes this message's tag and value.\n")
fmt.Fprintf(out, "\treturn %d\n", message.Method) n += nn; if err != nil { return n, err }
fmt.Fprintf(out, "}\n\n") nn, err = this.iprintf(
return nil "func(this %s) Encode(encoder *tape.Encoder) (n int, err error) {\n",
this.resolveMessageName(message.Name))
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("tag := ")
n += nn; if err != nil { return n, err }
nn, err = this.generateTag(message.Type, "(*this)")
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err := encoder.WriteUint8()\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(message.Type, "(*this)", "tag")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("return n, nil\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
// TODO decode method
return n, nil
} }
func (this *Protocol) marshalMessage(out io.Writer, message Message) error { // generateEncodeValue generates code to encode a value of a specified type. It
fmt.Fprintf(out, "// MarshalBinary encodes the data in this message into a buffer.\n") // pulls from the variable (or parenthetical statement) specified by
fmt.Fprintf(out, "func (msg *Message%s) MarshalBinary() ([]byte, error) {\n", message.Name) // valueSource, and the value will be encoded according to the tag stored in
requiredCount := 0 // the variable (or parenthetical statement) specified by tagSource.
for _, field := range message.Fields { // the code generated is a BLOCK and expects these variables to be defined:
if !field.Optional { requiredCount ++ } //
// - encoder *tape.Encoder
// - n int
// - err error
// - nn int
func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource string) (n int, err error) {
switch typ := typ.(type) {
case TypeInt:
// SI: (none)
// LI: <value: IntN>
if typ.Bits <= 5 {
// SI stores the value in the tag, so we write nothing here
break
} }
fmt.Fprintf(out, "\tsize := 0\n") nn, err := this.iprintf("nn, err = encoder.WriteInt%d(%s)\n", bitsToBytes(typ.Bits), valueSource)
fmt.Fprintf(out, "\tcount := %d\n", requiredCount) n += nn; if err != nil { return n, err }
for _, field := range message.Fields { nn, err = this.generateErrorCheck()
fmt.Fprintf(out, "\toffset%s := size\n", field.Name) n += nn; if err != nil { return n, err }
if field.Optional { case TypeFloat:
fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name) // FP: <value: FloatN>
fmt.Fprintf(out, "\t\tcount ++\n") nn, err := this.iprintf("nn, err = encoder.WriteFloat%d(%s)\n", bitsToBytes(typ.Bits), valueSource)
fmt.Fprintf(out, "\t\t") n += nn; if err != nil { return n, err }
err := this.marshalSizeOf(out, field) nn, err = this.generateErrorCheck()
if err != nil { return err } n += nn; if err != nil { return n, err }
fmt.Fprintf(out, " }\n") case TypeString:
// see TypeBuffer
nn, err := this.generateEncodeValue(TypeBuffer { }, valueSource, tagSource)
n += nn; if err != nil { return n, err }
case TypeBuffer:
// SBA: <data: U8>*
// LBA: <length: UN> <data: U8>*
nn, err := this.iprintf("if %s.Is(tape.LBA) {\n", tagSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf(
"nn, err = encoder.WriteUintN(%s.CN(), uint64(len(%s)))\n",
tagSource, valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n", tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.Write([]byte(%s))\n", valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeArray:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
nn, err := this.iprintf(
"nn, err = encoder.WriteUintN(%s.CN(), uint64(len(%s)))\n",
tagSource, valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("{\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("itemTag := ")
n += nn; if err != nil { return n, err }
nn, err = this.generateTN(typ.Element)
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("for _, item := range %s {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("tag := ")
n += nn; if err != nil { return n, err }
nn, err = this.generateTag(typ.Element, "item")
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("tag.Is(tape.SBA) { continue }\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("tag.CN() > itemTag.CN() { largest = tag }\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if itemTag.Is(tape.SBA) { itemTag += 1 << 5 }\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("for _, item := range %s {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.generateEncodeValue(typ.Element, "item", "itemTag")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
case TypeTable:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf(
"nn, err = tape.EncodeAny(encoder, %s, %s)\n",
valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeTableDefined:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf(
"nn, err = encoder.WriteUintN(%s.CN(), %d)\n",
tagSource, len(typ.Fields))
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("{\n")
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("var tag tape.Tag\n")
n += nn; if err != nil { return n, err }
for key, field := range typ.Fields {
nn, err = this.iprintf("nn, err = encoder.WriteUint16(0x%04X)\n", key)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("tag = ")
n += nn; if err != nil { return n, err }
fieldSource := fmt.Sprintf("%s.%s", valueSource, field.Name)
nn, err = this.generateTag(field.Type, fieldSource)
n += nn; if err != nil { return n, err }
nn, err = this.println()
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.WriteUint8(uint8(tag))\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
nn, err = this.generateEncodeValue(field.Type, fieldSource, "tag")
n += nn; if err != nil { return n, err }
}
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
case TypeNamed:
// WHATEVER: [WHATEVER]
nn, err := this.iprintf("nn, err = %s.EncodeValue(encoder, %s)\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
}
return n, nil
}
// generateDencodeValue generates code to decode a value of a specified type. It
// overwrites memory pointed to by the variable (or parenthetical statement)
// specified by valueSource, and the value will be encoded according to the tag
// stored in the variable (or parenthetical statement) specified by tagSource.
// the code generated is a BLOCK and expects these variables to be defined:
//
// - decoder *tape.Decoder
// - n int
// - err error
// - nn int
func (this *Generator) generateDecodeValue(typ Type, valueSource, tagSource string) (n int, err error) {
// TODO
}
func (this *Generator) generateErrorCheck() (n int, err error) {
return this.iprintf("n += nn; if err != nil { return n, err }\n")
}
// generateTag generates the preferred TN and CN for the given type and value.
// The generated code is INLINE.
func (this *Generator) generateTag(typ Type, source string) (n int, err error) {
switch typ := typ.(type) {
case TypeInt:
if typ.Bits <= 5 {
nn, err := this.printf("tape.TagSI")
n += nn; if err != nil { return n, err }
} else { } else {
fmt.Fprintf(out, "\t{") nn, err := this.printf("tape.TagLI.WithCN(%d)", bitsToCN(typ.Bits))
fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name) n += nn; if err != nil { return n, err }
fmt.Fprintf(out, "\t\t")
err := this.marshalSizeOf(out, field)
if err != nil { return err }
fmt.Fprintf(out, " }\n")
} }
case TypeFloat:
nn, err := this.printf("tape.TagFP.WithCN(%d)", bitsToCN(typ.Bits))
n += nn; if err != nil { return n, err }
case TypeString:
nn, err := this.generateTag(TypeBuffer { }, source)
n += nn; if err != nil { return n, err }
case TypeBuffer:
nn, err := this.printf("bufferTag(%s)", source)
n += nn; if err != nil { return n, err }
case TypeArray:
nn, err := this.printf("arrayTag(tape.TagOTA.WithCN(tape.IntBytes(uint64(len(%s))))", source)
n += nn; if err != nil { return n, err }
case TypeTable:
nn, err := this.printf("tape.TagKTV.WithCN(tape.IntBytes(uint64(len(%s))))", source)
n += nn; if err != nil { return n, err }
case TypeTableDefined:
nn, err := this.printf("tape.TagKTV.WithCN(%d)", tape.IntBytes(uint64(len(typ.Fields))))
n += nn; if err != nil { return n, err }
case TypeNamed:
resolved, err := this.resolveTypeName(typ.Name)
if err != nil { return n, err }
nn, err := this.generateTag(resolved, source)
n += nn; if err != nil { return n, err }
} }
fmt.Fprintf(out, "\tif size > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n")
fmt.Fprintf(out, "\tif count > 0xFFFF { return nil, hopp.ErrPayloadTooLarge}\n") return n, nil
fmt.Fprintf(out, "\tbuffer := make([]byte, 2 + 4 * count + size)\n") }
fmt.Fprintf(out, "\ttape.EncodeI16(buffer[:2], uint16(count))\n")
for _, field := range message.Fields { // generateTN generates the appropriate TN for the given type. The generated
if field.Optional { // code is INLINE. The generated tag will have a CN as zero. For types that
fmt.Fprintf(out, "\tif value, ok := msg.%s.Get(); ok {\n", field.Name) // change TN based on their length, the TN capable of supporting more
fmt.Fprintf(out, "\t\t") // information is chosen.
err := this.marshalField(out, field) func (this *Generator) generateTN(typ Type) (n int, err error) {
if err != nil { return err } switch typ := typ.(type) {
fmt.Fprintf(out, "}\n") case TypeInt:
if typ.Bits <= 5 {
nn, err := this.printf("tape.TagSI")
n += nn; if err != nil { return n, err }
} else { } else {
fmt.Fprintf(out, "\t{") nn, err := this.printf("tape.TagLI")
fmt.Fprintf(out, "\tvalue := msg.%s\n", field.Name) n += nn; if err != nil { return n, err }
fmt.Fprintf(out, "\t\t")
err := this.marshalField(out, field)
if err != nil { return err }
fmt.Fprintf(out, "}\n")
} }
case TypeFloat:
nn, err := this.printf("tape.TagFP",)
n += nn; if err != nil { return n, err }
case TypeString:
nn, err := this.generateTN(TypeBuffer { })
n += nn; if err != nil { return n, err }
case TypeBuffer:
nn, err := this.printf("tape.TagLBA")
n += nn; if err != nil { return n, err }
case TypeArray:
nn, err := this.printf("tape.TagOTA")
n += nn; if err != nil { return n, err }
case TypeTable:
nn, err := this.printf("tape.TagKTV")
n += nn; if err != nil { return n, err }
case TypeTableDefined:
nn, err := this.printf("tape.TagKTV")
n += nn; if err != nil { return n, err }
case TypeNamed:
resolved, err := this.resolveTypeName(typ.Name)
if err != nil { return n, err }
nn, err := this.generateTN(resolved)
n += nn; if err != nil { return n, err }
} }
fmt.Fprintf(out, "\treturn buffer, nil\n")
fmt.Fprintf(out, "}\n\n") return n, nil
return nil
} }
func (this *Protocol) marshalSizeOf(out io.Writer, field Field) error { func (this *Generator) generateType(typ Type) (n int, err error) {
switch field.Type { switch typ := typ.(type) {
case "I8": fmt.Fprintf(out, "size += 1; _ = value") case TypeInt:
case "I16": fmt.Fprintf(out, "size += 2; _ = value") if err := this.validateIntBitSize(typ.Bits); err != nil {
case "I32": fmt.Fprintf(out, "size += 4; _ = value") return n, err
case "I64": fmt.Fprintf(out, "size += 8; _ = value")
case "U8": fmt.Fprintf(out, "size += 1; _ = value")
case "U16": fmt.Fprintf(out, "size += 2; _ = value")
case "U32": fmt.Fprintf(out, "size += 4; _ = value")
case "U64": fmt.Fprintf(out, "size += 8; _ = value")
case "I8Array": fmt.Fprintf(out, "size += len(value)")
case "I16Array": fmt.Fprintf(out, "size += len(value) * 2")
case "I32Array": fmt.Fprintf(out, "size += len(value) * 4")
case "I64Array": fmt.Fprintf(out, "size += len(value) * 8")
case "U8Array": fmt.Fprintf(out, "size += len(value)")
case "U16Array": fmt.Fprintf(out, "size += len(value) * 2")
case "U32Array": fmt.Fprintf(out, "size += len(value) * 4")
case "U64Array": fmt.Fprintf(out, "size += len(value) * 8")
case "String": fmt.Fprintf(out, "size += len(value)")
case "StringArray":
fmt.Fprintf(
out,
"for _, el := range value { size += 2 + len(el) }")
default:
return fmt.Errorf("unknown type: %s", field.Type)
} }
return nil if typ.Signed {
} nn, err := this.printf("int%d", typ.Bits)
n += nn; if err != nil { return n, err }
func (this *Protocol) marshalField(out io.Writer, field Field) error {
switch field.Type {
case "I8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name)
case "I16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name)
case "I32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name)
case "I64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name)
case "U8": fmt.Fprintf(out, "tape.EncodeI8(buffer[offset%s:], value)", field.Name)
case "U16": fmt.Fprintf(out, "tape.EncodeI16(buffer[offset%s:], value)", field.Name)
case "U32": fmt.Fprintf(out, "tape.EncodeI32(buffer[offset%s:], value)", field.Name)
case "U64": fmt.Fprintf(out, "tape.EncodeI64(buffer[offset%s:], value)", field.Name)
case "I8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name)
case "I16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name)
case "I32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name)
case "I64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name)
case "U8Array": fmt.Fprintf(out, "tape.EncodeI8Array(buffer[offset%s:], value)", field.Name)
case "U16Array": fmt.Fprintf(out, "tape.EncodeI16Array(buffer[offset%s:], value)", field.Name)
case "U32Array": fmt.Fprintf(out, "tape.EncodeI32Array(buffer[offset%s:], value)", field.Name)
case "U64Array": fmt.Fprintf(out, "tape.EncodeI64Array(buffer[offset%s:], value)", field.Name)
case "String": fmt.Fprintf(out, "tape.EncodeString(buffer[offset%s:], value)", field.Name)
case "StringArray": fmt.Fprintf(out, "tape.EncodeStringArray(buffer[offset%s:], value)", field.Name)
default:
return fmt.Errorf("unknown type: %s", field.Type)
}
return nil
}
func (this *Protocol) unmarshalMessage(out io.Writer, message Message) error {
fmt.Fprintf(out, "// UnmarshalBinary dencodes the data from a buffer int this message.\n")
fmt.Fprintf(out,
"func (msg *Message%s) UnmarshalBinary(buffer []byte) error {\n",
message.Name)
if len(message.Fields) < 1 {
fmt.Fprintf(out, "\t// no fields\n")
fmt.Fprintf(out, "\treturn nil\n")
fmt.Fprintf(out, "}\n\n")
return nil
}
fmt.Fprintf(out, "\tpairs, err := tape.DecodePairs(buffer)\n")
fmt.Fprintf(out, "\tif err != nil { return err }\n")
requiredTotal := 0
for _, field := range message.Fields {
if field.Optional {
requiredTotal ++
}
}
if requiredTotal > 0 {
fmt.Fprintf(out, "\tfoundRequired := 0\n")
}
fmt.Fprintf(out, "\tfor tag, data := range pairs {\n")
fmt.Fprintf(out, "\t\tswitch tag {\n")
for _, field := range message.Fields {
fmt.Fprintf(out, "\t\tcase %d:\n", field.Tag)
fmt.Fprintf(out, "\t\t\t")
err := this.unmarshalField(out, field)
if err != nil { return err }
fmt.Fprintf(out, "\n")
fmt.Fprintf(out, "\t\t\tif err != nil { return err }\n")
if field.Optional {
fmt.Fprintf(out, "\t\t\tmsg.%s = hopp.O(value)\n", field.Name)
} else { } else {
fmt.Fprintf(out, "\t\t\tmsg.%s = value\n", field.Name) nn, err := this.printf("uint%d", typ.Bits)
if requiredTotal > 0 { n += nn; if err != nil { return n, err }
fmt.Fprintf(out, "\t\t\tfoundRequired ++\n")
} }
} case TypeFloat:
} switch typ.Bits {
fmt.Fprintf(out, "\t\t}\n") case 16:
fmt.Fprintf(out, "\t}\n") nn, err := this.print("float32")
if requiredTotal > 0 { n += nn; if err != nil { return n, err }
fmt.Fprintf(out, case 32, 64:
"\tif foundRequired != %d { return hopp.ErrTablePairMissing }\n", nn, err := this.printf("float%d", typ.Bits)
requiredTotal) n += nn; if err != nil { return n, err }
}
fmt.Fprintf(out, "\treturn nil\n")
fmt.Fprintf(out, "}\n\n")
return nil
}
func (this *Protocol) unmarshalField(out io.Writer, field Field) error {
typ, err := this.ResolveType(field.Type)
if err != nil { return err }
switch field.Type {
case "I8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ)
case "I16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ)
case "I32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ)
case "I64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ)
case "U8": fmt.Fprintf(out, "value, err := tape.DecodeI8[%s](data)", typ)
case "U16": fmt.Fprintf(out, "value, err := tape.DecodeI16[%s](data)", typ)
case "U32": fmt.Fprintf(out, "value, err := tape.DecodeI32[%s](data)", typ)
case "U64": fmt.Fprintf(out, "value, err := tape.DecodeI64[%s](data)", typ)
case "I8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ)
case "I16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ)
case "I32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ)
case "I64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ)
case "U8Array": fmt.Fprintf(out, "value, err := tape.DecodeI8Array[%s](data)", typ)
case "U16Array": fmt.Fprintf(out, "value, err := tape.DecodeI16Array[%s](data)", typ)
case "U32Array": fmt.Fprintf(out, "value, err := tape.DecodeI32Array[%s](data)", typ)
case "U64Array": fmt.Fprintf(out, "value, err := tape.DecodeI64Array[%s](data)", typ)
case "String": fmt.Fprintf(out, "value, err := tape.DecodeString[%s](data)", typ)
case "StringArray": fmt.Fprintf(out, "value, err := tape.DecodeStringArray[%s](data)", typ)
default: default:
return fmt.Errorf("unknown type: %s", field.Type) return n, fmt.Errorf("floats of size %d are unsupported on this platform", typ.Bits)
} }
return nil case TypeString:
nn, err := this.print("string")
n += nn; if err != nil { return n, err }
case TypeBuffer:
nn, err := this.print("[]byte")
n += nn; if err != nil { return n, err }
case TypeArray:
nn, err := this.print("[]")
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ.Element)
n += nn; if err != nil { return n, err }
case TypeTable:
nn, err := this.print("Table")
n += nn; if err != nil { return n, err }
case TypeTableDefined:
nn, err := this.generateTypeTableDefined(typ)
n += nn; if err != nil { return n, err }
case TypeNamed:
actual, err := this.resolveTypeName(typ.Name)
if err != nil { return n, err }
nn, err := this.generateType(actual)
n += nn; if err != nil { return n, err }
}
return n, nil
} }
func comment(prefix, text string) string { func (this *Generator) generateTypeTableDefined(typ TypeTableDefined) (n int, err error) {
return prefix + " " + strings.ReplaceAll(strings.TrimSpace(text), "\n", "\n" + prefix + " ") nn, err := this.print("struct {\n")
n += nn; if err != nil { return n, err }
this.push()
for _, key := range slices.Sorted(maps.Keys(typ.Fields)) {
field := typ.Fields[key]
nn, err := this.iprintf("%s ", field.Name)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(field.Type)
n += nn; if err != nil { return n, err }
nn, err = this.print("\n")
n += nn; if err != nil { return n, err }
}
this.pop()
nn, err = this.iprint("}")
n += nn; if err != nil { return n, err }
return n, nil
}
func (this *Generator) validateIntBitSize(size int) error {
switch size {
case 8, 16, 32, 64: return nil
default: return fmt.Errorf("integers of size %d are unsupported on this platform", size)
}
}
func (this *Generator) validateFloatBitSize(size int) error {
switch size {
case 16, 32, 64: return nil
default: return fmt.Errorf("floats of size %d are unsupported on this platform", size)
}
}
func (this *Generator) push() {
this.nestingLevel ++
}
func (this *Generator) pop() {
if this.nestingLevel < 1 {
panic("cannot pop when nesting level is less than 1")
}
this.nestingLevel --
}
func (this *Generator) indent() string {
return strings.Repeat("\t", this.nestingLevel)
}
func (this *Generator) print(args ...any) (n int, err error) {
return fmt.Fprint(this.Output, args...)
}
func (this *Generator) println(args ...any) (n int, err error) {
return fmt.Fprintln(this.Output, args...)
}
func (this *Generator) printf(format string, args ...any) (n int, err error) {
return fmt.Fprintf(this.Output, format, args...)
}
func (this *Generator) iprint(args ...any) (n int, err error) {
return fmt.Fprint(this.Output, this.indent() + fmt.Sprint(args...))
}
func (this *Generator) iprintln(args ...any) (n int, err error) {
return fmt.Fprintln(this.Output, this.indent() + fmt.Sprint(args...))
}
func (this *Generator) iprintf(format string, args ...any) (n int, err error) {
return fmt.Fprintf(this.Output, this.indent() + format, args...)
}
func (this *Generator) resolveMessageName(message string) string {
return "Message" + message
}
func (this *Generator) resolveTypeName(name string) (Type, error) {
switch name {
case "U8": return TypeInt { Bits: 8 }, nil
case "U16": return TypeInt { Bits: 16 }, nil
case "U32": return TypeInt { Bits: 32 }, nil
case "U64": return TypeInt { Bits: 64 }, nil
case "U128": return TypeInt { Bits: 128 }, nil
case "U256": return TypeInt { Bits: 256 }, nil
case "I8": return TypeInt { Bits: 8, Signed: true }, nil
case "I16": return TypeInt { Bits: 16, Signed: true }, nil
case "I32": return TypeInt { Bits: 32, Signed: true }, nil
case "I64": return TypeInt { Bits: 64, Signed: true }, nil
case "I128": return TypeInt { Bits: 128, Signed: true }, nil
case "I256": return TypeInt { Bits: 256, Signed: true }, nil
case "F16": return TypeFloat { Bits: 16 }, nil
case "F32": return TypeFloat { Bits: 32 }, nil
case "F64": return TypeFloat { Bits: 64 }, nil
case "F128": return TypeFloat { Bits: 128 }, nil
case "F256": return TypeFloat { Bits: 256 }, nil
case "String": return TypeString { }, nil
case "Buffer": return TypeBuffer { }, nil
case "Table": return TypeTable { }, nil
}
if typ, ok := this.protocol.Types[name]; ok {
if typ, ok := typ.(TypeNamed); ok {
return this.resolveTypeName(typ.Name)
}
return typ, nil
}
return nil, fmt.Errorf("no type exists called %s", name)
}
func bitsToBytes(bits int) int {
return int(math.Ceil(float64(bits) / 8.0))
}
func bitsToCN(bits int) int {
return bitsToBytes(bits) - 1
} }

230
generate/lex.go Normal file
View File

@ -0,0 +1,230 @@
package generate
import "io"
import "bufio"
import "unicode"
import "unicode/utf8"
import "git.tebibyte.media/sashakoshka/goparse"
const (
TokenMethod parse.TokenKind = iota
TokenKey
TokenIdent
TokenComma
TokenLBrace
TokenRBrace
TokenLBracket
TokenRBracket
)
var tokenNames = map[parse.TokenKind] string {
TokenMethod: "Method",
TokenKey: "Key",
TokenIdent: "Ident",
TokenComma: "Comma",
TokenLBrace: "LBrace",
TokenRBrace: "RBrace",
TokenLBracket: "LBracket",
TokenRBracket: "RBracket",
}
func Lex(fileName string, reader io.Reader) (parse.Lexer, error) {
lex := &lexer {
fileName: fileName,
lineScanner: bufio.NewScanner(reader),
}
lex.nextRune()
return lex, nil
}
type lexer struct {
fileName string
lineScanner *bufio.Scanner
rune rune
line string
lineFood string
offset int
row int
column int
eof bool
}
func (this *lexer) Next() (parse.Token, error) {
token, err := this.nextInternal()
if err == io.EOF { err = this.errUnexpectedEOF() }
return token, err
}
func (this *lexer) nextInternal() (token parse.Token, err error) {
err = this.skipWhitespace()
token.Position = this.pos()
if this.eof {
token.Kind = parse.EOF
err = nil
return
}
if err != nil { return }
appendRune := func () {
token.Value += string(this.rune)
err = this.nextRune()
}
doNumber := func () {
for isDigit(this.rune) {
appendRune()
if this.eof { err = nil; return }
if err != nil { return }
}
}
defer func () {
newPos := this.pos()
newPos.End -- // TODO figure out why tf we have to do this
token.Position = token.Position.Union(newPos)
} ()
switch {
// Method
case this.rune == 'M':
token.Kind = TokenMethod
err = this.nextRune()
if err != nil { return }
doNumber()
if this.eof { err = nil; return }
// Key
case isDigit(this.rune):
token.Kind = TokenKey
doNumber()
if this.eof { err = nil; return }
// Ident
case unicode.IsUpper(this.rune):
token.Kind = TokenIdent
for unicode.IsLetter(this.rune) || isDigit(this.rune) {
appendRune()
if this.eof { err = nil; return }
if err != nil { return }
}
// Comma
case this.rune == ',':
token.Kind = TokenComma
appendRune()
if this.eof { err = nil; return }
// LBrace
case this.rune == '{':
token.Kind = TokenLBrace
appendRune()
if this.eof { err = nil; return }
// RBrace
case this.rune == '}':
token.Kind = TokenRBrace
appendRune()
if this.eof { err = nil; return }
// LBracket
case this.rune == '[':
token.Kind = TokenLBracket
appendRune()
if this.eof { err = nil; return }
// RBracket
case this.rune == ']':
token.Kind = TokenRBracket
appendRune()
if this.eof { err = nil; return }
case unicode.IsPrint(this.rune):
err = parse.Errorf (
this.pos(), "unexpected rune '%c'",
this.rune)
default:
err = parse.Errorf (
this.pos(), "unexpected rune %U",
this.rune)
}
return
}
func (this *lexer) nextRune() error {
if this.lineFood == "" {
ok := this.lineScanner.Scan()
if ok {
this.line = this.lineScanner.Text()
this.lineFood = this.line
this.rune = '\n'
this.column = 0
this.row ++
} else {
err := this.lineScanner.Err()
if err == nil {
this.eof = true
return io.EOF
} else {
return err
}
}
} else {
var ch rune
var size int
for ch == 0 && this.lineFood != "" {
ch, size = utf8.DecodeRuneInString(this.lineFood)
this.lineFood = this.lineFood[size:]
}
this.rune = ch
this.column ++
}
return nil
}
func (this *lexer) skipWhitespace() error {
err := this.skipComment()
if err != nil { return err }
for isWhitespace(this.rune) {
err := this.nextRune()
if err != nil { return err }
err = this.skipComment()
if err != nil { return err }
}
return nil
}
func (this *lexer) skipComment() error {
if this.rune == ';' {
for this.rune != '\n' {
err := this.nextRune()
if err != nil { return err }
}
}
return nil
}
func (this *lexer) pos() parse.Position {
return parse.Position {
File: this.fileName,
Line: this.lineScanner.Text(),
Row: this.row - 1,
Start: this.column - 1,
End: this.column,
}
}
func (this *lexer) errUnexpectedEOF() error {
return parse.Errorf(this.pos(), "unexpected EOF")
}
func isWhitespace(char rune) bool {
switch char {
case ' ', '\t', '\r', '\n': return true
default: return false
}
}
func isDigit(char rune) bool {
return char >= '0' && char <= '9'
}
func isHexDigit(char rune) bool {
return isDigit(char) || char >= 'a' && char <= 'f' || char >= 'A' && char <= 'F'
}

54
generate/lex_test.go Normal file
View File

@ -0,0 +1,54 @@
package generate
import "strings"
import "testing"
import "git.tebibyte.media/sashakoshka/goparse"
func TestLex(test *testing.T) {
lexer, err := Lex("test.pdl", strings.NewReader(`
M0001 User {
0000 Name String,
0001 Users []User,
0002 Followers U32,
}`))
if err != nil { test.Fatal(parse.Format(err)) }
correctTokens := []parse.Token {
tok(TokenMethod, "0001"),
tok(TokenIdent, "User"),
tok(TokenLBrace, "{"),
tok(TokenKey, "0000"),
tok(TokenIdent, "Name"),
tok(TokenIdent, "String"),
tok(TokenComma, ","),
tok(TokenKey, "0001"),
tok(TokenIdent, "Users"),
tok(TokenLBracket, "["),
tok(TokenRBracket, "]"),
tok(TokenIdent, "User"),
tok(TokenComma, ","),
tok(TokenKey, "0002"),
tok(TokenIdent, "Followers"),
tok(TokenIdent, "U32"),
tok(TokenComma, ","),
tok(TokenRBrace, "}"),
tok(parse.EOF, ""),
}
for index, correct := range correctTokens {
got, err := lexer.Next()
if err != nil { test.Fatal(parse.Format(err)) }
if got.Kind != correct.Kind || got.Value != correct.Value {
test.Logf("token %d mismatch", index)
test.Log("GOT:", tokenNames[got.Kind], got.Value)
test.Fatal("CORRECT:", tokenNames[correct.Kind], correct.Value)
}
}
}
func tok(kind parse.TokenKind, value string) parse.Token {
return parse.Token {
Kind: kind,
Value: value,
}
}

185
generate/parse.go Normal file
View File

@ -0,0 +1,185 @@
package generate
import "io"
import "strconv"
import "git.tebibyte.media/sashakoshka/goparse"
func Parse(lx parse.Lexer) (*Protocol, error) {
protocol := defaultProtocol()
par := parser {
Parser: parse.Parser {
Lexer: lx,
TokenNames: tokenNames,
},
protocol: &protocol,
}
err := par.parse()
if err != nil { return nil, err }
return par.protocol, nil
}
func defaultProtocol() Protocol {
return Protocol {
Messages: make(map[uint16] Message),
Types: map[string] Type { },
}
}
func ParseReader(reader io.Reader) (*Protocol, error) {
lx, err := Lex("test.pdl", reader)
if err != nil { return nil, err }
return Parse(lx)
}
type parser struct {
parse.Parser
protocol *Protocol
}
func (this *parser) parse() error {
err := this.Next()
if err != nil { return err }
for this.Token.Kind != parse.EOF {
err = this.parseTopLevel()
if err != nil { return err }
}
return nil
}
func (this *parser) parseTopLevel() error {
err := this.ExpectDesc("message or typedef", TokenMethod, TokenIdent)
if err != nil { return err }
if this.EOF() { return nil }
switch this.Kind() {
case TokenMethod: return this.parseMessage()
case TokenIdent: return this.parseTypedef()
}
panic("bug")
}
func (this *parser) parseMessage() error {
err := this.Expect(TokenMethod)
if err != nil { return err }
method, err := this.parseHexNumber(this.Value(), 0xFFFF)
if err != nil { return err }
err = this.ExpectNext(TokenIdent)
if err != nil { return err }
name := this.Value()
err = this.Next()
if err != nil { return err }
typ, err := this.parseType()
if err != nil { return err }
this.protocol.Messages[uint16(method)] = Message {
Name: name,
Type: typ,
}
return nil
}
func (this *parser) parseTypedef() error {
err := this.Expect(TokenIdent)
if err != nil { return err }
name := this.Value()
err = this.Next()
if err != nil { return err }
typ, err := this.parseType()
if err != nil { return err }
this.protocol.Types[name] = typ
return nil
}
func (this *parser) parseType() (Type, error) {
err := this.ExpectDesc("type", TokenIdent, TokenLBracket, TokenLBrace)
if err != nil { return nil, err }
switch this.Kind() {
case TokenIdent:
return this.parseTypeNamed()
case TokenLBracket:
return this.parseTypeArray()
case TokenLBrace:
return this.parseTypeTable()
}
panic("bug")
}
func (this *parser) parseTypeNamed() (TypeNamed, error) {
err := this.Expect(TokenIdent)
if err != nil { return TypeNamed { }, err }
name := this.Value()
err = this.Next()
if err != nil { return TypeNamed { }, err }
return TypeNamed { Name: name }, nil
}
func (this *parser) parseTypeArray() (TypeArray, error) {
err := this.Expect(TokenLBracket)
if err != nil { return TypeArray { }, err }
err = this.ExpectNext(TokenRBracket)
if err != nil { return TypeArray { }, err }
err = this.Next()
if err != nil { return TypeArray { }, err }
typ, err := this.parseType()
if err != nil { return TypeArray { }, err }
return TypeArray { Element: typ }, nil
}
func (this *parser) parseTypeTable() (TypeTableDefined, error) {
err := this.Expect(TokenLBrace)
if err != nil { return TypeTableDefined { }, err }
err = this.Next()
if err != nil { return TypeTableDefined { }, err }
typ := TypeTableDefined {
Fields: make(map[uint16] Field),
}
for {
err := this.ExpectDesc("table field", TokenKey, TokenRBrace)
if err != nil { return TypeTableDefined { }, err }
if this.Is(TokenRBrace) {
break
}
key, field, err := this.parseField()
if err != nil { return TypeTableDefined { }, err }
typ.Fields[key] = field
err = this.Expect(TokenComma, TokenRBrace)
if err != nil { return TypeTableDefined { }, err }
if this.Is(TokenRBrace) {
break
}
err = this.Next()
if err != nil { return TypeTableDefined { }, err }
}
err = this.Next()
if err != nil { return TypeTableDefined { }, err }
return typ, nil
}
func (this *parser) parseField() (uint16, Field, error) {
err := this.Expect(TokenKey)
if err != nil { return 0, Field { }, err }
key, err := this.parseHexNumber(this.Value(), 0xFFFF)
if err != nil { return 0, Field { }, err }
err = this.ExpectNext(TokenIdent)
if err != nil { return 0, Field { }, err }
name := this.Value()
err = this.Next()
if err != nil { return 0, Field { }, err }
typ, err := this.parseType()
if err != nil { return 0, Field { }, err }
return uint16(key), Field {
Name: name,
Type: typ,
}, nil
}
func (this *parser) parseHexNumber(input string, maxValue int64) (int64, error) {
number, err := strconv.ParseInt(input, 16, 64)
if err != nil {
return 0, parse.Errorf(this.Pos(), "%v", err)
}
if maxValue > 0 && number > maxValue {
return 0, parse.Errorf(this.Pos(), "value too large (max %X)", maxValue)
}
return number, nil
}

68
generate/parse_test.go Normal file
View File

@ -0,0 +1,68 @@
package generate
import "fmt"
import "strings"
import "testing"
import "git.tebibyte.media/sashakoshka/goparse"
func TestParse(test *testing.T) {
correct := defaultProtocol()
correct.Messages[0x0000] = Message {
Name: "Connect",
Type: TypeTableDefined {
Fields: map[uint16] Field {
0x0000: Field { Name: "Name", Type: TypeNamed { Name: "String" } },
0x0001: Field { Name: "Password", Type: TypeNamed { Name: "String" } },
},
},
}
correct.Messages[0x0001] = Message {
Name: "UserList",
Type: TypeTableDefined {
Fields: map[uint16] Field {
0x0000: Field { Name: "Users", Type: TypeArray { Element: TypeNamed { Name: "User" } } },
},
},
}
correct.Types["User"] = TypeTableDefined {
Fields: map[uint16] Field {
0x0000: Field { Name: "Name", Type: TypeNamed { Name: "String" } },
0x0001: Field { Name: "Bio", Type: TypeNamed { Name: "String" } },
0x0002: Field { Name: "Followers", Type: TypeNamed { Name: "U32" } },
},
}
test.Log("CORRECT:", &correct)
got, err := ParseReader(strings.NewReader(`
M0000 Connect {
0000 Name String,
0001 Password String,
}
M0001 UserList {
0000 Users []User,
}
User {
0000 Name String,
0001 Bio String,
0002 Followers U32,
}
`))
if err != nil { test.Fatal(parse.Format(err)) }
test.Log("GOT: ", got)
correctStr := fmt.Sprint(&correct)
gotStr := fmt.Sprint(got)
if correctStr != gotStr {
test.Error("not equal")
for index := range min(len(correctStr), len(gotStr)) {
if correctStr[index] == gotStr[index] { continue }
test.Log("C:", correctStr[max(0, index - 8):min(len(correctStr), index + 8)])
test.Log("G:", gotStr[max(0, index - 8):min(len(gotStr), index + 8)])
break
}
test.FailNow()
}
}

View File

@ -1,244 +1,47 @@
package generate package generate
import "io"
import "fmt"
import "errors"
import "strconv"
import "strings"
import "github.com/gomarkdown/markdown"
import "github.com/gomarkdown/markdown/ast"
import "github.com/gomarkdown/markdown/parser"
// Protocol describes a protocol.
type Protocol struct { type Protocol struct {
Messages []Message Messages map[uint16] Message
Types map[string] Type
} }
// Message describes a protocol message.
type Message struct { type Message struct {
Doc string
Method uint16
Name string Name string
Fields []Field Type Type
}
type Type interface {
}
type TypeInt struct {
Bits int
Signed bool
}
type TypeFloat struct {
Bits int
}
type TypeString struct { }
type TypeBuffer struct { }
type TypeArray struct {
Element Type
}
type TypeTable struct { }
type TypeTableDefined struct {
Fields map[uint16] Field
} }
// Field describes a named value within a message.
type Field struct { type Field struct {
Doc string
Tag uint16
Name string Name string
Optional bool Type Type
Type string
} }
// ParseReader parses a protocol definition from a reader. type TypeNamed struct {
func ParseReader(reader io.Reader) (*Protocol, error) { Name string
data, err := io.ReadAll(reader)
if err != nil { return nil, err }
protocol := new(Protocol)
err = protocol.UnmarshalText(data)
if err != nil { return nil, err }
return protocol, nil
}
// UnmarshalText unmarshals markdown-formatted text data into the protocol.
func (this *Protocol) UnmarshalText(text []byte) error {
var state int; const (
stateIdle = iota
stateMessage
stateMessageDoc
stateMessageField
)
var message *Message
addMessage := func(method uint16, name string) {
this.Messages = append(this.Messages, Message {
Method: method,
Name: name,
})
message = &this.Messages[len(this.Messages) - 1]
}
root := markdown.Parse(text, parser.New())
for _, node := range root.GetChildren() {
if node, ok := node.(*ast.Heading); ok {
if node.Level == 2 {
if removeBreaks(flatten(node)) == "Messages" {
state = stateMessage
continue
}
}
if node.Level > 3 {
state = stateIdle
continue
}
if state != stateIdle && node.Level == 3 {
heading := removeBreaks(flatten(node))
method, name, err := splitMessageHeading(heading)
if err != nil { return err }
addMessage(method, name)
state = stateMessageDoc
}
}
if state == stateIdle { continue }
if message == nil { continue }
// TODO when we are adding text content to the doc comment, it
// might be wise to do stuff like indent lists and quotes so
// that go doc renders them correctly
switch node := node.(type) {
case *ast.Paragraph:
if message.Doc != "" { message.Doc += "\n\n" }
message.Doc += removeBreaks(flatten(node))
case *ast.BlockQuote:
if message.Doc != "" { message.Doc += "\n\n> " }
message.Doc += removeBreaks(flatten(node))
case *ast.List:
// FIXME format the list
if message.Doc != "" { message.Doc += "\n\n" }
message.Doc += removeBreaks(flatten(node))
case *ast.Table:
fields, err := processFieldTable(node)
if err != nil { return err}
message.Fields = append(message.Fields, fields...)
}
}
return nil
}
func processFieldTable(node *ast.Table) ([]Field, error) {
fields := []Field { }
children := node.GetChildren()
if len(children) != 2 {
return nil, errors.New("malformed field table")
}
// get columns
columns := []string { }
if header, ok := children[0].(*ast.TableHeader); ok {
children := header.GetChildren()
if len(children) != 1 {
return nil, errors.New("malformed field table header")
}
if row, ok := header.Children[0].(*ast.TableRow); ok {
for _, cell := range row.GetChildren() {
if cell, ok := cell.(*ast.TableCell); ok {
columns = append(columns, flatten(cell))
}
}
} else {
return nil, errors.New("malformed field table header")
}
for index, column := range columns {
columns[index] = strings.ToLower(column)
}
} else {
return nil, errors.New("malformed field table: no header")
}
// get data
if body, ok := children[1].(*ast.TableBody); ok {
for _, node := range body.GetChildren() {
if row, ok := node.(*ast.TableRow); ok {
children := row.GetChildren()
if len(children) != len(columns) {
return nil, errors.New (
"malformed field table row: wrong " +
"number of columns")
}
field := Field { }
for index, node := range children {
if cell, ok := node.(*ast.TableCell); ok {
text := flatten(cell)
switch columns[index] {
case "tag":
tag, err := parseTag(text)
if err != nil { return nil, err }
field.Tag = tag
case "name":
field.Name = text
case "required":
field.Optional = !parseBool(text)
case "optional":
field.Optional = parseBool(text)
case "type":
field.Type = text
case "comment", "purpose", "documentation":
field.Doc = text
}
}}
fields = append(fields, field)
}}
} else {
return nil, errors.New("malformed field table: no body")
}
return fields, nil
}
type nodeFlattener struct {
text string
}
func (this *nodeFlattener) String() string { return this.text }
func (this *nodeFlattener) Visit(node ast.Node, entering bool) ast.WalkStatus {
if entering {
if node := node.AsLeaf(); node != nil {
this.text += string(node.Literal)
}
}
return ast.GoToNext
}
func flatten(node ast.Node) string {
flattener := new(nodeFlattener)
ast.Walk(node, flattener)
return flattener.text
}
func removeBreaks(text string) string {
text = strings.ReplaceAll(text, "\n", " ")
text = strings.ReplaceAll(text, "\r", "")
return text
}
func parseBool(text string) bool {
switch(strings.ToLower(text)) {
case "yes": return true
case "no": return false
case "true": return true
case "false": return false
}
return false
}
func parseTag(text string) (uint16, error) {
tag, err := strconv.ParseUint(text, 10, 16)
if err != nil {
return 0, fmt.Errorf("malformed tag '%s': %w", text, err)
}
return uint16(tag), nil
}
func splitMessageHeading(text string) (uint16, string, error) {
text = strings.TrimSpace(text)
methodText, name, ok := strings.Cut(text, " ")
if !ok {
return 0, "", fmt.Errorf(
"malformed message heading '%s': no message name",
text)
}
method, err := strconv.ParseUint(methodText, 16, 16)
if err != nil {
return 0, "", fmt.Errorf(
"malformed method number '%s': %w",
methodText, err)
}
name = strings.TrimSpace(name)
return uint16(method), name, nil
} }

16
go.mod
View File

@ -4,19 +4,5 @@ go 1.23.0
require ( require (
git.tebibyte.media/sashakoshka/go-util v0.9.1 git.tebibyte.media/sashakoshka/go-util v0.9.1
github.com/gomarkdown/markdown v0.0.0-20241205020045-f7e15b2f3e62 git.tebibyte.media/sashakoshka/goparse v0.2.0
github.com/quic-go/quic-go v0.48.2
)
require (
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
go.uber.org/mock v0.4.0 // indirect
golang.org/x/crypto v0.26.0 // indirect
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.28.0 // indirect
golang.org/x/sys v0.23.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
) )

60
go.sum
View File

@ -1,60 +1,4 @@
git.tebibyte.media/sashakoshka/go-util v0.9.1 h1:eGAbLwYhOlh4aq/0w+YnJcxT83yPhXtxnYMzz6K7xGo= git.tebibyte.media/sashakoshka/go-util v0.9.1 h1:eGAbLwYhOlh4aq/0w+YnJcxT83yPhXtxnYMzz6K7xGo=
git.tebibyte.media/sashakoshka/go-util v0.9.1/go.mod h1:0Q1t+PePdx6tFYkRuJNcpM1Mru7wE6X+it1kwuOH+6Y= git.tebibyte.media/sashakoshka/go-util v0.9.1/go.mod h1:0Q1t+PePdx6tFYkRuJNcpM1Mru7wE6X+it1kwuOH+6Y=
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= git.tebibyte.media/sashakoshka/goparse v0.2.0 h1:uQmKvOCV2AOlCHEDjg9uclZCXQZzq2PxaXfZ1aIMiQI=
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= git.tebibyte.media/sashakoshka/goparse v0.2.0/go.mod h1:tSQwfuD+EujRoKr6Y1oaRy74ZynatzkRLxjE3sbpCmk=
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/gomarkdown/markdown v0.0.0-20241205020045-f7e15b2f3e62 h1:pbAFUZisjG4s6sxvRJvf2N7vhpCvx2Oxb3PmS6pDO1g=
github.com/gomarkdown/markdown v0.0.0-20241205020045-f7e15b2f3e62/go.mod h1:JDGcbDT52eL4fju3sZ4TeHGsQwhG9nbDV21aMyhwPoA=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE=
golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -0,0 +1,94 @@
package testutil
import "fmt"
import "slices"
import "strings"
// Snake lets you compare blocks of data where the ordering of certain parts may
// be swapped every which way. It is designed for comparing the encoding of
// maps where the ordering of individual elements is inconsistent.
//
// The snake is divided into sectors, which hold a number of variations. For a
// sector to be satisfied by some data, some ordering of it must match the data
// exactly. for the snake to be satisfied by some data, its sectors must match
// the data in order, but the internal ordering of each sector doesn't matter.
type Snake [] [] []byte
// snake sector variation
// S returns a new snake.
func S(data ...byte) Snake {
return (Snake { }).Add(data...)
}
// AddVar returns a new snake with the given sector added on to it. Successive
// calls of this method can be chained together to create a big ass snake.
func (sn Snake) AddVar(sector ...[]byte) Snake {
slice := make(Snake, len(sn) + 1)
copy(slice, sn)
slice[len(slice) - 1] = sector
return slice
}
// Add is like AddVar, but adds a sector with only one variation, which means it
// does not vary, hence why the method is called that.
func (sn Snake) Add(data ...byte) Snake {
return sn.AddVar(data)
}
// Check determines if the data satisfies the snake.
func (sn Snake) Check(data []byte) (ok bool, n int) {
left := data
variations := map[int] []byte { }
for _, sector := range sn {
clear(variations)
for key, variation := range sector {
variations[key] = variation
}
for len(variations) > 0 {
found := false
for key, variation := range variations {
if len(left) < len(variation) { continue }
if !slices.Equal(left[:len(variation)], variation) { continue }
n += len(variation)
left = data[n:]
delete(variations, key)
found = true
}
if !found { return false, n }
}
}
if n < len(data) {
return false, n
}
return true, n
}
func (sn Snake) String() string {
if len(sn) == 0 || len(sn[0]) == 0 || len(sn[0][0]) == 0 {
return "EMPTY"
}
out := strings.Builder { }
for index, sector := range sn {
if index > 0 { out.WriteString(" : ") }
out.WriteRune('[')
for index, variation := range sector {
if index > 0 { out.WriteString(" / ") }
for _, byt := range variation {
fmt.Fprintf(&out, "%02x", byt)
}
}
out.WriteRune(']')
}
return out.String()
}
// HexBytes formats bytes into a hexadecimal string.
func HexBytes(data []byte) string {
if len(data) == 0 { return "EMPTY" }
out := strings.Builder { }
for _, byt := range data {
fmt.Fprintf(&out, "%02x", byt)
}
return out.String()
}

View File

@ -0,0 +1,66 @@
package testutil
import "testing"
func TestSnakeA(test *testing.T) {
snake := S(1, 6).AddVar(
[]byte { 1 },
[]byte { 2 },
[]byte { 3 },
[]byte { 4 },
[]byte { 5 },
).Add(9)
test.Log(snake)
ok, n := snake.Check([]byte { 1, 6, 1, 2, 3, 4, 5, 9 })
if !ok { test.Fatal("false negative:", n) }
ok, n = snake.Check([]byte { 1, 6, 5, 4, 3, 2, 1, 9 })
if !ok { test.Fatal("false negative:", n) }
ok, n = snake.Check([]byte { 1, 6, 3, 1, 4, 2, 5, 9 })
if !ok { test.Fatal("false negative:", n) }
ok, n = snake.Check([]byte { 1, 6, 9 })
if ok { test.Fatal("false positive:", n) }
ok, n = snake.Check([]byte { 1, 6, 1, 2, 3, 4, 5, 6, 9 })
if ok { test.Fatal("false positive:", n) }
ok, n = snake.Check([]byte { 1, 6, 0, 2, 3, 4, 5, 6, 9 })
if ok { test.Fatal("false positive:", n) }
ok, n = snake.Check([]byte { 1, 6, 7, 1, 4, 2, 5, 9 })
if ok { test.Fatal("false positive:", n) }
ok, n = snake.Check([]byte { 1, 6, 7, 3, 1, 4, 2, 5, 9 })
if ok { test.Fatal("false positive:", n) }
ok, n = snake.Check([]byte { 1, 6, 7, 3, 1, 4, 2, 5, 9 })
if ok { test.Fatal("false positive:", n) }
ok, n = snake.Check([]byte { 1, 6, 1, 2, 3, 4, 5, 9, 10})
if ok { test.Fatal("false positive:", n) }
}
func TestSnakeB(test *testing.T) {
snake := S(1, 6).AddVar(
[]byte { 1 },
[]byte { 2 },
).Add(9).AddVar(
[]byte { 3, 2 },
[]byte { 0 },
[]byte { 1, 1, 2, 3 },
)
test.Log(snake)
ok, n := snake.Check([]byte { 1, 6, 1, 2, 9, 3, 2, 0, 1, 1, 2, 3})
if !ok { test.Fatal("false negative:", n) }
ok, n = snake.Check([]byte { 1, 6, 2, 1, 9, 0, 1, 1, 2, 3, 3, 2})
if !ok { test.Fatal("false negative:", n) }
ok, n = snake.Check([]byte { 1, 6, 9 })
if ok { test.Fatal("false positive:", n) }
ok, n = snake.Check([]byte { 1, 6, 1, 2, 9 })
if ok { test.Fatal("false positive:", n) }
ok, n = snake.Check([]byte { 1, 6, 9, 3, 2, 0, 1, 1, 2, 3})
if ok { test.Fatal("false positive:", n) }
ok, n = snake.Check([]byte { 1, 6, 2, 9, 0, 1, 1, 2, 3, 3, 2})
if ok { test.Fatal("false positive:", n) }
ok, n = snake.Check([]byte { 1, 6, 1, 2, 9, 3, 2, 1, 1, 2, 3})
if ok { test.Fatal("false positive:", n) }
}

View File

@ -1,9 +1,8 @@
package hopp package hopp
import "net" import "net"
import "context" import "errors"
import "crypto/tls" import "crypto/tls"
import "github.com/quic-go/quic-go"
// Listener is an object which listens for incoming HOPP connections. // Listener is an object which listens for incoming HOPP connections.
type Listener interface { type Listener interface {
@ -17,7 +16,8 @@ type Listener interface {
} }
// Listen listens for incoming HOPP connections. The network must be one of // Listen listens for incoming HOPP connections. The network must be one of
// "quic", "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". // "quic", "quic4", (IPv4-only) "quic6" (IPv6-only), or "unix". For now, quic is
// not supported.
func Listen(network, address string) (Listener, error) { func Listen(network, address string) (Listener, error) {
switch network { switch network {
case "quic", "quic4", "quic6": return ListenQUIC(network, address, nil) case "quic", "quic4", "quic6": return ListenQUIC(network, address, nil)
@ -30,19 +30,8 @@ func Listen(network, address string) (Listener, error) {
// The network must be one of "quic", "quic4", (IPv4-only) or "quic6" // The network must be one of "quic", "quic4", (IPv4-only) or "quic6"
// (IPv6-only). // (IPv6-only).
func ListenQUIC(network, address string, tlsConf *tls.Config) (Listener, error) { func ListenQUIC(network, address string, tlsConf *tls.Config) (Listener, error) {
tlsConf = tlsConfig(tlsConf) // tlsConf = tlsConfig(tlsConf)
quicConf := quicConfig() return nil, errors.New("quic is not yet implemented")
udpNetwork, err := quicNetworkToUDPNetwork(network)
if err != nil { return nil, err }
addr, err := net.ResolveUDPAddr(udpNetwork, address)
if err != nil { return nil, err }
udpListener, err := net.ListenUDP(udpNetwork, addr)
if err != nil { return nil, err }
quicListener, err := quic.Listen(udpListener, tlsConf, quicConf)
if err != nil { return nil, err }
return &listenerQUIC {
underlying: quicListener,
}, nil
} }
// ListenUnix listens for incoming HOPP connections using a Unix domain socket // ListenUnix listens for incoming HOPP connections using a Unix domain socket
@ -58,24 +47,6 @@ func ListenUnix(network, address string) (Listener, error) {
}, nil }, nil
} }
type listenerQUIC struct {
underlying *quic.Listener
}
func (this *listenerQUIC) Accept() (Conn, error) {
conn, err := this.underlying.Accept(context.Background())
if err != nil { return nil, err }
return AdaptB(quicMultiConn { underlying: conn }), nil
}
func (this *listenerQUIC) Close() error {
return this.underlying.Close()
}
func (this *listenerQUIC) Addr() net.Addr {
return this.underlying.Addr()
}
type listenerUnix struct { type listenerUnix struct {
underlying *net.UnixListener underlying *net.UnixListener
} }

View File

@ -4,11 +4,16 @@ import "io"
import "fmt" import "fmt"
import "net" import "net"
import "sync" import "sync"
import "sync/atomic"
import "git.tebibyte.media/sashakoshka/hopp/tape" import "git.tebibyte.media/sashakoshka/hopp/tape"
import "git.tebibyte.media/sashakoshka/go-util/sync" import "git.tebibyte.media/sashakoshka/go-util/sync"
// TODO investigate why 30 never reaches the server, causing it to wait for ever
// and never close the connection, causing the client to also wait forever
const closeMethod = 0xFFFF const closeMethod = 0xFFFF
const int64Max = int64((^uint64(0)) >> 1) const int64Max = int64((^uint64(0)) >> 1)
const defaultChunkSize = 0x1000
// Party represents a side of a connection. // Party represents a side of a connection.
type Party bool; const ( type Party bool; const (
@ -16,7 +21,16 @@ type Party bool; const (
ClientSide Party = true ClientSide Party = true
) )
func (party Party) String() string {
if party == ServerSide {
return "server"
} else {
return "client"
}
}
type a struct { type a struct {
sizeLimit int64
underlying net.Conn underlying net.Conn
party Party party Party
transID int64 transID int64
@ -32,6 +46,7 @@ type a struct {
// oriented transport such as TCP or UNIX domain stream sockets. // oriented transport such as TCP or UNIX domain stream sockets.
func AdaptA(underlying net.Conn, party Party) Conn { func AdaptA(underlying net.Conn, party Party) Conn {
conn := &a { conn := &a {
sizeLimit: defaultSizeLimit,
underlying: underlying, underlying: underlying,
party: party, party: party,
transMap: make(map[int64] *transA), transMap: make(map[int64] *transA),
@ -49,7 +64,7 @@ func AdaptA(underlying net.Conn, party Party) Conn {
func (this *a) Close() error { func (this *a) Close() error {
close(this.done) close(this.done)
return this.underlying.Close() return nil
} }
func (this *a) LocalAddr() net.Addr { func (this *a) LocalAddr() net.Addr {
@ -63,30 +78,41 @@ func (this *a) RemoteAddr() net.Addr {
func (this *a) OpenTrans() (Trans, error) { func (this *a) OpenTrans() (Trans, error) {
this.transLock.Lock() this.transLock.Lock()
defer this.transLock.Unlock() defer this.transLock.Unlock()
if this.transID == int64Max {
return nil, fmt.Errorf("could not open transaction: %w", ErrIntegerOverflow)
}
id := this.transID id := this.transID
this.transID ++
trans := &transA { trans := &transA {
parent: this, parent: this,
id: id, id: id,
incoming: usync.NewGate[incomingMessage](), incoming: usync.NewGate[incomingMessage](),
} }
this.transMap[id] = trans this.transMap[id] = trans
if this.transID == int64Max { if this.party == ClientSide {
return nil, fmt.Errorf("could not open transaction: %w", ErrIntegerOverflow)
}
this.transID ++ this.transID ++
} else {
this.transID --
}
return trans, nil return trans, nil
} }
func (this *a) AcceptTrans() (Trans, error) { func (this *a) AcceptTrans() (Trans, error) {
eof := fmt.Errorf("could not accept transaction: %w", io.EOF)
select { select {
case trans := <- this.transChan: case trans := <- this.transChan:
if trans == nil {
return nil, eof
}
return trans, nil return trans, nil
case <- this.done: case <- this.done:
return nil, fmt.Errorf("could not accept transaction: %w", io.EOF) return nil, eof
} }
} }
func (this *a) SetSizeLimit(limit int64) {
this.sizeLimit = limit
}
func (this *a) unlistTransactionSafe(id int64) { func (this *a) unlistTransactionSafe(id int64) {
this.transLock.Lock() this.transLock.Lock()
defer this.transLock.Unlock() defer this.transLock.Unlock()
@ -96,27 +122,32 @@ func (this *a) unlistTransactionSafe(id int64) {
func (this *a) sendMessageSafe(trans int64, method uint16, data []byte) error { func (this *a) sendMessageSafe(trans int64, method uint16, data []byte) error {
this.sendLock.Lock() this.sendLock.Lock()
defer this.sendLock.Unlock() defer this.sendLock.Unlock()
return encodeMessageA(this.underlying, trans, method, data) return encodeMessageA(this.underlying, this.sizeLimit, trans, method, data)
} }
func (this *a) receive() { func (this *a) receive() {
defer func() { defer func() {
this.underlying.Close() this.underlying.Close()
close(this.transChan)
this.transLock.Lock() this.transLock.Lock()
defer this.transLock.Unlock() defer this.transLock.Unlock()
for _, trans := range this.transMap { for _, trans := range this.transMap {
trans.closeDontUnlist() trans.closeDontUnlist()
} }
clear(this.transMap) clear(this.transMap)
this.underlying.Close()
}() }()
// receive MMBs in a loop and forward them to transactions until shit
// starts closing
for { for {
transID, method, payload, err := decodeMessageA(this.underlying) transID, method, chunked, payload, err := decodeMessageA(this.underlying, this.sizeLimit)
if err != nil { if err != nil {
this.err = fmt.Errorf("could not receive message: %w", err) this.err = fmt.Errorf("could not receive message: %w", err)
return return
} }
err = this.receiveMultiplex(transID, method, payload) err = this.multiplexMMB(transID, method, chunked, payload)
if err != nil { if err != nil {
this.err = fmt.Errorf("could not receive message: %w", err) this.err = fmt.Errorf("could not receive message: %w", err)
return return
@ -124,7 +155,7 @@ func (this *a) receive() {
} }
} }
func (this *a) receiveMultiplex(transID int64, method uint16, payload []byte) error { func (this *a) multiplexMMB(transID int64, method uint16, chunked bool, payload []byte) error {
if transID == 0 { return ErrMessageMalformed } if transID == 0 { return ErrMessageMalformed }
trans, err := func() (*transA, error) { trans, err := func() (*transA, error) {
@ -133,6 +164,12 @@ func (this *a) receiveMultiplex(transID int64, method uint16, payload []byte) er
trans, ok := this.transMap[transID] trans, ok := this.transMap[transID]
if !ok { if !ok {
// check if this is a superfluous close message and just
// do nothing if so
if method == closeMethod {
return nil, nil
}
// it is forbidden for the other party to initiate a transaction // it is forbidden for the other party to initiate a transaction
// with an ID from this party // with an ID from this party
if this.party == partyFromTransID(transID) { if this.party == partyFromTransID(transID) {
@ -150,28 +187,49 @@ func (this *a) receiveMultiplex(transID int64, method uint16, payload []byte) er
}() }()
if err != nil { return err } if err != nil { return err }
if trans == nil {
return nil
}
if method == closeMethod {
return trans.Close()
} else {
trans.incoming.Send(incomingMessage { trans.incoming.Send(incomingMessage {
method: method, method: method,
chunked: chunked,
payload: payload, payload: payload,
}) })
}
return nil return nil
} }
// most methods in transA don't need to be goroutine safe except those marked
// as such
type transA struct { type transA struct {
parent *a parent *a
id int64 id int64
incoming usync.Gate[incomingMessage] incoming usync.Gate[incomingMessage]
currentReader io.Reader
currentWriter io.Closer
writeBuffer []byte
closed atomic.Bool
} }
func (this *transA) Close() error { func (this *transA) Close() error {
// MUST be goroutine safe
err := this.closeDontUnlist() err := this.closeDontUnlist()
this.parent.unlistTransactionSafe(this.ID()) this.parent.unlistTransactionSafe(this.ID())
return err return err
} }
func (this *transA) closeDontUnlist() error { func (this *transA) closeDontUnlist() (err error) {
this.Send(closeMethod, nil) // MUST be goroutine safe
return this.incoming.Close() this.incoming.Close()
if !this.closed.Load() {
err = this.Send(closeMethod, nil)
}
this.closed.Store(true)
return err
} }
func (this *transA) ID() int64 { func (this *transA) ID() int64 {
@ -182,58 +240,213 @@ func (this *transA) Send(method uint16, data []byte) error {
return this.parent.sendMessageSafe(this.id, method, data) return this.parent.sendMessageSafe(this.id, method, data)
} }
func (this *transA) SendWriter(method uint16) (io.WriteCloser, error) {
// close previous writer if necessary
if this.currentWriter != nil {
this.currentWriter.Close()
this.currentWriter = nil
}
// create new writer
writer := &writerA {
parent: this,
// there is only ever one writer at a time, so they can all
// share a buffer
buffer: this.writeBuffer[:0],
method: method,
chunkSize: defaultChunkSize,
open: true,
}
this.currentWriter = writer
return writer, nil
}
func (this *transA) Receive() (method uint16, data []byte, err error) { func (this *transA) Receive() (method uint16, data []byte, err error) {
receive := this.incoming.Receive() method, reader, err := this.ReceiveReader()
if err != nil { return 0, nil, err }
data, err = io.ReadAll(reader)
if err != nil { return 0, nil, err }
return method, data, nil
}
func (this *transA) ReceiveReader() (uint16, io.Reader, error) {
// if the transaction has been closed, return an io.EOF
if this.closed.Load() {
return 0, nil, io.EOF
}
// drain previous reader if necessary
if this.currentReader != nil {
io.Copy(io.Discard, this.currentReader)
}
// create new reader
reader := &readerA {
parent: this,
}
method, err := reader.pull()
if err != nil { return 0, nil, err}
this.currentReader = reader
return method, reader, nil
}
type readerA struct {
parent *transA
leftover []byte
eof bool
}
// pull pulls the next MMB in this message from the transaction.
func (this *readerA) pull() (uint16, error) {
// if the previous message ended the chain, return an io.EOF
if this.eof {
return 0, io.EOF
}
// get an MMB from the transaction we are a part of
receive := this.parent.incoming.Receive()
if receive != nil { if receive != nil {
if message, ok := <- receive; ok { if message, ok := <- receive; ok {
if message.method != closeMethod { if message.method != closeMethod {
return message.method, message.payload, nil this.leftover = append(this.leftover, message.payload...)
if !message.chunked {
this.eof = true
}
return message.method, nil
} }
} }
} }
// close and return error on failure // close and return error on failure
this.Close() this.eof = true
if this.parent.err == nil { this.parent.Close()
return 0, nil, fmt.Errorf("could not receive message: %w", io.EOF) if this.parent.parent.err == nil {
return 0, fmt.Errorf("could not receive message: %w", io.EOF)
} else { } else {
return 0, nil, this.parent.err return 0, this.parent.parent.err
} }
} }
func (this *readerA) Read(buffer []byte) (int, error) {
if len(this.leftover) == 0 {
if this.eof { return 0, io.EOF }
this.pull()
}
copied := copy(buffer, this.leftover)
this.leftover = this.leftover[copied:]
return copied, nil
}
type writerA struct {
parent *transA
buffer []byte
method uint16
chunkSize int64
open bool
}
func (this *writerA) Write(data []byte) (n int, err error) {
if !this.open { return 0, io.EOF }
toSend := data
for len(toSend) > 0 {
nn, err := this.writeOne(toSend)
n += nn
toSend = toSend[nn:]
if err != nil { return n, err }
}
return n, nil
}
func (this *writerA) Close() error {
this.open = false
return nil
}
func (this *writerA) writeOne(data []byte) (n int, err error) {
data = data[:min(len(data), int(this.chunkSize))]
// if there is more room, append to the buffer and exit
if int64(len(this.buffer) + len(data)) <= this.chunkSize {
this.buffer = append(this.buffer, data...)
n = len(data)
// if have a full chunk, flush
if int64(len(this.buffer)) == this.chunkSize {
err = this.flush()
if err != nil { return n, err }
}
return n, nil
}
// if not, flush and store as much as we can in the buffer
err = this.flush()
if err != nil { return n, err }
this.buffer = append(this.buffer, data...)
return n, nil
}
func (this *writerA) flush() error {
return this.parent.parent.sendMessageSafe(this.parent.id, this.method, this.buffer)
}
type incomingMessage struct { type incomingMessage struct {
method uint16 method uint16
chunked bool
payload []byte payload []byte
} }
func encodeMessageA(writer io.Writer, trans int64, method uint16, data []byte) error { func encodeMessageA(
buffer := make([]byte, 12 + len(data)) writer io.Writer,
sizeLimit int64,
trans int64,
method uint16,
data []byte,
) error {
if int64(len(data)) > sizeLimit {
return ErrPayloadTooLarge
}
buffer := make([]byte, 18 + len(data))
tape.EncodeI64(buffer[:8], trans) tape.EncodeI64(buffer[:8], trans)
tape.EncodeI16(buffer[8:10], method) tape.EncodeI16(buffer[8:10], method)
length, ok := tape.U16CastSafe(len(data)) tape.EncodeI64(buffer[10:18], uint64(len(data)))
if !ok { return ErrPayloadTooLarge } copy(buffer[18:], data)
tape.EncodeI16(buffer[10:12], length)
copy(buffer[12:], data)
_, err := writer.Write(buffer) _, err := writer.Write(buffer)
return err return err
} }
func decodeMessageA(reader io.Reader) (int64, uint16, []byte, error) { func decodeMessageA(
headerBuffer := [12]byte { } reader io.Reader,
_, err := io.ReadFull(reader, headerBuffer[:]) sizeLimit int64,
if err != nil { return 0, 0, nil, err } ) (
transID, err := tape.DecodeI64[int64](headerBuffer[:8]) transID int64,
if err != nil { return 0, 0, nil, err } method uint16,
method, err := tape.DecodeI16[uint16](headerBuffer[8:10]) chunked bool,
if err != nil { return 0, 0, nil, err } payloadBuffer []byte,
length, err := tape.DecodeI16[uint16](headerBuffer[10:12]) err error,
if err != nil { return 0, 0, nil, err } ) {
payloadBuffer := make([]byte, int(length)) headerBuffer := [18]byte { }
_, err = io.ReadFull(reader, headerBuffer[:])
if err != nil { return 0, 0, false, nil, err }
transID, err = tape.DecodeI64[int64](headerBuffer[:8])
if err != nil { return 0, 0, false, nil, err }
method, err = tape.DecodeI16[uint16](headerBuffer[8:10])
if err != nil { return 0, 0, false, nil, err }
size, err := tape.DecodeI64[uint64](headerBuffer[10:18])
if err != nil { return 0, 0, false, nil, err }
chunked, size = splitCCBSize(size)
if size > uint64(sizeLimit) {
return 0, 0, false, nil, ErrPayloadTooLarge
}
payloadBuffer = make([]byte, int(size))
_, err = io.ReadFull(reader, payloadBuffer) _, err = io.ReadFull(reader, payloadBuffer)
if err != nil { return 0, 0, nil, err } if err != nil { return 0, 0, false, nil, err }
return transID, method, payloadBuffer, nil return transID, method, chunked, payloadBuffer, nil
} }
func partyFromTransID(id int64) Party { func partyFromTransID(id int64) Party {
return id > 0 return id > 0
} }
func splitCCBSize(size uint64) (bool, uint64) {
return size >> 63 > 1, size & 0x7FFFFFFFFFFFFFFF
}

View File

@ -25,47 +25,18 @@ func TestConnA(test *testing.T) {
"When the impostor is sus!", "When the impostor is sus!",
} }
network := "tcp" clientFunc := func(a Conn) {
addr := "localhost:7959"
// server
listener, err := net.Listen(network, addr)
if err != nil { test.Fatal(err) }
defer listener.Close()
go func() {
test.Log("SERVER listening")
conn, err := listener.Accept()
if err != nil { test.Error("SERVER", err); return }
defer conn.Close()
a := AdaptA(conn, ServerSide)
trans, err := a.OpenTrans()
if err != nil { test.Error("SERVER", err); return }
defer trans.Close()
for method, payload := range payloads {
test.Log("SERVER", method, payload)
err := trans.Send(uint16(method), []byte(payload))
if err != nil { test.Error("SERVER", err); return }
}
}()
// client
test.Log("CLIENT dialing")
conn, err := net.Dial(network, addr)
if err != nil { test.Fatal("CLIENT", err) }
test.Log("CLIENT dialed")
a := AdaptA(conn, ClientSide)
defer a.Close()
test.Log("CLIENT accepting transaction") test.Log("CLIENT accepting transaction")
trans, err := a.AcceptTrans() trans, err := a.AcceptTrans()
if err != nil { test.Fatal("CLIENT", err) } if err != nil { test.Fatal("CLIENT", err) }
test.Log("CLIENT accepted transaction") test.Log("CLIENT accepted transaction")
defer trans.Close() test.Cleanup(func() { trans.Close() })
for method, payload := range payloads { for method, payload := range payloads {
test.Log("CLIENT waiting...") test.Log("CLIENT waiting...")
gotMethod, gotPayloadBytes, err := trans.Receive() gotMethod, gotPayloadBytes, err := trans.Receive()
if err != nil { test.Fatal("CLIENT", err) } if err != nil { test.Fatal("CLIENT", err) }
gotPayload := string(gotPayloadBytes) gotPayload := string(gotPayloadBytes)
test.Log("CLIENT", gotMethod, gotPayload) test.Log("CLIENT m:", gotMethod, "p:", gotPayload)
if int(gotMethod) != method { if int(gotMethod) != method {
test.Errorf("CLIENT method not equal") test.Errorf("CLIENT method not equal")
} }
@ -73,22 +44,112 @@ func TestConnA(test *testing.T) {
test.Errorf("CLIENT payload not equal") test.Errorf("CLIENT payload not equal")
} }
} }
_, _, err = trans.Receive() test.Log("CLIENT waiting for transaction close...")
gotMethod, gotPayload, err := trans.Receive()
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
test.Fatal("CLIENT wrong error:", err) test.Error("CLIENT wrong error:", err)
test.Error("CLIENT method:", gotMethod)
test.Error("CLIENT payload:", gotPayload)
test.Fatal("CLIENT ok byeeeeeeeeeeeee")
} }
test.Log("CLIENT done") }
// TODO test error from trans/connection closed by other side
serverFunc := func(a Conn) {
trans, err := a.OpenTrans()
if err != nil { test.Error("SERVER", err); return }
test.Cleanup(func() { trans.Close() })
for method, payload := range payloads {
test.Log("SERVER m:", method, "p:", payload)
err := trans.Send(uint16(method), []byte(payload))
if err != nil { test.Error("SERVER", err); return }
}
test.Log("SERVER closing connection")
}
clientServerEnvironment(test, clientFunc, serverFunc)
}
func TestTransOpenCloseA(test *testing.T) {
// currently:
//
// | data sent | data recvd | close sent | close recvd
// 10 | X | X | X | server hangs
// 20 | X | X | X | client hangs
// 30 | X | | X |
//
// when a close message is recvd, it tries to push to the trans and
// hangs on trans.incoming.Send, which hangs on sending the value to the
// underlying channel. why is this?
//
// check if we are really getting values from the channel when pulling
// from the trans channel when we are expecting a close.
clientFunc := func(conn Conn) {
// 10
trans, err := conn.OpenTrans()
if err != nil { test.Error("CLIENT", err); return }
test.Log("CLIENT sending 10")
trans.Send(10, []byte("hi"))
trans.Close()
// 20
test.Log("CLIENT awaiting 20")
trans, err = conn.AcceptTrans()
if err != nil { test.Error("CLIENT", err); return }
test.Cleanup(func() { trans.Close() })
gotMethod, gotPayload, err := trans.Receive()
if err != nil { test.Error("CLIENT", err); return }
test.Logf("CLIENT m: %d p: %s", gotMethod, gotPayload)
if gotMethod != 20 { test.Error("CLIENT wrong method")}
// 30
trans, err = conn.OpenTrans()
if err != nil { test.Error("CLIENT", err); return }
test.Log("CLIENT sending 30")
trans.Send(30, []byte("good"))
trans.Close()
}
serverFunc := func(conn Conn) {
// 10
test.Log("SERVER awaiting 10")
trans, err := conn.AcceptTrans()
if err != nil { test.Error("SERVER", err); return }
test.Cleanup(func() { trans.Close() })
gotMethod, gotPayload, err := trans.Receive()
if err != nil { test.Error("SERVER", err); return }
test.Logf("SERVER m: %d p: %s", gotMethod, gotPayload)
if gotMethod != 10 { test.Error("SERVER wrong method")}
// 20
trans, err = conn.OpenTrans()
if err != nil { test.Error("SERVER", err); return }
test.Log("SERVER sending 20")
trans.Send(20, []byte("hi how r u"))
trans.Close()
// 30
test.Log("SERVER awaiting 30")
trans, err = conn.AcceptTrans()
if err != nil { test.Error("SERVER", err); return }
test.Cleanup(func() { trans.Close() })
gotMethod, gotPayload, err = trans.Receive()
if err != nil { test.Error("SERVER", err); return }
test.Logf("SERVER m: %d p: %s", gotMethod, gotPayload)
if gotMethod != 30 { test.Error("SERVER wrong method")}
}
clientServerEnvironment(test, clientFunc, serverFunc)
} }
func TestEncodeMessageA(test *testing.T) { func TestEncodeMessageA(test *testing.T) {
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
payload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 } payload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 }
err := encodeMessageA(buffer, 0x5800FEABC3104F04, 0x6B12, payload) err := encodeMessageA(buffer, defaultSizeLimit, 0x5800FEABC3104F04, 0x6B12, payload)
correct := []byte { correct := []byte {
0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04, 0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04,
0x6B, 0x12, 0x6B, 0x12,
0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06,
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
} }
if err != nil { if err != nil {
@ -102,19 +163,19 @@ func TestEncodeMessageA(test *testing.T) {
func TestEncodeMessageAErr(test *testing.T) { func TestEncodeMessageAErr(test *testing.T) {
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
payload := make([]byte, 0x10000) payload := make([]byte, 0x10000)
err := encodeMessageA(buffer, 0x5800FEABC3104F04, 0x6B12, payload) err := encodeMessageA(buffer, 0x20, 0x5800FEABC3104F04, 0x6B12, payload)
if !errors.Is(err, ErrPayloadTooLarge) { if !errors.Is(err, ErrPayloadTooLarge) {
test.Fatalf("wrong error: %v", err) test.Fatalf("wrong error: %v", err)
} }
} }
func TestDecodeMessageA(test *testing.T) { func TestDecodeMessageA(test *testing.T) {
transID, method, payload, err := decodeMessageA(bytes.NewReader([]byte { transID, method, _, payload, err := decodeMessageA(bytes.NewReader([]byte {
0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04, 0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04,
0x6B, 0x12, 0x6B, 0x12,
0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06,
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
})) }), defaultSizeLimit)
if err != nil { if err != nil {
test.Fatal(err) test.Fatal(err)
} }
@ -131,13 +192,76 @@ func TestDecodeMessageA(test *testing.T) {
} }
func TestDecodeMessageAErr(test *testing.T) { func TestDecodeMessageAErr(test *testing.T) {
_, _, _, err := decodeMessageA(bytes.NewReader([]byte { _, _, _, _, err := decodeMessageA(bytes.NewReader([]byte {
0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04, 0x58, 0x00, 0xFE, 0xAB, 0xC3, 0x10, 0x4F, 0x04,
0x6B, 0x12, 0x6B, 0x12,
0x01, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x06,
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
})) }), defaultSizeLimit)
if !errors.Is(err, io.ErrUnexpectedEOF) { if !errors.Is(err, io.ErrUnexpectedEOF) {
test.Fatalf("wrong error: %v", err) test.Fatalf("wrong error: %v", err)
} }
} }
func TestEncodeDecodeMessageA(test *testing.T) {
correctTransID := int64(2)
correctMethod := uint16(30)
correctPayload := []byte("good")
buffer := bytes.Buffer { }
err := encodeMessageA(&buffer, defaultSizeLimit, correctTransID, correctMethod, correctPayload)
if err != nil { test.Fatal(err) }
transID, method, chunked, payload, err := decodeMessageA(&buffer, defaultSizeLimit)
if got, correct := transID, int64(2); got != correct {
test.Fatalf("not equal: %v %v", got, correct)
}
if got, correct := method, uint16(30); got != correct {
test.Fatalf("not equal: %v %v", got, correct)
}
if chunked {
test.Fatalf("message should not be chunked")
}
if got, correct := payload, correctPayload; !slices.Equal(got, correct) {
test.Fatalf("not equal: %v %v", got, correct)
}
}
func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), serverFunc func(conn Conn)) {
network := "tcp"
addr := "localhost:7959"
// server
listener, err := net.Listen(network, addr)
if err != nil { test.Fatal(err) }
test.Cleanup(func() { listener.Close() })
go func() {
test.Log("SERVER listening")
conn, err := listener.Accept()
if err != nil { test.Error("SERVER", err); return }
defer conn.Close()
test.Cleanup(func() { conn.Close() })
a := AdaptA(conn, ServerSide)
test.Cleanup(func() { a.Close() })
serverFunc(a)
test.Log("SERVER closing")
}()
// client
test.Log("CLIENT dialing")
conn, err := net.Dial(network, addr)
if err != nil { test.Fatal("CLIENT", err) }
test.Log("CLIENT dialed")
a := AdaptA(conn, ClientSide)
test.Cleanup(func() { a.Close() })
clientFunc(a)
test.Log("CLIENT waiting for connection close...")
trans, err := a.AcceptTrans()
if !errors.Is(err, io.EOF) {
test.Error("CLIENT wrong error:", err)
test.Fatal("CLIENT trans:", trans)
}
test.Log("CLIENT DONE")
conn.Close()
}

View File

@ -2,19 +2,23 @@ package hopp
import "io" import "io"
import "net" import "net"
import "bytes"
import "errors"
import "context" import "context"
import "git.tebibyte.media/sashakoshka/hopp/tape" import "git.tebibyte.media/sashakoshka/hopp/tape"
// B implements METADAPT-B over a multiplexed stream-oriented transport such as // B implements METADAPT-B over a multiplexed stream-oriented transport such as
// QUIC. // QUIC.
type b struct { type b struct {
sizeLimit int64
underlying MultiConn underlying MultiConn
} }
// AdaptB returns a connection implementing METADAPT-B over a singular stream- // AdaptB returns a connection implementing METADAPT-B over a multiplexed
// oriented transport such as TCP or UNIX domain stream sockets. // stream-oriented transport such as QUIC.
func AdaptB(underlying MultiConn) Conn { func AdaptB(underlying MultiConn) Conn {
return &b { return &b {
sizeLimit: defaultSizeLimit,
underlying: underlying, underlying: underlying,
} }
} }
@ -34,33 +38,105 @@ func (this *b) RemoteAddr() net.Addr {
func (this *b) OpenTrans() (Trans, error) { func (this *b) OpenTrans() (Trans, error) {
stream, err := this.underlying.OpenStream() stream, err := this.underlying.OpenStream()
if err != nil { return nil, err } if err != nil { return nil, err }
return transB { underlying: stream }, nil return this.newTrans(stream), nil
} }
func (this *b) AcceptTrans() (Trans, error) { func (this *b) AcceptTrans() (Trans, error) {
stream, err := this.underlying.AcceptStream(context.Background()) stream, err := this.underlying.AcceptStream(context.Background())
if err != nil { return nil, err } if err != nil { return nil, err }
return transB { underlying: stream }, nil return this.newTrans(stream), nil
}
func (this *b) SetSizeLimit(limit int64) {
this.sizeLimit = limit
}
func (this *b) newTrans(underlying Stream) *transB {
return &transB {
sizeLimit: this.sizeLimit,
underlying: underlying,
}
} }
type transB struct { type transB struct {
sizeLimit int64
underlying Stream underlying Stream
currentData io.Reader
currentWriter *writerB
} }
func (trans transB) Close() error { func (this *transB) Close() error {
return trans.underlying.Close() return this.underlying.Close()
} }
func (trans transB) ID() int64 { func (this *transB) ID() int64 {
return trans.underlying.ID() return this.underlying.ID()
} }
func (trans transB) Send(method uint16, data []byte) error { func (this *transB) Send(method uint16, data []byte) error {
return encodeMessageB(trans.underlying, method, data) return encodeMessageB(this.underlying, this.sizeLimit, method, data)
} }
func (trans transB) Receive() (uint16, []byte, error) { func (this *transB) SendWriter(method uint16) (io.WriteCloser, error) {
return decodeMessageB(trans.underlying) if this.currentWriter != nil {
this.currentWriter.Close()
}
// TODO: come up with a fix that allows us to pipe data through the
// writer. as of now, it just reads whatever is written into a buffer
// and sends the message on close. we should probably introduce chunked
// encoding to METADAPT-B to fix this. the implementation would be
// simpler than on METADAPT-A, but most of the code could just be
// copied over.
writer := &writerB {
parent: this,
method: method,
}
this.currentWriter = writer
return writer, nil
}
func (this *transB) Receive() (uint16, []byte, error) {
// get a reader for the next message
method, size, data, err := this.receiveReader()
if err != nil { return 0, nil, err }
// read the entire thing
payloadBuffer := make([]byte, int(size))
_, err = io.ReadFull(data, payloadBuffer)
if err != nil { return 0, nil, err }
// we have used up the reader by now so we can forget it exists
this.currentData = nil
return method, payloadBuffer, nil
}
func (this *transB) ReceiveReader() (uint16, io.Reader, error) {
method, _, data, err := this.receiveReader()
return method, data, err
}
func (this *transB) receiveReader() (uint16, int64, io.Reader, error) {
// decode the message
method, size, data, err := decodeMessageB(this.underlying, this.sizeLimit)
if err != nil { return 0, 0, nil, err }
// discard current reader if there is one
if this.currentData == nil {
io.Copy(io.Discard, this.currentData)
}
this.currentData = data
return method, size, data, nil
}
type writerB struct {
parent *transB
buffer bytes.Buffer
method uint16
}
func (this *writerB) Write(data []byte) (int, error) {
return this.buffer.Write(data)
}
func (this *writerB) Close() error {
return this.parent.Send(this.method, this.buffer.Bytes())
} }
// MultiConn represens a multiplexed stream-oriented transport for use in // MultiConn represens a multiplexed stream-oriented transport for use in
@ -84,27 +160,42 @@ type Stream interface {
ID() int64 ID() int64
} }
func encodeMessageB(writer io.Writer, method uint16, data []byte) error { func encodeMessageB(writer io.Writer, sizeLimit int64, method uint16, data []byte) error {
buffer := make([]byte, 4 + len(data)) if int64(len(data)) > sizeLimit {
return ErrPayloadTooLarge
}
buffer := make([]byte, 10 + len(data))
tape.EncodeI16(buffer[:2], method) tape.EncodeI16(buffer[:2], method)
length, ok := tape.U16CastSafe(len(data)) tape.EncodeI64(buffer[2:10], uint64(len(data)))
if !ok { return ErrPayloadTooLarge } copy(buffer[10:], data)
tape.EncodeI16(buffer[2:4], length)
copy(buffer[4:], data)
_, err := writer.Write(buffer) _, err := writer.Write(buffer)
return err return err
} }
func decodeMessageB(reader io.Reader) (uint16, []byte, error) { func decodeMessageB(
headerBuffer := [4]byte { } reader io.Reader,
_, err := io.ReadFull(reader, headerBuffer[:]) sizeLimit int64,
if err != nil { return 0, nil, err } ) (
method, err := tape.DecodeI16[uint16](headerBuffer[:2]) method uint16,
if err != nil { return 0, nil, err } size int64,
length, err := tape.DecodeI16[uint16](headerBuffer[2:4]) data io.Reader,
if err != nil { return 0, nil, err } err error,
payloadBuffer := make([]byte, int(length)) ) {
_, err = io.ReadFull(reader, payloadBuffer) headerBuffer := [10]byte { }
if err != nil { return 0, nil, err } _, err = io.ReadFull(reader, headerBuffer[:])
return method, payloadBuffer, nil if err != nil {
if errors.Is(err, io.EOF) { return 0, 0, nil, io.ErrUnexpectedEOF }
return 0, 0, nil, err
}
method, err = tape.DecodeI16[uint16](headerBuffer[:2])
if err != nil { return 0, 0, nil, err }
length, err := tape.DecodeI64[uint64](headerBuffer[2:10])
if err != nil { return 0, 0, nil, err }
if length > uint64(sizeLimit) {
return 0, 0, nil, ErrPayloadTooLarge
}
return method, int64(length), &io.LimitedReader {
R: reader,
N: int64(length),
}, nil
} }

View File

@ -9,9 +9,9 @@ import "testing"
func TestEncodeMessageB(test *testing.T) { func TestEncodeMessageB(test *testing.T) {
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
payload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 } payload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 }
err := encodeMessageB(buffer, 0x6B12, payload) err := encodeMessageB(buffer, defaultSizeLimit, 0x6B12, payload)
correct := []byte { correct := []byte {
0x6B, 0x12, 0x6B, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x06, 0x00, 0x06,
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
} }
@ -26,24 +26,25 @@ func TestEncodeMessageB(test *testing.T) {
func TestEncodeMessageBErr(test *testing.T) { func TestEncodeMessageBErr(test *testing.T) {
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
payload := make([]byte, 0x10000) payload := make([]byte, 0x10000)
err := encodeMessageB(buffer, 0x6B12, payload) err := encodeMessageB(buffer, 255, 0x6B12, payload)
if !errors.Is(err, ErrPayloadTooLarge) { if !errors.Is(err, ErrPayloadTooLarge) {
test.Fatalf("wrong error: %v", err) test.Fatalf("wrong error: %v", err)
} }
} }
func TestDecodeMessageB(test *testing.T) { func TestDecodeMessageB(test *testing.T) {
method, payload, err := decodeMessageB(bytes.NewReader([]byte { method, _, data, err := decodeMessageB(bytes.NewReader([]byte {
0x6B, 0x12, 0x6B, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x06, 0x00, 0x06,
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
})) }), defaultSizeLimit)
if err != nil { if err != nil {
test.Fatal(err) test.Fatal(err)
} }
if got, correct := method, uint16(0x6B12); got != correct { if got, correct := method, uint16(0x6B12); got != correct {
test.Fatalf("not equal: %v %v", got, correct) test.Fatalf("not equal: %v %v", got, correct)
} }
payload, _ := io.ReadAll(data)
correctPayload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 } correctPayload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 }
if got, correct := payload, correctPayload; !slices.Equal(got, correct) { if got, correct := payload, correctPayload; !slices.Equal(got, correct) {
test.Fatalf("not equal: %v %v", got, correct) test.Fatalf("not equal: %v %v", got, correct)
@ -51,11 +52,9 @@ func TestDecodeMessageB(test *testing.T) {
} }
func TestDecodeMessageBErr(test *testing.T) { func TestDecodeMessageBErr(test *testing.T) {
_, _, err := decodeMessageB(bytes.NewReader([]byte { _, _, _, err := decodeMessageB(bytes.NewReader([]byte {
0x6B, 0x12, 0x6B, 0x12, 0x00, 0x00, 0x00, 0x00,
0x01, 0x06, }), defaultSizeLimit)
0x00, 0x01, 0x02, 0x03, 0x04, 0x05,
}))
if !errors.Is(err, io.ErrUnexpectedEOF) { if !errors.Is(err, io.ErrUnexpectedEOF) {
test.Fatalf("wrong error: %v", err) test.Fatalf("wrong error: %v", err)
} }

View File

@ -1,54 +0,0 @@
package hopp
import "net"
import "context"
import "github.com/quic-go/quic-go"
var _ MultiConn = quicMultiConn { }
type quicMultiConn struct {
underlying quic.Connection
}
func (conn quicMultiConn) Close() error {
return conn.underlying.CloseWithError(0, "good bye")
}
func (conn quicMultiConn) LocalAddr() net.Addr {
return conn.underlying.LocalAddr()
}
func (conn quicMultiConn) RemoteAddr() net.Addr {
return conn.underlying.RemoteAddr()
}
func (conn quicMultiConn) AcceptStream(ctx context.Context) (Stream, error) {
strea, err := conn.underlying.AcceptStream(ctx)
if err != nil { return nil, err }
return quicStream { underlying: strea }, nil
}
func (conn quicMultiConn) OpenStream() (Stream, error) {
strea, err := conn.underlying.OpenStream()
if err != nil { return nil, err }
return quicStream { underlying: strea }, nil
}
type quicStream struct {
underlying quic.Stream
}
func (strea quicStream) Read(buffer []byte) (n int, err error) {
return strea.underlying.Read(buffer)
}
func (strea quicStream) Write(buffer []byte) (n int, err error) {
return strea.underlying.Read(buffer)
}
func (strea quicStream) Close() error {
return strea.underlying.Close()
}
func (strea quicStream) ID() int64 {
return int64(strea.underlying.StreamID())
}

134
tape/decode.go Normal file
View File

@ -0,0 +1,134 @@
package tape
import "io"
import "math"
import "bufio"
// Decodable is any type that can decode itself from a decoder.
type Decodable interface {
// Decode reads data from decoder, replacing the data of the object. It
// returns the amount of bytes written, and an error if the write
// stopped early.
Decode(decoder *Decoder) (n int, err error)
}
// Decoder decodes data from an [io.Reader].
type Decoder struct {
bufio.Reader
}
// NewDecoder creates a new decoder that reads from reader.
func NewDecoder(reader io.Reader) *Decoder {
decoder := &Decoder { }
decoder.Reader.Reset(reader)
return decoder
}
// ReadFull calls [io.ReadFull] on the reader.
func (this *Decoder) ReadFull(buffer []byte) (n int, err error) {
return io.ReadFull(this, buffer)
}
// ReadInt8 decodes an 8-bit signed integer from the input reader.
func (this *Decoder) ReadInt8() (value int8, n int, err error) {
uncasted, n, err := this.ReadUint8()
return int8(uncasted), n, err
}
// ReadUint8 decodes an 8-bit unsigned integer from the input reader.
func (this *Decoder) ReadUint8() (value uint8, n int, err error) {
buffer := [1]byte { }
n, err = this.ReadFull(buffer[:])
return uint8(buffer[0]), n, err
}
// ReadInt16 decodes an 16-bit signed integer from the input reader.
func (this *Decoder) ReadInt16() (value int16, n int, err error) {
uncasted, n, err := this.ReadUint16()
return int16(uncasted), n, err
}
// ReadUint16 decodes an 16-bit unsigned integer from the input reader.
func (this *Decoder) ReadUint16() (value uint16, n int, err error) {
buffer := [2]byte { }
n, err = this.ReadFull(buffer[:])
return uint16(buffer[0]) << 8 |
uint16(buffer[1]), n, err
}
// ReadInt32 decodes an 32-bit signed integer from the input reader.
func (this *Decoder) ReadInt32() (value int32, n int, err error) {
uncasted, n, err := this.ReadUint32()
return int32(uncasted), n, err
}
// ReadUint32 decodes an 32-bit unsigned integer from the input reader.
func (this *Decoder) ReadUint32() (value uint32, n int, err error) {
buffer := [4]byte { }
n, err = this.ReadFull(buffer[:])
return uint32(buffer[0]) << 24 |
uint32(buffer[1]) << 16 |
uint32(buffer[2]) << 8 |
uint32(buffer[3]), n, err
}
// ReadInt64 decodes an 64-bit signed integer from the input reader.
func (this *Decoder) ReadInt64() (value int64, n int, err error) {
uncasted, n, err := this.ReadUint64()
return int64(uncasted), n, err
}
// ReadUint64 decodes an 64-bit unsigned integer from the input reader.
func (this *Decoder) ReadUint64() (value uint64, n int, err error) {
buffer := [8]byte { }
n, err = this.ReadFull(buffer[:])
return uint64(buffer[0]) << 56 |
uint64(buffer[1]) << 48 |
uint64(buffer[2]) << 40 |
uint64(buffer[3]) << 32 |
uint64(buffer[4]) << 24 |
uint64(buffer[5]) << 16 |
uint64(buffer[6]) << 8 |
uint64(buffer[7]), n, err
}
// ReadIntN decodes an N-byte signed integer from the input reader.
func (this *Decoder) ReadIntN(bytes int) (value int64, n int, err error) {
uncasted, n, err := this.ReadUintN(bytes)
return int64(uncasted), n, err
}
// ReadUintN decodes an N-byte unsigned integer from the input reader.
func (this *Decoder) ReadUintN(bytes int) (value uint64, n int, err error) {
// TODO: don't make multiple read calls (without allocating)
buffer := [1]byte { }
for bytesLeft := bytes; bytesLeft > 0; bytesLeft -- {
nn, err := this.ReadFull(buffer[:])
n += nn; if err != nil { return 0, n, err }
value |= uint64(buffer[0]) << ((bytesLeft - 1) * 8)
}
// *read* integers too big, but don't return them.
if bytes > 8 { value = 0 }
return value, n, nil
}
// ReadFloat32 decldes a 32-bit floating point value from the input reader.
func (this *Decoder) ReadFloat32() (value float32, n int, err error) {
bits, nn, err := this.ReadUint32()
n += nn; if err != nil { return 0, n, err }
return math.Float32frombits(bits), n, nil
}
// ReadFloat64 decldes a 64-bit floating point value from the input reader.
func (this *Decoder) ReadFloat64() (value float64, n int, err error) {
bits, nn, err := this.ReadUint64()
n += nn; if err != nil { return 0, n, err }
return math.Float64frombits(bits), n, nil
}
// ReadTag decodes a [Tag] from the input reader.
func (this *Decoder) ReadTag() (value Tag, n int, err error) {
uncasted, nn, err := this.ReadUint8()
n += nn; if err != nil { return 0, n, err }
return Tag(uncasted), n, nil
}

384
tape/dynamic.go Normal file
View File

@ -0,0 +1,384 @@
package tape
// dont smoke reflection, kids!!!!!!!!!
// totally reflectric, reflectrified, etc. this is probably souper slow but
// certainly no slower than the built in json encoder i'd imagine.
// TODO: add support for struct tags: `tape:"0000"`, tape:"0001"` so they can get
// transformed into tables with a defined schema
import "fmt"
import "reflect"
var dummyMap map[uint16] any
// EncodeAny encodes an "any" value. Returns an error if the underlying type is
// unsupported. Supported types are:
//
// - int
// - int<N>
// - uint
// - uint<N>
// - string
// - []<supported type>
// - map[uint16]<supported type>
func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
// TODO use reflection for all of this to ignore type names
// primitives
switch value := value.(type) {
case int: return encoder.WriteInt32(int32(value))
case uint: return encoder.WriteUint32(uint32(value))
case int8: return encoder.WriteInt8(value)
case uint8: return encoder.WriteUint8(value)
case int16: return encoder.WriteInt16(value)
case uint16: return encoder.WriteUint16(value)
case int32: return encoder.WriteInt32(value)
case uint32: return encoder.WriteUint32(value)
case int64: return encoder.WriteInt64(value)
case uint64: return encoder.WriteUint64(value)
case string: return EncodeAny(encoder, []byte(value), tag)
case []byte:
if tag.Is(LBA) {
nn, err := encoder.WriteUintN(uint64(len(value)), tag.CN() + 1)
n += nn; if err != nil { return n, err }
}
nn, err := encoder.Write(value)
n += nn; if err != nil { return n, err }
return n, nil
}
// aggregates
reflectType := reflect.TypeOf(value)
switch reflectType.Kind() {
case reflect.Slice:
return encodeAnySlice(encoder, value, tag)
// case reflect.Array:
// return encodeAnySlice(encoder, reflect.ValueOf(value).Slice(0, reflectType.Len()).Interface(), tag)
case reflect.Map:
if reflectType.Key() == reflect.TypeOf(uint16(0)) {
return encodeAnyMap(encoder, value, tag)
}
return n, fmt.Errorf("cannot encode map key %T, key must be uint16", value)
}
return n, fmt.Errorf("cannot encode type %T", value)
}
// DecodeAny decodes data and places it into destination, which must be a
// pointer to a supported type. See [EncodeAny] for a list of supported types.
func DecodeAny(decoder *Decoder, destination any, tag Tag) (n int, err error) {
reflectDestination := reflect.ValueOf(destination)
if reflectDestination.Kind() != reflect.Pointer {
return n, fmt.Errorf("expected pointer destination, not %v", destination)
}
return decodeAny(decoder, reflectDestination.Elem(), tag)
}
// unknownSlicePlaceholder is inserted by skeletonValue and informs the program
// that the destination for the slice needs to be generated based on the item
// tag in the OTA.
type unknownSlicePlaceholder struct { }
var unknownSlicePlaceholderType = reflect.TypeOf(unknownSlicePlaceholder { })
// decodeAny is internal to [DecodeAny]. It takes in an addressable
// [reflect.Value] as the destination.
func decodeAny(decoder *Decoder, destination reflect.Value, tag Tag) (n int, err error) {
errWrongDestinationType := func(expected string) error {
panic(fmt.Errorf(
// return fmt.Errorf(
"expected %s destination, not %v",
expected, destination))
}
switch tag.WithoutCN() {
case SI:
// SI: (none)
err = setInt(destination, uint64(tag.CN()))
if err != nil { return n, err }
case LI:
// LI: <value: IntN>
nn, err := decodeAndSetInt(decoder, destination, tag.CN() + 1)
n += nn; if err != nil { return n, err }
case FP:
// FP: <value: FloatN>
nn, err := decodeAndSetFloat(decoder, destination, tag.CN() + 1)
n += nn; if err != nil { return n, err }
case SBA:
// SBA: <data: U8>*
buffer := make([]byte, tag.CN())
nn, err := decoder.Read(buffer)
n += nn; if err != nil { return n, err }
err = setByteArray(destination, buffer)
if err != nil { return n, err }
case LBA:
// LBA: <length: UN> <data: U8>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
buffer := make([]byte, length)
nn, err = decoder.Read(buffer)
n += nn; if err != nil { return n, err }
err = setByteArray(destination, buffer)
if err != nil { return n, err }
case OTA:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
oneTag, nn, err := decoder.ReadTag()
n += nn; if err != nil { return n, err }
if destination.Kind() != reflect.Slice {
return n, errWrongDestinationType("slice")
}
if destination.Cap() < int(length) {
destination.Grow(destination.Cap() - int(length))
}
destination.SetLen(int(length))
for index := range length {
nn, err := decodeAny(decoder, destination.Index(int(index)), oneTag)
n += nn; if err != nil { return n, err }
}
case KTV:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
table := destination
if table.Type() != reflect.TypeOf(dummyMap) {
return n, errWrongDestinationType("map[uint16] any")
}
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
table.Clear()
for _ = range length {
key, nn, err := decoder.ReadUint16()
n += nn; if err != nil { return n, err }
itemTag, nn, err := decoder.ReadTag()
n += nn; if err != nil { return n, err }
value, err := skeletonValue(decoder, itemTag)
if err != nil { return n, err }
nn, err = decodeAny(decoder, value.Elem(), itemTag)
n += nn; if err != nil { return n, err }
table.SetMapIndex(reflect.ValueOf(key), value)
}
default:
return n, fmt.Errorf("unknown TN %d", tag.TN())
}
return n, nil
}
// TagAny returns the correct tag for an "any" value. Returns an error if the
// underlying type is unsupported. See [EncodeAny] for a list of supported
// types.
func TagAny(value any) (Tag, error) {
// TODO use reflection for all of this to ignore type names
// primitives
switch value := value.(type) {
case int, uint: return LI.WithCN(3), nil
case int8, uint8: return LI.WithCN(0), nil
case int16, uint16: return LI.WithCN(1), nil
case int32, uint32: return LI.WithCN(3), nil
case int64, uint64: return LI.WithCN(7), nil
case string: return bufferLenTag(len(value)), nil
case []byte: return bufferLenTag(len(value)), nil
}
// aggregates
reflectType := reflect.TypeOf(value)
switch reflectType.Kind() {
case reflect.Slice: return OTA.WithCN(IntBytes(uint64(reflect.ValueOf(value).Len())) - 1), nil
case reflect.Array: return OTA.WithCN(reflectType.Len()), nil
case reflect.Map:
if reflectType.Key() == reflect.TypeOf(uint16(0)) {
return KTV.WithCN(IntBytes(uint64(reflect.ValueOf(value).Len())) - 1), nil
}
return 0, fmt.Errorf("cannot encode map key %T, key must be uint16", value)
}
return 0, fmt.Errorf("cannot get tag of type %T", value)
}
func encodeAnySlice(encoder *Encoder, value any, tag Tag) (n int, err error) {
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
reflectValue := reflect.ValueOf(value)
nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1)
n += nn; if err != nil { return n, err }
reflectType := reflect.TypeOf(value)
oneTag, err := TagAny(reflect.Zero(reflectType.Elem()).Interface())
if err != nil { return n, err }
for index := 0; index < reflectValue.Len(); index += 1 {
item := reflectValue.Index(index).Interface()
itemTag, err := TagAny(item)
if err != nil { return n, err }
if itemTag.CN() > oneTag.CN() { oneTag = itemTag }
}
if oneTag.Is(SBA) { oneTag += 1 << 5 }
nn, err = encoder.WriteUint8(uint8(oneTag))
n += nn; if err != nil { return n, err }
for index := 0; index < reflectValue.Len(); index += 1 {
item := reflectValue.Index(index).Interface()
nn, err = EncodeAny(encoder, item, oneTag)
n += nn; if err != nil { return n, err }
}
return n, err
}
func encodeAnyMap(encoder *Encoder, value any, tag Tag) (n int, err error) {
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
reflectValue := reflect.ValueOf(value)
nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1)
n += nn; if err != nil { return n, err }
iter := reflectValue.MapRange()
for iter.Next() {
key := iter.Key().Interface().(uint16)
value := iter.Value().Interface()
nn, err = encoder.WriteUint16(key)
n += nn; if err != nil { return n, err }
itemTag, err := TagAny(value)
if err != nil { return n, err }
nn, err = encoder.WriteUint8(uint8(itemTag))
n += nn; if err != nil { return n, err }
nn, err = EncodeAny(encoder, value, itemTag)
n += nn; if err != nil { return n, err }
}
return n, nil
}
// setInt expects a settable destination.
func setInt(destination reflect.Value, value uint64) error {
switch {
case destination.CanInt():
destination.Set(reflect.ValueOf(int64(value)).Convert(destination.Type()))
case destination.CanUint():
destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
default:
return fmt.Errorf("cannot assign integer to %T", destination.Interface())
}
return nil
}
// setInt expects a settable destination.
func setFloat(destination reflect.Value, value float64) error {
if !destination.CanFloat() {
return fmt.Errorf("cannot assign float to %T", destination.Interface())
}
destination.Set(reflect.ValueOf(value).Convert(destination.Type()))
return nil
}
// setByteArrayexpects a settable destination.
func setByteArray(destination reflect.Value, value []byte) error {
typ := destination.Type()
if typ.Kind() != reflect.Slice {
return fmt.Errorf("cannot assign %T to ", value)
}
if typ.Elem() != reflect.TypeOf(byte(0)) {
return fmt.Errorf("cannot convert %T to *[]byte", value)
}
destination.Set(reflect.ValueOf(value))
return nil
}
// decodeAndSetInt expects a settable destination.
func decodeAndSetInt(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
value, nn, err := decoder.ReadUintN(bytes)
n += nn; if err != nil { return n, err }
return n, setInt(destination, value)
}
// decodeAndSetInt expects a settable destination.
func decodeAndSetFloat(decoder *Decoder, destination reflect.Value, bytes int) (n int, err error) {
switch bytes {
case 8:
value, nn, err := decoder.ReadFloat64()
n += nn; if err != nil { return n, err }
return n, setFloat(destination, float64(value))
case 4:
value, nn, err := decoder.ReadFloat32()
n += nn; if err != nil { return n, err }
return n, setFloat(destination, float64(value))
}
return n, fmt.Errorf("cannot decode float%d", bytes * 8)
}
// skeletonValue returns a pointer value. In order for it to be set, it must be
// dereferenced using Elem().
func skeletonValue(decoder *Decoder, tag Tag) (reflect.Value, error) {
typ, err := typeOf(decoder, tag)
if err != nil { return reflect.Value { }, err }
return reflect.New(typ), nil
}
// typeOf returns the type of the current tag being decoded. It does not use up
// the decoder, it only peeks.
func typeOf(decoder *Decoder, tag Tag) (reflect.Type, error) {
switch tag.WithoutCN() {
case SI:
return reflect.TypeOf(uint8(0)), nil
case LI:
switch tag.CN() {
case 0: return reflect.TypeOf(uint8(0)), nil
case 1: return reflect.TypeOf(uint16(0)), nil
case 3: return reflect.TypeOf(uint32(0)), nil
case 7: return reflect.TypeOf(uint64(0)), nil
}
return nil, fmt.Errorf("unknown CN %d for LI", tag.CN())
case FP:
switch tag.CN() {
case 3: return reflect.TypeOf(float32(0)), nil
case 7: return reflect.TypeOf(float64(0)), nil
}
return nil, fmt.Errorf("unknown CN %d for FP", tag.CN())
case SBA: return reflect.SliceOf(reflect.TypeOf(byte(0))), nil
case LBA: return reflect.SliceOf(reflect.TypeOf(byte(0))), nil
case OTA:
elemTag, dimension, err := peekSlice(decoder, tag)
if err != nil { return nil, err }
if elemTag.Is(OTA) { panic("peekSlice cannot return OTA") }
typ, err := typeOf(decoder, elemTag)
if err != nil { return nil, err }
for _ = range dimension {
typ = reflect.SliceOf(typ)
}
return typ, nil
case KTV: return reflect.TypeOf(dummyMap), nil
}
return nil, fmt.Errorf("unknown TN %d", tag.TN())
}
// peekSlice returns the element tag and dimension count of the OTA currently
// being decoded. It does not use up the decoder, it only peeks.
func peekSlice(decoder *Decoder, tag Tag) (Tag, int, error) {
offset := 0
dimension := 0
for {
elem, populated, n, err := peekSliceOnce(decoder, tag, offset)
if err != nil { return 0, 0, err }
offset += n
dimension += 1
if elem.Is(OTA) {
if !populated {
return LBA, dimension + 1, nil
}
} else {
return elem, dimension, nil
}
}
}
// peekSliceOnce returns the element tag of the OTA located offset bytes ahead
// of the current position. It does not use up the decoder, it only peeks. The n
// return value denotes how far away from 0 it peeked. If the OTA has more than
// zero items, populated will be set to true.
func peekSliceOnce(decoder *Decoder, tag Tag, offset int) (elem Tag, populated bool, n int, err error) {
lengthStart := offset
lengthEnd := lengthStart + tag.CN() + 1
elemTagStart := lengthEnd
elemTagEnd := elemTagStart + 1
headerBytes, err := decoder.Peek(elemTagEnd)
if err != nil { return 0, false, 0, err }
elem = Tag(headerBytes[len(headerBytes)])
for index := lengthStart; index < lengthEnd; index += 1 {
if headerBytes[index] > 0 {
populated = true
break
}
}
n = elemTagEnd
return
}

136
tape/dynamic_test.go Normal file
View File

@ -0,0 +1,136 @@
package tape
import "fmt"
import "bytes"
import "testing"
import "reflect"
import tu "git.tebibyte.media/sashakoshka/hopp/internal/testutil"
func TestEncodeAnyInt(test *testing.T) {
err := testEncodeAny(test, uint8(0xCA), LI.WithCN(0), tu.S(0xCA))
if err != nil { test.Fatal(err) }
err = testEncodeAny(test, 400, LI.WithCN(3), tu.S(
0, 0, 0x1, 0x90,
))
if err != nil { test.Fatal(err) }
}
func TestEncodeAnyTable(test *testing.T) {
err := testEncodeAny(test, map[uint16] any {
0xF3B9: 1,
0x0102: 2,
0x0000: "hi!",
0xFFFF: []uint16 { 0xBEE5, 0x7777 },
0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} },
}, KTV.WithCN(0), tu.S(5).AddVar(
[]byte {
0xF3, 0xB9,
byte(LI.WithCN(3)),
0, 0, 0, 1,
},
[]byte {
0x01, 0x02,
byte(LI.WithCN(3)),
0, 0, 0, 2,
},
[]byte {
0, 0,
byte(SBA.WithCN(3)),
'h', 'i', '!',
},
[]byte {
0xFF, 0xFF,
byte(OTA.WithCN(0)), 2, byte(LI.WithCN(1)),
0xBE, 0xE5, 0x77, 0x77,
},
[]byte {
0x12, 0x34,
byte(OTA.WithCN(0)), 2, byte(OTA.WithCN(0)),
1, byte(LI.WithCN(1)),
0, 0x5,
2, byte(LI.WithCN(1)),
0, 0x17,
0xAA, 0xAA,
},
))
if err != nil { test.Fatal(err) }
}
func TestEncodeDecodeAnyMap(test *testing.T) {
err := testEncodeDecodeAny(test, map[uint16] any {
0xF3B9: 1,
0x0102: 2,
0x0000: "hi!",
0xFFFF: []uint16 { 0xBEE5, 0x7777 },
0x1234: [][]uint16 { []uint16 { 0x5 }, []uint16 { 0x17, 0xAAAA} },
}, nil)
if err != nil { test.Fatal(err) }
}
func encAny(value any) ([]byte, Tag, int, error) {
tag, err := TagAny(value)
if err != nil { return nil, 0, 0, err }
buffer := bytes.Buffer { }
encoder := NewEncoder(&buffer)
n, err := EncodeAny(encoder, value, tag)
if err != nil { return nil, 0, n, err }
encoder.Flush()
return buffer.Bytes(), tag, n, nil
}
func decAny(data []byte) (Tag, any, int, error) {
destination := map[uint16] any { }
tag, err := TagAny(destination)
if err != nil { return 0, nil, 0, err }
n, err := DecodeAny(NewDecoder(bytes.NewBuffer(data)), &destination, tag)
if err != nil { return 0, nil, n, err }
return tag, destination, n, nil
}
func testEncodeAny(test *testing.T, value any, correctTag Tag, correctBytes tu.Snake) error {
bytes, tag, n, err := encAny(value)
if err != nil { return err }
test.Log("n: ", n)
test.Log("tag: ", tag)
test.Log("got: ", tu.HexBytes(bytes))
test.Log("correct:", correctBytes)
if tag != correctTag {
return fmt.Errorf("tag not equal")
}
if ok, n := correctBytes.Check(bytes); !ok {
return fmt.Errorf("bytes not equal: %d", n)
}
if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes))
}
return nil
}
func testEncodeDecodeAny(test *testing.T, value, correctValue any) error {
if correctValue == nil {
correctValue = value
}
test.Log("encoding...")
bytes, tag, n, err := encAny(value)
if err != nil { return err }
test.Log("n: ", n)
test.Log("tag:", tag)
test.Log("got:", tu.HexBytes(bytes))
test.Log("decoding...", tag)
if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes))
}
_, decoded, n, err := decAny(bytes)
if err != nil { return err }
test.Log("got: ", decoded)
test.Log("correct:", correctValue)
if !reflect.DeepEqual(decoded, correctValue) {
return fmt.Errorf("values not equal")
}
if n != len(bytes) {
return fmt.Errorf("n not equal: %d != %d", n, len(bytes))
}
return nil
}

118
tape/encode.go Normal file
View File

@ -0,0 +1,118 @@
package tape
import "io"
import "math"
import "bufio"
// Encodable is any type that can write itself to an encoder.
type Encodable interface {
// Encode sends data to encoder. It returns the amount of bytes written,
// and an error if the write stopped early.
Encode(encoder *Encoder) (n int, err error)
}
// Encoder encodes data to an io.Writer.
type Encoder struct {
bufio.Writer
}
// NewEncoder creates a new encoder that writes to writer.
func NewEncoder(writer io.Writer) *Encoder {
encoder := &Encoder { }
encoder.Reset(writer)
return encoder
}
// WriteInt8 encodes an 8-bit signed integer to the output writer.
func (this *Encoder) WriteInt8(value int8) (n int, err error) {
return this.WriteUint8(uint8(value))
}
// WriteUint8 encodes an 8-bit unsigned integer to the output writer.
func (this *Encoder) WriteUint8(value uint8) (n int, err error) {
return this.Write([]byte { byte(value) })
}
// WriteInt16 encodes an 16-bit signed integer to the output writer.
func (this *Encoder) WriteInt16(value int16) (n int, err error) {
return this.WriteUint16(uint16(value))
}
// WriteUint16 encodes an 16-bit unsigned integer to the output writer.
func (this *Encoder) WriteUint16(value uint16) (n int, err error) {
return this.Write([]byte {
byte(value >> 8),
byte(value),
})
}
// WriteInt32 encodes an 32-bit signed integer to the output writer.
func (this *Encoder) WriteInt32(value int32) (n int, err error) {
return this.WriteUint32(uint32(value))
}
// WriteUint32 encodes an 32-bit unsigned integer to the output writer.
func (this *Encoder) WriteUint32(value uint32) (n int, err error) {
return this.Write([]byte {
byte(value >> 24),
byte(value >> 16),
byte(value >> 8),
byte(value),
})
}
// WriteInt64 encodes an 64-bit signed integer to the output writer.
func (this *Encoder) WriteInt64(value int64) (n int, err error) {
return this.WriteUint64(uint64(value))
}
// WriteUint64 encodes an 64-bit unsigned integer to the output writer.
func (this *Encoder) WriteUint64(value uint64) (n int, err error) {
return this.Write([]byte {
byte(value >> 56),
byte(value >> 48),
byte(value >> 40),
byte(value >> 32),
byte(value >> 24),
byte(value >> 16),
byte(value >> 8),
byte(value),
})
}
// WriteIntN encodes an N-byte signed integer to the output writer.
func (this *Encoder) WriteIntN(value int64, bytes int) (n int, err error) {
return this.WriteUintN(uint64(value), bytes)
}
// for Write/ReadUintN, increase buffers if go somehow gets support for over 64
// bit integers. we could also make an expanding int type in goutil to use here,
// or maybe there is one in the stdlib. keep the int64 versions as well though
// because its ergonomic.
// WriteUintN encodes an N-byte unsigned integer to the output writer.
func (this *Encoder) WriteUintN(value uint64, bytes int) (n int, err error) {
// TODO: don't make multiple write calls (without allocating)
buffer := [1]byte { }
for bytesLeft := bytes; bytesLeft > 0; bytesLeft -- {
buffer[0] = byte(value) >> ((bytesLeft - 1) * 8)
nn, err := this.Write(buffer[:])
n += nn; if err != nil { return n, err }
}
return n, nil
}
// WriteFloat32 encodes a 32-bit floating point value to the output writer.
func (this *Encoder) WriteFloat32(value float32) (n int, err error) {
return this.WriteUint32(math.Float32bits(value))
}
// WriteFloat64 encodes a 64-bit floating point value to the output writer.
func (this *Encoder) WriteFloat64(value float64) (n int, err error) {
return this.WriteUint64(math.Float64bits(value))
}
// WriteTag encodes a [Tag] to the output writer.
func (this *Encoder) WriteTag(value Tag) (n int, err error) {
return this.WriteUint8(uint8(value))
}

12
tape/measure.go Normal file
View File

@ -0,0 +1,12 @@
package tape
// IntBytes returns the number of bytes required to hold a given unsigned
// integer.
func IntBytes(value uint64) int {
bytes := 0
for value > 0 || bytes == 0 {
value >>= 8;
bytes ++
}
return bytes
}

21
tape/measure_test.go Normal file
View File

@ -0,0 +1,21 @@
package tape
import "testing"
func TestIntBytes(test *testing.T) {
if correct, got := 1, IntBytes(0); correct != got {
test.Fatal("wrong:", got)
}
if correct, got := 1, IntBytes(1); correct != got {
test.Fatal("wrong:", got)
}
if correct, got := 1, IntBytes(16); correct != got {
test.Fatal("wrong:", got)
}
if correct, got := 1, IntBytes(255); correct != got {
test.Fatal("wrong:", got)
}
if correct, got := 2, IntBytes(256); correct != got {
test.Fatal("wrong:", got)
}
}

View File

@ -1,83 +0,0 @@
package tape
import "iter"
// DecodePairs decodes message tag/value pairs from a byte slice. It returns an
// iterator over all pairs, where the first value is the tag and the second is
// the value. If data yielded by the iterator is retained, it must be copied
// first.
func DecodePairs(data []byte) (iter.Seq2[uint16, []byte], error) {
// determine section bounds
if len(data) < 2 { return nil, ErrDataTooLarge }
length16, _ := DecodeI16[uint16](data[0:2])
data = data[2:]
length := int(length16)
headerSize := length * 4
if len(data) < headerSize { return nil, ErrDataTooLarge }
valuesData := data[headerSize:]
// ensure the value buffer is big enough
var valuesSize int
for index := range length {
offset := index * 4
end, _ := DecodeI16[uint16](data[offset + 2:offset + 4])
valuesSize = int(end)
}
if valuesSize > len(valuesData) {
return nil, ErrDataTooLarge
}
// return iterator
return func(yield func(uint16, []byte) bool) {
start := uint16(0)
for index := range length {
offset := index * 4
key , _ := DecodeI16[uint16](data[offset + 0:offset + 2])
end, _ := DecodeI16[uint16](data[offset + 2:offset + 4])
// if nextValuesOffset < len(valuesData) {
if !yield(key, valuesData[start:end]) {
return
}
// } else {
// if !yield(key, nil) {
// return
// }
// }
start = end
}
}, nil
}
// EncodePairs encodes message tag/value pairs into a byte slice.
func EncodePairs(pairs map[uint16] []byte) ([]byte, error) {
// determine section bounds
headerSize := 2 + len(pairs) * 4
valuesSize := 0
for _, value := range pairs {
valuesSize += len(value)
}
// generate data
buffer := make([]byte, headerSize + valuesSize)
length16, ok := U16CastSafe(len(pairs))
if !ok { return nil, ErrDataTooLarge }
EncodeI16[uint16](buffer[0:2], length16)
index := 0
end := headerSize
for key, value := range pairs {
start := end
end += len(value)
tagOffset := 2 + index * 4
end16, ok := U16CastSafe(end - headerSize)
if !ok { return nil, ErrDataTooLarge }
// write tag and length
EncodeI16[uint16](buffer[tagOffset + 0:tagOffset + 2], key)
EncodeI16[uint16](buffer[tagOffset + 2:tagOffset + 4], end16)
// write value
copy(buffer[start:end], value)
index ++
}
return buffer, nil
}

View File

@ -1,62 +0,0 @@
package tape
import "slices"
import "testing"
func TestDecodePairs(test *testing.T) {
pairs := map[uint16] []byte {
3894: []byte("foo"),
7: []byte("br"),
}
got, err := DecodePairs([]byte {
0, 2,
0, 7, 0, 2,
15, 54, 0, 5,
98, 114,
102, 111, 111})
if err != nil { test.Fatal(err) }
length := 0
for key, value := range got {
test.Log(key, value)
if !slices.Equal(pairs[key], value) { test.Fatal("not equal") }
length ++
}
test.Log("length")
if length != len(pairs) { test.Fatal("wrong length") }
}
func TestEncodePairs(test *testing.T) {
pairs := map[uint16] []byte {
3894: []byte("foo"),
7: []byte("br"),
}
got, err := EncodePairs(pairs)
if err != nil { test.Fatal(err) }
test.Log(got)
valid := slices.Equal(got, []byte {
0, 2,
15, 54, 0, 3,
0, 7, 0, 5,
102, 111, 111,
98, 114}) ||
slices.Equal(got, []byte {
0, 2,
0, 7, 0, 2,
15, 54, 0, 5,
98, 114,
102, 111, 111})
if !valid { test.Fatal("not equal") }
}
func FuzzDecodePairs(fuzz *testing.F) {
fuzz.Add([]byte {
0, 2,
0, 7, 0, 2,
15, 54, 0, 5,
98, 114,
102, 111, 111})
fuzz.Fuzz(func(t *testing.T, buffer []byte) {
// ensure it does not panic :P
DecodePairs(buffer)
})
}

63
tape/tag.go Normal file
View File

@ -0,0 +1,63 @@
package tape
import "fmt"
type Tag byte; const (
SI Tag = 0 << 5 // Small integer
LI Tag = 1 << 5 // Large integer
FP Tag = 2 << 5 // Floating point
SBA Tag = 3 << 5 // Small byte array
LBA Tag = 4 << 5 // Large byte array
OTA Tag = 5 << 5 // One-tag array
KTV Tag = 6 << 5 // Key-tag-value table
TNMask Tag = 0xE0 // The entire TN bitfield
CNMask Tag = 0x1F // The entire CN bitfield
CNLimit Tag = 32 // All valid CNs are < CNLimit
)
func (tag Tag) TN() int {
return int(tag >> 5)
}
func (tag Tag) CN() int {
return int(tag & CNMask)
}
func (tag Tag) WithCN(cn int) Tag {
return (tag & TNMask) | Tag(cn % 32)
}
func (tag Tag) WithoutCN() Tag {
return tag.WithCN(0)
}
func (tag Tag) Is(other Tag) bool {
return tag.TN() == other.TN()
}
func (tag Tag) String() string {
tn := fmt.Sprint(tag.TN())
switch tag.WithoutCN() {
case SI: tn = "SI"
case LI: tn = "LI"
case FP: tn = "FP"
case SBA: tn = "SBA"
case LBA: tn = "LBA"
case OTA: tn = "OTA"
case KTV: tn = "KTV"
}
return fmt.Sprintf("%s:%d", tn, tag.CN())
}
// BufferTag returns the appropriate tag for a buffer.
func BufferTag(value []byte) Tag {
return bufferLenTag(len(value))
}
func bufferLenTag(length int) Tag {
if length < int(CNLimit) {
return SBA.WithCN(length)
} else {
return LBA.WithCN(IntBytes(uint64(length)))
}
}

View File

@ -1,311 +0,0 @@
// Package tape implements Table Pair Encoding.
package tape
import "fmt"
const dataMaxSize = 0xFFFF
const uint16Max = 0xFFFF
// Error enumerates common errors in this package.
type Error string; const (
ErrWrongBufferLength Error = "wrong buffer length"
ErrDataTooLarge Error = "data too large"
)
// Error implements the error interface.
func (err Error) Error() string {
return string(err)
}
// Int8 is any 8-bit integer.
type Int8 interface { ~uint8 | ~int8 }
// Int16 is any 16-bit integer.
type Int16 interface { ~uint16 | ~int16 }
// Int32 is any 32-bit integer.
type Int32 interface { ~uint32 | ~int32 }
// Int64 is any 64-bit integer.
type Int64 interface { ~uint64 | ~int64 }
// String is any string.
type String interface { ~string }
// DecodeI8 decodes an 8 bit integer from the given data.
func DecodeI8[T Int8](data []byte) (T, error) {
if len(data) != 1 { return 0, fmt.Errorf("decoding int8: %w", ErrWrongBufferLength) }
return T(data[0]), nil
}
// EncodeI8 encodes an 8 bit integer into the given buffer.
func EncodeI8[T Int8](buffer []byte, value T) error {
if len(buffer) != 1 { return fmt.Errorf("encoding int8: %w", ErrWrongBufferLength) }
buffer[0] = byte(value)
return nil
}
// DecodeI16 decodes a 16 bit integer from the given data.
func DecodeI16[T Int16](data []byte) (T, error) {
if len(data) != 2 { return 0, fmt.Errorf("decoding int16: %w", ErrWrongBufferLength) }
return T(data[0]) << 8 | T(data[1]), nil
}
// EncodeI16 encodes a 16 bit integer into the given buffer.
func EncodeI16[T Int16](buffer []byte, value T) error {
if len(buffer) != 2 { return fmt.Errorf("encoding int16: %w", ErrWrongBufferLength) }
buffer[0] = byte(value >> 8)
buffer[1] = byte(value)
return nil
}
// DecodeI32 decodes a 32 bit integer from the given data.
func DecodeI32[T Int32](data []byte) (T, error) {
if len(data) != 4 { return 0, fmt.Errorf("decoding int32: %w", ErrWrongBufferLength) }
return T(data[0]) << 24 |
T(data[1]) << 16 |
T(data[2]) << 8 |
T(data[3]), nil
}
// EncodeI32 encodes a 32 bit integer into the given buffer.
func EncodeI32[T Int32](buffer []byte, value T) error {
if len(buffer) != 4 { return fmt.Errorf("encoding int32: %w", ErrWrongBufferLength) }
buffer[0] = byte(value >> 24)
buffer[1] = byte(value >> 16)
buffer[2] = byte(value >> 8)
buffer[3] = byte(value)
return nil
}
// DecodeI64 decodes a 64 bit integer from the given data.
func DecodeI64[T Int64](data []byte) (T, error) {
if len(data) != 8 { return 0, fmt.Errorf("decoding int64: %w", ErrWrongBufferLength) }
return T(data[0]) << 56 |
T(data[1]) << 48 |
T(data[2]) << 40 |
T(data[3]) << 32 |
T(data[4]) << 24 |
T(data[5]) << 16 |
T(data[6]) << 8 |
T(data[7]), nil
}
// EncodeI64 encodes a 64 bit integer into the given buffer.
func EncodeI64[T Int64](buffer []byte, value T) error {
if len(buffer) != 8 { return fmt.Errorf("encoding int64: %w", ErrWrongBufferLength) }
buffer[0] = byte(value >> 56)
buffer[1] = byte(value >> 48)
buffer[2] = byte(value >> 40)
buffer[3] = byte(value >> 32)
buffer[4] = byte(value >> 24)
buffer[5] = byte(value >> 16)
buffer[6] = byte(value >> 8)
buffer[7] = byte(value)
return nil
}
// DecodeString decodes a string from the given data.
func DecodeString[T String](data []byte) (T, error) {
return T(data), nil
}
// EncodeString encodes a string into the given buffer.
func EncodeString[T String](data []byte, value T) error {
if len(data) != len(value) { return fmt.Errorf("encoding string: %w", ErrWrongBufferLength) }
copy(data, value)
return nil
}
// StringSize returns the size of a string. Returns 0 and an error if the size
// is too large.
func StringSize[T String](value T) (int, error) {
if len(value) > dataMaxSize { return 0, ErrDataTooLarge }
return len(value), nil
}
// DecodeStringArray decodes a packed string array from the given data.
func DecodeStringArray[T String](data []byte) ([]T, error) {
result := []T { }
for len(data) > 0 {
if len(data) < 2 { return nil, fmt.Errorf("decoding []string: %w", ErrWrongBufferLength) }
itemSize16, _ := DecodeI16[uint16](data[:2])
itemSize := int(itemSize16)
data = data[2:]
if len(data) < itemSize { return nil, fmt.Errorf("decoding []string: %w", ErrWrongBufferLength) }
result = append(result, T(data[:itemSize]))
data = data[itemSize:]
}
return result, nil
}
// EncodeStringArray encodes a packed string array into the given buffer.
func EncodeStringArray[T String](buffer []byte, value []T) error {
for _, item := range value {
length, err := StringSize(item)
if err != nil { return err }
if len(buffer) < 2 + length { return fmt.Errorf("encoding []string: %w", ErrWrongBufferLength) }
EncodeI16(buffer[:2], uint16(length))
buffer = buffer[2:]
copy(buffer, item)
buffer = buffer[length:]
}
if len(buffer) > 0 { return fmt.Errorf("encoding []string: %w", ErrWrongBufferLength) }
return nil
}
// StringArraySize returns the size of a packed string array. Returns 0 and an
// error if the size is too large.
func StringArraySize[T String](value []T) (int, error) {
total := 0
for _, item := range value {
total += 2 + len(item)
}
if total > dataMaxSize { return 0, ErrDataTooLarge }
return total, nil
}
// DecodeI8Array decodes a packed array of 8 bit integers from the given data.
func DecodeI8Array[T Int8](data []byte) ([]T, error) {
result := make([]T, len(data))
for index, item := range data {
result[index] = T(item)
}
return result, nil
}
// EncodeI8Array encodes a packed array of 8 bit integers into the given buffer.
func EncodeI8Array[T Int8](buffer []byte, value []T) error {
if len(buffer) != len(value) { return fmt.Errorf("encoding []int8: %w", ErrWrongBufferLength) }
for index, item := range value {
buffer[index] = byte(item)
}
return nil
}
// I8ArraySize returns the size of a packed 8 bit integer array. Returns 0 and
// an error if the size is too large.
func I8ArraySize[T Int8](value []T) (int, error) {
total := len(value)
if total > dataMaxSize { return 0, ErrDataTooLarge }
return total, nil
}
// DecodeI16Array decodes a packed array of 16 bit integers from the given data.
func DecodeI16Array[T Int16](data []byte) ([]T, error) {
if len(data) % 2 != 0 { return nil, fmt.Errorf("decoding []int16: %w", ErrWrongBufferLength) }
length := len(data) / 2
result := make([]T, length)
for index := range length {
offset := index * 2
result[index] = T(data[offset]) << 8 | T(data[offset + 1])
}
return result, nil
}
// EncodeI16Array encodes a packed array of 16 bit integers into the given buffer.
func EncodeI16Array[T Int16](buffer []byte, value []T) error {
if len(buffer) != len(value) * 2 { return fmt.Errorf("encoding []int16: %w", ErrWrongBufferLength) }
for _, item := range value {
buffer[0] = byte(item >> 8)
buffer[1] = byte(item)
buffer = buffer[2:]
}
return nil
}
// I16ArraySize returns the size of a packed 16 bit integer array. Returns 0 and
// an error if the size is too large.
func I16ArraySize[T Int16](value []T) (int, error) {
total := len(value) * 2
if total > dataMaxSize { return 0, ErrDataTooLarge }
return total, nil
}
// DecodeI32Array decodes a packed array of 32 bit integers from the given data.
func DecodeI32Array[T Int32](data []byte) ([]T, error) {
if len(data) % 4 != 0 { return nil, fmt.Errorf("decoding []int32: %w", ErrWrongBufferLength) }
length := len(data) / 4
result := make([]T, length)
for index := range length {
offset := index * 4
result[index] =
T(data[offset + 0]) << 24 |
T(data[offset + 1]) << 16 |
T(data[offset + 2]) << 8 |
T(data[offset + 3])
}
return result, nil
}
// EncodeI32Array encodes a packed array of 32 bit integers into the given buffer.
func EncodeI32Array[T Int32](buffer []byte, value []T) error {
if len(buffer) != len(value) * 4 { return fmt.Errorf("encoding []int32: %w", ErrWrongBufferLength) }
for _, item := range value {
buffer[0] = byte(item >> 24)
buffer[1] = byte(item >> 16)
buffer[2] = byte(item >> 8)
buffer[3] = byte(item)
buffer = buffer[4:]
}
return nil
}
// I32ArraySize returns the size of a packed 32 bit integer array. Returns 0 and
// an error if the size is too large.
func I32ArraySize[T Int32](value []T) (int, error) {
total := len(value) * 4
if total > dataMaxSize { return 0, ErrDataTooLarge }
return total, nil
}
// DecodeI64Array decodes a packed array of 32 bit integers from the given data.
func DecodeI64Array[T Int64](data []byte) ([]T, error) {
if len(data) % 8 != 0 { return nil, fmt.Errorf("decoding []int64: %w", ErrWrongBufferLength) }
length := len(data) / 8
result := make([]T, length)
for index := range length {
offset := index * 8
result[index] =
T(data[offset + 0]) << 56 |
T(data[offset + 1]) << 48 |
T(data[offset + 2]) << 40 |
T(data[offset + 3]) << 32 |
T(data[offset + 4]) << 24 |
T(data[offset + 5]) << 16 |
T(data[offset + 6]) << 8 |
T(data[offset + 7])
}
return result, nil
}
// EncodeI64Array encodes a packed array of 64 bit integers into the given buffer.
func EncodeI64Array[T Int64](buffer []byte, value []T) error {
if len(buffer) != len(value) * 8 { return fmt.Errorf("encoding []int64: %w", ErrWrongBufferLength) }
for _, item := range value {
buffer[0] = byte(item >> 56)
buffer[1] = byte(item >> 48)
buffer[2] = byte(item >> 40)
buffer[3] = byte(item >> 32)
buffer[4] = byte(item >> 24)
buffer[5] = byte(item >> 16)
buffer[6] = byte(item >> 8)
buffer[7] = byte(item)
buffer = buffer[8:]
}
return nil
}
// I64ArraySize returns the size of a packed 64 bit integer array. Returns 0 and
// an error if the size is too large.
func I64ArraySize[T Int64](value []T) (int, error) {
total := len(value) * 8
if total > dataMaxSize { return 0, ErrDataTooLarge }
return total, nil
}
// U16CastSafe safely casts an integer to a uint16. If an overflow or underflow
// occurs, it will return (0, false).
func U16CastSafe(n int) (uint16, bool) {
if n < uint16Max && n >= 0 {
return uint16(n), true
} else {
return 0, false
}
}

View File

@ -1,292 +0,0 @@
package tape
import "slices"
import "errors"
import "testing"
import "math/rand"
const largeNumberNTestRounds = 2048
const randStringBytes = "-abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func TestI8(test *testing.T) {
var buffer [16]byte
err := EncodeI8[uint8](buffer[:], 5)
if err.Error() != "encoding int8: wrong buffer length" { test.Fatal(err) }
err = EncodeI8[uint8](buffer[:0], 5)
if err.Error() != "encoding int8: wrong buffer length" { test.Fatal(err) }
_, err = DecodeI8[uint8](buffer[:])
if err.Error() != "decoding int8: wrong buffer length" { test.Fatal(err) }
_, err = DecodeI8[uint8](buffer[:0])
if err.Error() != "decoding int8: wrong buffer length" { test.Fatal(err) }
for number := range uint8(255) {
err := EncodeI8[uint8](buffer[:1], number)
if err != nil { test.Fatal(err) }
decoded, err := DecodeI8[uint8](buffer[:1])
if err != nil { test.Fatal(err) }
if decoded != number {
test.Fatalf("%d != %d", decoded, number)
}
}
}
func TestI16(test *testing.T) {
var buffer [16]byte
err := EncodeI16[uint16](buffer[:], 5)
if err.Error() != "encoding int16: wrong buffer length" { test.Fatal(err) }
err = EncodeI16[uint16](buffer[:0], 5)
if err.Error() != "encoding int16: wrong buffer length" { test.Fatal(err) }
_, err = DecodeI16[uint16](buffer[:])
if err.Error() != "decoding int16: wrong buffer length" { test.Fatal(err) }
_, err = DecodeI16[uint16](buffer[:0])
if err.Error() != "decoding int16: wrong buffer length" { test.Fatal(err) }
for _ = range largeNumberNTestRounds {
number := uint16(rand.Int())
err := EncodeI16[uint16](buffer[:2], number)
if err != nil { test.Fatal(err) }
decoded, err := DecodeI16[uint16](buffer[:2])
if err != nil { test.Fatal(err) }
if decoded != number {
test.Fatalf("%d != %d", decoded, number)
}
}
}
func TestI32(test *testing.T) {
var buffer [16]byte
err := EncodeI32[uint32](buffer[:], 5)
if err.Error() != "encoding int32: wrong buffer length" { test.Fatal(err) }
err = EncodeI32[uint32](buffer[:0], 5)
if err.Error() != "encoding int32: wrong buffer length" { test.Fatal(err) }
_, err = DecodeI32[uint32](buffer[:])
if err.Error() != "decoding int32: wrong buffer length" { test.Fatal(err) }
_, err = DecodeI32[uint32](buffer[:0])
if err.Error() != "decoding int32: wrong buffer length" { test.Fatal(err) }
for _ = range largeNumberNTestRounds {
number := uint32(rand.Int())
err := EncodeI32[uint32](buffer[:4], number)
if err != nil { test.Fatal(err) }
decoded, err := DecodeI32[uint32](buffer[:4])
if err != nil { test.Fatal(err) }
if decoded != number {
test.Fatalf("%d != %d", decoded, number)
}
}
}
func TestI64(test *testing.T) {
var buffer [16]byte
err := EncodeI64[uint64](buffer[:], 5)
if err.Error() != "encoding int64: wrong buffer length" { test.Fatal(err) }
err = EncodeI64[uint64](buffer[:0], 5)
if err.Error() != "encoding int64: wrong buffer length" { test.Fatal(err) }
_, err = DecodeI64[uint64](buffer[:])
if err.Error() != "decoding int64: wrong buffer length" { test.Fatal(err) }
_, err = DecodeI64[uint64](buffer[:0])
if err.Error() != "decoding int64: wrong buffer length" { test.Fatal(err) }
for _ = range largeNumberNTestRounds {
number := uint64(rand.Int())
err := EncodeI64[uint64](buffer[:8], number)
if err != nil { test.Fatal(err) }
decoded, err := DecodeI64[uint64](buffer[:8])
if err != nil { test.Fatal(err) }
if decoded != number {
test.Fatalf("%d != %d", decoded, number)
}
}
}
func TestString(test *testing.T) {
var buffer [16]byte
err := EncodeString[string](buffer[:], "hello")
if !errIs(err, ErrWrongBufferLength, "encoding string: wrong buffer length") { test.Fatal(err) }
err = EncodeString[string](buffer[:0], "hello")
if !errIs(err, ErrWrongBufferLength, "encoding string: wrong buffer length") { test.Fatal(err) }
_, err = DecodeString[string](buffer[:])
if err != nil { test.Fatal(err) }
_, err = DecodeString[string](buffer[:0])
if err != nil { test.Fatal(err) }
for _ = range largeNumberNTestRounds {
length := rand.Intn(16)
str := randString(length)
err := EncodeString[string](buffer[:length], str)
if err != nil { test.Fatal(err) }
decoded, err := DecodeString[string](buffer[:length])
if err != nil { test.Fatal(err) }
if decoded != str {
test.Fatalf("%s != %s", decoded, str)
}
}
}
func TestI8Array(test *testing.T) {
var buffer [64]byte
err := EncodeI8Array[uint8](buffer[:], []uint8 { 0, 4, 50, 19 })
if !errIs(err, ErrWrongBufferLength, "encoding []int8: wrong buffer length") { test.Fatal(err) }
err = EncodeI8Array[uint8](buffer[:0], []uint8 { 0, 4, 50, 19 })
if !errIs(err, ErrWrongBufferLength, "encoding []int8: wrong buffer length") { test.Fatal(err) }
_, err = DecodeI8Array[uint8](buffer[:])
if err != nil { test.Fatal(err) }
_, err = DecodeI8Array[uint8](buffer[:0])
if err != nil { test.Fatal(err) }
for _ = range largeNumberNTestRounds {
array := randInts[uint8](rand.Intn(16))
length, _ := I8ArraySize(array)
if length != len(array) { test.Fatalf("%d != %d", length, len(array)) }
err := EncodeI8Array[uint8](buffer[:length], array)
if err != nil { test.Fatal(err) }
decoded, err := DecodeI8Array[uint8](buffer[:length])
if err != nil { test.Fatal(err) }
if !slices.Equal(decoded, array) {
test.Fatalf("%v != %v", decoded, array)
}
}
}
func TestI16Array(test *testing.T) {
var buffer [128]byte
err := EncodeI16Array[uint16](buffer[:], []uint16 { 0, 4, 50, 19 })
if !errIs(err, ErrWrongBufferLength, "encoding []int16: wrong buffer length") { test.Fatal(err) }
err = EncodeI16Array[uint16](buffer[:0], []uint16 { 0, 4, 50, 19 })
if !errIs(err, ErrWrongBufferLength, "encoding []int16: wrong buffer length") { test.Fatal(err) }
_, err = DecodeI16Array[uint16](buffer[:])
if err != nil { test.Fatal(err) }
_, err = DecodeI16Array[uint16](buffer[:0])
if err != nil { test.Fatal(err) }
for _ = range largeNumberNTestRounds {
array := randInts[uint16](rand.Intn(16))
length, _ := I16ArraySize(array)
if length != 2 * len(array) { test.Fatalf("%d != %d", length, 2 * len(array)) }
err := EncodeI16Array[uint16](buffer[:length], array)
if err != nil { test.Fatal(err) }
decoded, err := DecodeI16Array[uint16](buffer[:length])
if err != nil { test.Fatal(err) }
if !slices.Equal(decoded, array) {
test.Fatalf("%v != %v", decoded, array)
}
}
}
func TestI32Array(test *testing.T) {
var buffer [256]byte
err := EncodeI32Array[uint32](buffer[:], []uint32 { 0, 4, 50, 19 })
if !errIs(err, ErrWrongBufferLength, "encoding []int32: wrong buffer length") { test.Fatal(err) }
err = EncodeI32Array[uint32](buffer[:0], []uint32 { 0, 4, 50, 19 })
if !errIs(err, ErrWrongBufferLength, "encoding []int32: wrong buffer length") { test.Fatal(err) }
_, err = DecodeI32Array[uint32](buffer[:])
if err != nil { test.Fatal(err) }
_, err = DecodeI32Array[uint32](buffer[:0])
if err != nil { test.Fatal(err) }
for _ = range largeNumberNTestRounds {
array := randInts[uint32](rand.Intn(16))
length, _ := I32ArraySize(array)
if length != 4 * len(array) { test.Fatalf("%d != %d", length, 4 * len(array)) }
err := EncodeI32Array[uint32](buffer[:length], array)
if err != nil { test.Fatal(err) }
decoded, err := DecodeI32Array[uint32](buffer[:length])
if err != nil { test.Fatal(err) }
if !slices.Equal(decoded, array) {
test.Fatalf("%v != %v", decoded, array)
}
}
}
func TestI64Array(test *testing.T) {
var buffer [512]byte
err := EncodeI64Array[uint64](buffer[:], []uint64 { 0, 4, 50, 19 })
if !errIs(err, ErrWrongBufferLength, "encoding []int64: wrong buffer length") { test.Fatal(err) }
err = EncodeI64Array[uint64](buffer[:0], []uint64 { 0, 4, 50, 19 })
if !errIs(err, ErrWrongBufferLength, "encoding []int64: wrong buffer length") { test.Fatal(err) }
_, err = DecodeI64Array[uint64](buffer[:])
if err != nil { test.Fatal(err) }
_, err = DecodeI64Array[uint64](buffer[:0])
if err != nil { test.Fatal(err) }
for _ = range largeNumberNTestRounds {
array := randInts[uint64](rand.Intn(16))
length, _ := I64ArraySize(array)
if length != 8 * len(array) { test.Fatalf("%d != %d", length, 8 * len(array)) }
err := EncodeI64Array[uint64](buffer[:length], array)
if err != nil { test.Fatal(err) }
decoded, err := DecodeI64Array[uint64](buffer[:length])
if err != nil { test.Fatal(err) }
if !slices.Equal(decoded, array) {
test.Fatalf("%v != %v", decoded, array)
}
}
}
func TestStringArray(test *testing.T) {
var buffer [8192]byte
err := EncodeStringArray[string](buffer[:], []string { "0", "4", "50", "19" })
if !errIs(err, ErrWrongBufferLength, "encoding []string: wrong buffer length") { test.Fatal(err) }
err = EncodeStringArray[string](buffer[:0], []string { "0", "4", "50", "19" })
if !errIs(err, ErrWrongBufferLength, "encoding []string: wrong buffer length") { test.Fatal(err) }
_, err = DecodeStringArray[string](buffer[:0])
if err != nil { test.Fatal(err) }
for _ = range largeNumberNTestRounds {
array := randStrings[string](rand.Intn(16), 16)
length, _ := StringArraySize(array)
// TODO test length
err := EncodeStringArray[string](buffer[:length], array)
if err != nil { test.Fatal(err) }
decoded, err := DecodeStringArray[string](buffer[:length])
if err != nil { test.Fatal(err) }
if !slices.Equal(decoded, array) {
test.Fatalf("%v != %v", decoded, array)
}
}
}
func TestU16CastSafe(test *testing.T) {
number, ok := U16CastSafe(90_000)
if ok { test.Fatalf("false positive: %v, %v", number, ok) }
number, ok = U16CastSafe(-478)
if ok { test.Fatalf("false positive: %v, %v", number, ok) }
number, ok = U16CastSafe(3870)
if !ok { test.Fatalf("false negative: %v, %v", number, ok) }
if got, correct := number, uint16(3870); got != correct {
test.Fatalf("not equal: %v %v", got, correct)
}
number, ok = U16CastSafe(0)
if !ok { test.Fatalf("false negative: %v, %v", number, ok) }
if got, correct := number, uint16(0); got != correct {
test.Fatalf("not equal: %v %v", got, correct)
}
}
func randString(length int) string {
buffer := make([]byte, length)
for index := range buffer {
buffer[index] = randStringBytes[rand.Intn(len(randStringBytes))]
}
return string(buffer)
}
func randInts[T interface { ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 }] (length int) []T {
buffer := make([]T, length)
for index := range buffer {
buffer[index] = T(rand.Int())
}
return buffer
}
func randStrings[T interface { ~string }] (length, maxItemLength int) []T {
buffer := make([]T, length)
for index := range buffer {
buffer[index] = T(randString(rand.Intn(maxItemLength)))
}
return buffer
}
func errIs(err error, wraps error, description string) bool {
return err != nil && (wraps == nil || errors.Is(err, wraps)) && err.Error() == description
}