Compare commits
8 Commits
80c7d25c73
...
message-si
| Author | SHA1 | Date | |
|---|---|---|---|
| 12fbfa6293 | |||
| 44fb561758 | |||
| 04c352fad6 | |||
| 0ea7e222cc | |||
| ae79a32309 | |||
| e28ab4dc6b | |||
| 80161b37f7 | |||
| 9d40b81e00 |
@@ -2,7 +2,7 @@ package hopp
|
||||
|
||||
import "io"
|
||||
import "net"
|
||||
// import "time"
|
||||
import "time"
|
||||
|
||||
const defaultSizeLimit int64 = 1024 * 1024 // 1 megabyte
|
||||
|
||||
@@ -23,6 +23,9 @@ type Conn interface {
|
||||
// be called in a loop to avoid the connection locking up.
|
||||
AcceptTrans() (Trans, error)
|
||||
|
||||
// SetDeadline operates is [net.Conn.SetDeadline] but for OpenTrans
|
||||
// and AcceptTrans calls.
|
||||
SetDeadline(t time.Time) error
|
||||
// SetSizeLimit sets a limit (in bytes) for how large messages can be.
|
||||
// By default, this limit is 1 megabyte. Note that this is only
|
||||
// enforced when sending and receiving byte slices, and it does not
|
||||
@@ -41,8 +44,6 @@ type Trans interface {
|
||||
// ID returns the transaction ID. This must not change, and it must be
|
||||
// unique within the connection. This method is safe for concurrent use.
|
||||
ID() int64
|
||||
|
||||
// TODO: add methods for setting send and receive deadlines
|
||||
|
||||
// Send sends a message. This method is not safe for concurrent use.
|
||||
Send(method uint16, data []byte) error
|
||||
@@ -59,4 +60,12 @@ type Trans interface {
|
||||
// previously opened through this function will be discarded. This
|
||||
// method is not safe for concurrent use, and neither is its result.
|
||||
ReceiveReader() (method uint16, data io.Reader, err error)
|
||||
|
||||
// See the documentation for [net.Conn.SetDeadline].
|
||||
SetDeadline(time.Time) error
|
||||
// TODO
|
||||
// // See the documentation for [net.Conn.SetReadDeadline].
|
||||
// SetReadDeadline(t time.Time) error
|
||||
// // See the documentation for [net.Conn.SetWriteDeadline].
|
||||
// SetWriteDeadline(t time.Time) error
|
||||
}
|
||||
|
||||
@@ -343,7 +343,14 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
|
||||
case TypeBuffer:
|
||||
// SBA: <data: U8>*
|
||||
// LBA: <length: UN> <data: U8>*
|
||||
nn, err := this.iprintf("if %s.Is(tape.LBA) {\n", tagSource)
|
||||
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.push()
|
||||
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
|
||||
this.pop()
|
||||
nn, err = this.iprintf("}\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf("if %s.Is(tape.LBA) {\n", tagSource)
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.push()
|
||||
nn, err = this.iprintf(
|
||||
@@ -355,14 +362,20 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
|
||||
this.pop()
|
||||
nn, err = this.iprintf("}\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
|
||||
nn, err = this.iprintf("nn, err = encoder.Write([]byte(%s))\n", valueSource)
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.generateErrorCheck()
|
||||
n += nn; if err != nil { return n, err }
|
||||
case TypeArray:
|
||||
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
|
||||
nn, err := this.iprintf(
|
||||
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.push()
|
||||
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
|
||||
this.pop()
|
||||
nn, err = this.iprintf("}\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf(
|
||||
"nn, err = encoder.WriteUintN(uint64(len(%s)), %s.CN())\n",
|
||||
valueSource, tagSource)
|
||||
n += nn; if err != nil { return n, err }
|
||||
@@ -417,7 +430,14 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
|
||||
n += nn; if err != nil { return n, err }
|
||||
case TypeTable:
|
||||
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
|
||||
nn, err := this.iprintf(
|
||||
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.push()
|
||||
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
|
||||
this.pop()
|
||||
nn, err = this.iprintf("}\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf(
|
||||
"nn, err = tape.EncodeAny(encoder, %s, %s)\n",
|
||||
valueSource, tagSource)
|
||||
n += nn; if err != nil { return n, err }
|
||||
@@ -425,7 +445,14 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
|
||||
n += nn; if err != nil { return n, err }
|
||||
case TypeTableDefined:
|
||||
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
|
||||
nn, err := this.iprintf(
|
||||
nn, err := this.iprintf("if %d > tape.MaxStructureLength {\n", len(typ.Fields))
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.push()
|
||||
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
|
||||
this.pop()
|
||||
nn, err = this.iprintf("}\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf(
|
||||
"nn, err = encoder.WriteUintN(%d, %s.CN())\n",
|
||||
len(typ.Fields), tagSource)
|
||||
n += nn; if err != nil { return n, err }
|
||||
@@ -568,7 +595,14 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS
|
||||
this.pop()
|
||||
nn, err = this.iprintf("}\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf("buffer := make([]byte, int(%s))\n", lengthVar)
|
||||
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.push()
|
||||
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
|
||||
this.pop()
|
||||
nn, err = this.iprintf("}\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf("buffer := make([]byte, %s)\n", lengthVar)
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf("nn, err = decoder.Read(buffer)\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
@@ -670,6 +704,13 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st
|
||||
lengthVar := this.newTemporaryVar("length")
|
||||
nn, err := this.iprintf("var %s uint64\n", lengthVar)
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.push()
|
||||
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
|
||||
this.pop()
|
||||
nn, err = this.iprintf("}\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()))\n", lengthVar)
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.generateErrorCheck()
|
||||
@@ -739,6 +780,13 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st
|
||||
lengthVar := this.newTemporaryVar("length")
|
||||
nn, err := this.iprintf("var %s uint64\n", lengthVar)
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.push()
|
||||
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
|
||||
this.pop()
|
||||
nn, err = this.iprintf("}\n")
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()))\n", lengthVar)
|
||||
n += nn; if err != nil { return n, err }
|
||||
nn, err = this.generateErrorCheck()
|
||||
@@ -751,7 +799,7 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st
|
||||
// problems
|
||||
|
||||
// read fields
|
||||
nn, err = this.iprintf("for _ = range int(%s) {\n", lengthVar)
|
||||
nn, err = this.iprintf("for _ = range %s {\n", lengthVar)
|
||||
n += nn; if err != nil { return n, err }
|
||||
this.push()
|
||||
// read field header
|
||||
|
||||
75
metadapta.go
75
metadapta.go
@@ -1,9 +1,11 @@
|
||||
package hopp
|
||||
|
||||
import "io"
|
||||
import "os"
|
||||
import "fmt"
|
||||
import "net"
|
||||
import "sync"
|
||||
import "time"
|
||||
import "sync/atomic"
|
||||
import "git.tebibyte.media/sashakoshka/go-util/sync"
|
||||
|
||||
@@ -108,6 +110,10 @@ func (this *a) AcceptTrans() (Trans, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (this *a) SetDeadline(t time.Time) error {
|
||||
return this.underlying.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (this *a) SetSizeLimit(limit int64) {
|
||||
this.sizeLimit = limit
|
||||
}
|
||||
@@ -212,6 +218,10 @@ type transA struct {
|
||||
currentWriter io.Closer
|
||||
writeBuffer []byte
|
||||
closed atomic.Bool
|
||||
closeErr error
|
||||
|
||||
deadline *time.Timer
|
||||
deadlineLock sync.Mutex
|
||||
}
|
||||
|
||||
func (this *transA) Close() error {
|
||||
@@ -221,6 +231,11 @@ func (this *transA) Close() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (this *transA) closeWithError(err error) error {
|
||||
this.closeErr = err
|
||||
return this.Close()
|
||||
}
|
||||
|
||||
func (this *transA) closeDontUnlist() (err error) {
|
||||
// MUST be goroutine safe
|
||||
this.incoming.Close()
|
||||
@@ -269,9 +284,9 @@ func (this *transA) Receive() (method uint16, data []byte, err error) {
|
||||
}
|
||||
|
||||
func (this *transA) ReceiveReader() (uint16, io.Reader, error) {
|
||||
// if the transaction has been closed, return an io.EOF
|
||||
if this.closed.Load() {
|
||||
return 0, nil, io.EOF
|
||||
// if the transaction has been closed, return an appropriate error.
|
||||
if err := this.errIfClosed(); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
// drain previous reader if necessary
|
||||
@@ -289,6 +304,54 @@ func (this *transA) ReceiveReader() (uint16, io.Reader, error) {
|
||||
return method, reader, nil
|
||||
}
|
||||
|
||||
func (this *transA) SetDeadline(t time.Time) error {
|
||||
this.deadlineLock.Lock()
|
||||
defer this.deadlineLock.Unlock()
|
||||
|
||||
if t == (time.Time { }) {
|
||||
if this.deadline != nil {
|
||||
this.deadline.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
until := time.Until(t)
|
||||
if this.deadline == nil {
|
||||
this.deadline.Reset(until)
|
||||
return nil
|
||||
}
|
||||
this.deadline = time.AfterFunc(until, func () {
|
||||
this.closeWithError(os.ErrDeadlineExceeded)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO
|
||||
// func (this *transA) SetReadDeadline(t time.Time) error {
|
||||
// // TODO
|
||||
// }
|
||||
//
|
||||
// func (this *transA) SetWriteDeadline(t time.Time) error {
|
||||
// // TODO
|
||||
// }
|
||||
|
||||
func (this *transA) errIfClosed() error {
|
||||
if !this.closed.Load() {
|
||||
return nil
|
||||
}
|
||||
return this.bestErr()
|
||||
}
|
||||
|
||||
func (this *transA) bestErr() error {
|
||||
if this.parent.err != nil {
|
||||
return this.parent.err
|
||||
}
|
||||
if this.closeErr != nil {
|
||||
return this.closeErr
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
type readerA struct {
|
||||
parent *transA
|
||||
leftover []byte
|
||||
@@ -319,11 +382,7 @@ func (this *readerA) pull() (uint16, error) {
|
||||
// close and return error on failure
|
||||
this.eof = true
|
||||
this.parent.Close()
|
||||
if this.parent.parent.err == nil {
|
||||
return 0, fmt.Errorf("could not receive message: %w", io.EOF)
|
||||
} else {
|
||||
return 0, this.parent.parent.err
|
||||
}
|
||||
return 0, fmt.Errorf("could not receive message: %w", this.parent.bestErr())
|
||||
}
|
||||
|
||||
func (this *readerA) Read(buffer []byte) (int, error) {
|
||||
|
||||
13
metadaptb.go
13
metadaptb.go
@@ -2,6 +2,7 @@ package hopp
|
||||
|
||||
import "io"
|
||||
import "net"
|
||||
import "time"
|
||||
import "bytes"
|
||||
import "errors"
|
||||
import "context"
|
||||
@@ -50,6 +51,10 @@ func (this *b) SetSizeLimit(limit int64) {
|
||||
this.sizeLimit = limit
|
||||
}
|
||||
|
||||
func (this *b) SetDeadline(t time.Time) error {
|
||||
return this.underlying.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (this *b) newTrans(underlying Stream) *transB {
|
||||
return &transB {
|
||||
sizeLimit: this.sizeLimit,
|
||||
@@ -124,6 +129,10 @@ func (this *transB) receiveReader() (uint16, int64, io.Reader, error) {
|
||||
return method, size, data, nil
|
||||
}
|
||||
|
||||
func (this *transB) SetDeadline(t time.Time) error {
|
||||
return this.underlying.SetDeadline(t)
|
||||
}
|
||||
|
||||
type writerB struct {
|
||||
parent *transB
|
||||
buffer bytes.Buffer
|
||||
@@ -149,12 +158,16 @@ type MultiConn interface {
|
||||
AcceptStream(context.Context) (Stream, error)
|
||||
// OpenStream opens a new stream.
|
||||
OpenStream() (Stream, error)
|
||||
// See the documentation for [net.Conn.SetDeadline].
|
||||
SetDeadline(time.Time) error
|
||||
}
|
||||
|
||||
// Stream represents a single stream returned by a [MultiConn].
|
||||
type Stream interface {
|
||||
// See documentation for [net.Conn].
|
||||
io.ReadWriteCloser
|
||||
// See the documentation for [net.Conn.SetDeadline].
|
||||
SetDeadline(time.Time) error
|
||||
// ID returns the stream ID
|
||||
ID() int64
|
||||
}
|
||||
|
||||
@@ -46,9 +46,16 @@ func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
|
||||
case reflect.Uint32: return encoder.WriteUint32(uint32(reflectValue.Uint()))
|
||||
case reflect.Int64: return encoder.WriteInt64(int64(reflectValue.Int()))
|
||||
case reflect.Uint64: return encoder.WriteUint64(uint64(reflectValue.Uint()))
|
||||
case reflect.String: return EncodeAny(encoder, []byte(reflectValue.String()), tag)
|
||||
case reflect.String:
|
||||
if reflectValue.Len() > MaxStructureLength {
|
||||
return 0, ErrTooLong
|
||||
}
|
||||
return EncodeAny(encoder, []byte(reflectValue.String()), tag)
|
||||
}
|
||||
if reflectValue.CanConvert(reflect.TypeOf(dummyBuffer)) {
|
||||
if reflectValue.Len() > MaxStructureLength {
|
||||
return 0, ErrTooLong
|
||||
}
|
||||
if tag.Is(LBA) {
|
||||
nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1)
|
||||
n += nn; if err != nil { return n, err }
|
||||
@@ -64,8 +71,13 @@ func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
|
||||
case reflect.Slice:
|
||||
return encodeAnySlice(encoder, value, tag)
|
||||
// case reflect.Array:
|
||||
// TODO: we can encode arrays. but can we decode into them?
|
||||
// that's the fucken question. maybe we just do the first
|
||||
// return encodeAnySlice(encoder, reflect.ValueOf(value).Slice(0, reflectType.Len()).Interface(), tag)
|
||||
case reflect.Map:
|
||||
if reflectValue.Len() > MaxStructureLength {
|
||||
return 0, ErrTooLong
|
||||
}
|
||||
if reflectType.Key() == reflect.TypeOf(uint16(0)) {
|
||||
return encodeAnyMap(encoder, value, tag)
|
||||
}
|
||||
@@ -133,7 +145,11 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
|
||||
n += nn; if err != nil { return n, err }
|
||||
case SBA:
|
||||
// SBA: <data: U8>*
|
||||
buffer := make([]byte, tag.CN())
|
||||
length := tag.CN()
|
||||
if length > MaxStructureLength {
|
||||
return 0, ErrTooLong
|
||||
}
|
||||
buffer := make([]byte, length)
|
||||
nn, err := decoder.Read(buffer)
|
||||
n += nn; if err != nil { return n, err }
|
||||
setByteArray(destination, buffer)
|
||||
@@ -141,6 +157,9 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
|
||||
// LBA: <length: UN> <data: U8>*
|
||||
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
|
||||
n += nn; if err != nil { return n, err }
|
||||
if length > uint64(MaxStructureLength) {
|
||||
return 0, ErrTooLong
|
||||
}
|
||||
buffer := make([]byte, length)
|
||||
nn, err = decoder.Read(buffer)
|
||||
n += nn; if err != nil { return n, err }
|
||||
@@ -149,10 +168,15 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
|
||||
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
|
||||
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
|
||||
n += nn; if err != nil { return n, err }
|
||||
if length > uint64(MaxStructureLength) {
|
||||
return 0, ErrTooLong
|
||||
}
|
||||
lengthCast, err := Uint64ToIntSafe(length)
|
||||
if err != nil { return n, err }
|
||||
oneTag, nn, err := decoder.ReadTag()
|
||||
n += nn; if err != nil { return n, err }
|
||||
if destination.Cap() < int(length) {
|
||||
destination.Grow(int(length) - destination.Cap())
|
||||
if destination.Cap() < lengthCast {
|
||||
destination.Grow(lengthCast - destination.Cap())
|
||||
}
|
||||
// skip the rest of the array if the one tag doesn't
|
||||
// match up with the destination
|
||||
@@ -165,7 +189,7 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
|
||||
break
|
||||
}
|
||||
if err != nil { return n, err }
|
||||
destination.SetLen(int(length))
|
||||
destination.SetLen(lengthCast)
|
||||
for index := range length {
|
||||
nn, err := decodeAny(decoder, destination.Index(int(index)), oneTag)
|
||||
n += nn
|
||||
@@ -179,6 +203,9 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
|
||||
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
|
||||
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
|
||||
n += nn; if err != nil { return n, err }
|
||||
if length > uint64(MaxStructureLength) {
|
||||
return 0, ErrTooLong
|
||||
}
|
||||
destination.Clear()
|
||||
for _ = range length {
|
||||
key, nn, err := decoder.ReadUint16()
|
||||
|
||||
12
tape/error.go
Normal file
12
tape/error.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package tape
|
||||
|
||||
// Error enumerates common errors in this package.
|
||||
type Error string; const (
|
||||
ErrTooLong Error = "data structure too long"
|
||||
ErrTooLarge Error = "number too large"
|
||||
)
|
||||
|
||||
// Error implements the error interface.
|
||||
func (err Error) Error() string {
|
||||
return string(err)
|
||||
}
|
||||
26
tape/limits.go
Normal file
26
tape/limits.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package tape
|
||||
|
||||
// MaxStructureLength determines how long a TAPE data structure can be. This
|
||||
// applies to:
|
||||
//
|
||||
// - OTA
|
||||
// - SBA/LBA
|
||||
// - KTV
|
||||
//
|
||||
// By default it is set at 2^20 (about a million).
|
||||
// You shouldn't need to change this. If you do, it should only be set once at
|
||||
// the start of the program.
|
||||
var MaxStructureLength = 1024 * 1024
|
||||
|
||||
// MaxInt is the maximum value an int can hold. This varies depending on the
|
||||
// system.
|
||||
const MaxInt int = int(^uint(0) >> 1)
|
||||
|
||||
// Uint64ToIntSafe casts the input to an int if it can be done without overflow,
|
||||
// or returns an error otherwise.
|
||||
func Uint64ToIntSafe(input uint64) (int, error) {
|
||||
if input > uint64(MaxInt) {
|
||||
return 0, ErrTooLarge
|
||||
}
|
||||
return int(input), nil
|
||||
}
|
||||
Reference in New Issue
Block a user