16 Commits

Author SHA1 Message Date
12fbfa6293 hopp: Add SetDeadline methods to Conn and Trans 2025-09-05 18:48:12 -04:00
44fb561758 generate: Safely cast in the static decoder 2025-08-29 12:21:10 -04:00
04c352fad6 tape: Safely cast when dynamically encoding/decoding 2025-08-29 12:03:39 -04:00
0ea7e222cc generate: Respect limits when statically decoding 2025-08-28 12:53:58 -04:00
ae79a32309 generate: Respect limits when statically encoding 2025-08-28 12:51:55 -04:00
e28ab4dc6b tape: Respect limits when dynamically decoding 2025-08-28 12:31:49 -04:00
80161b37f7 tape: Respect limits when dynamically encoding
Still need: dynamic decoding, static decoding, static encoding
2025-08-28 09:31:33 -04:00
9d40b81e00 tape: Add limits to the API 2025-08-28 09:31:14 -04:00
80c7d25c73 hopp: Clarify documentation for Conn.SetSizeLimit 2025-08-28 09:03:27 -04:00
743a5d4ae0 generate: Fix float encoding 2025-08-27 22:59:01 -04:00
ea17e354a3 cmd/hopp-generate: Update generate command 2025-08-27 22:54:43 -04:00
4dc8a30ebd generate: Don't hardcode filename as test.pdl in errors 2025-08-27 22:54:18 -04:00
15c5f0b2b8 generate: More fixes for TestGenerateRunDecodeWrongType 2025-08-27 22:22:00 -04:00
087b6b6690 generate: Fix problems with TestGenerateRunDecodeWrongType 2025-08-27 18:33:55 -04:00
77bfc45fea generate: Cast strings and buffers when decoding 2025-08-27 14:55:10 -04:00
de6099fadc generate: cast integers when decoding 2025-08-27 00:43:10 -04:00
11 changed files with 299 additions and 48 deletions

View File

@@ -4,6 +4,7 @@ import "os"
import "fmt"
import "strings"
import "path/filepath"
import "git.tebibyte.media/sashakoshka/goparse"
import "git.tebibyte.media/sashakoshka/hopp/generate"
func main() {
@@ -18,7 +19,7 @@ func main() {
input, err := os.Open(source)
handleErr(1, err)
defer input.Close()
protocol, err := generate.ParseReader(input)
protocol, err := generate.ParseReader(source, input)
handleErr(1, err)
absDestination, err := filepath.Abs(destination)
@@ -30,14 +31,18 @@ func main() {
output, err := os.Create(destination)
handleErr(1, err)
err = protocol.Generate(output, packageName)
generator := generate.Generator {
Output: output,
PackageName: packageName,
}
_, err = generator.Generate(protocol)
handleErr(1, err)
fmt.Fprintf(os.Stderr, "%s: OK\n", name)
}
func handleErr(code int, err error) {
if err != nil {
fmt.Fprintf(os.Stderr, "%s: %v\n", os.Args[0], err)
fmt.Fprintf(os.Stderr, "%s: %v\n", os.Args[0], parse.Format(err))
os.Exit(code)
}
}

View File

@@ -2,7 +2,7 @@ package hopp
import "io"
import "net"
// import "time"
import "time"
const defaultSizeLimit int64 = 1024 * 1024 // 1 megabyte
@@ -23,8 +23,13 @@ type Conn interface {
// be called in a loop to avoid the connection locking up.
AcceptTrans() (Trans, error)
// SetDeadline operates is [net.Conn.SetDeadline] but for OpenTrans
// and AcceptTrans calls.
SetDeadline(t time.Time) error
// SetSizeLimit sets a limit (in bytes) for how large messages can be.
// By default, this limit is 1 megabyte.
// By default, this limit is 1 megabyte. Note that this is only
// enforced when sending and receiving byte slices, and it does not
// apply to [Trans.SendWriter] or [Trans.ReceiveReader].
SetSizeLimit(limit int64)
}
@@ -39,8 +44,6 @@ type Trans interface {
// ID returns the transaction ID. This must not change, and it must be
// unique within the connection. This method is safe for concurrent use.
ID() int64
// TODO: add methods for setting send and receive deadlines
// Send sends a message. This method is not safe for concurrent use.
Send(method uint16, data []byte) error
@@ -57,4 +60,12 @@ type Trans interface {
// previously opened through this function will be discarded. This
// method is not safe for concurrent use, and neither is its result.
ReceiveReader() (method uint16, data io.Reader, err error)
// See the documentation for [net.Conn.SetDeadline].
SetDeadline(time.Time) error
// TODO
// // See the documentation for [net.Conn.SetReadDeadline].
// SetReadDeadline(t time.Time) error
// // See the documentation for [net.Conn.SetWriteDeadline].
// SetWriteDeadline(t time.Time) error
}

View File

@@ -332,7 +332,7 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s)\n", valueSource)
nn, err = this.printf("(%s))\n", valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
@@ -343,7 +343,14 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
case TypeBuffer:
// SBA: <data: U8>*
// LBA: <length: UN> <data: U8>*
nn, err := this.iprintf("if %s.Is(tape.LBA) {\n", tagSource)
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s.Is(tape.LBA) {\n", tagSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf(
@@ -355,14 +362,20 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = encoder.Write([]byte(%s))\n", valueSource)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
case TypeArray:
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
nn, err := this.iprintf(
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"nn, err = encoder.WriteUintN(uint64(len(%s)), %s.CN())\n",
valueSource, tagSource)
n += nn; if err != nil { return n, err }
@@ -417,7 +430,14 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
n += nn; if err != nil { return n, err }
case TypeTable:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf(
nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"nn, err = tape.EncodeAny(encoder, %s, %s)\n",
valueSource, tagSource)
n += nn; if err != nil { return n, err }
@@ -425,7 +445,14 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri
n += nn; if err != nil { return n, err }
case TypeTableDefined:
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
nn, err := this.iprintf(
nn, err := this.iprintf("if %d > tape.MaxStructureLength {\n", len(typ.Fields))
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf(
"nn, err = encoder.WriteUintN(%d, %s.CN())\n",
len(typ.Fields), tagSource)
n += nn; if err != nil { return n, err }
@@ -491,24 +518,57 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS
// LI/LSI: <value: IntN>
if typ.Bits <= 5 {
// SI stores the value in the tag
nn, err := this.iprintf("*%s = uint8(%s.CN())\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
if typeName == "" {
nn, err := this.iprintf("*%s = uint8(%s.CN())\n", valueSource, tagSource)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("*%s = %s(%s.CN())\n", valueSource, typeName, tagSource)
n += nn; if err != nil { return n, err }
}
break
}
prefix := "ReadUint"
if typ.Signed {
prefix = "ReadInt"
}
nn, err := this.iprintf("*%s, nn, err = decoder.%s%d()\n", valueSource, prefix, typ.Bits)
destinationVar := this.newTemporaryVar("destination")
nn, err := this.iprintf("var %s ", destinationVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
nn, err = this.print("\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.%s%d()\n", destinationVar, prefix, typ.Bits)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
if typeName == "" {
nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar)
n += nn; if err != nil { return n, err }
}
case TypeFloat:
// FP: <value: FloatN>
nn, err := this.iprintf("*%s, nn, err = decoder.ReadFloat%d()\n", valueSource, typ.Bits)
destinationVar := this.newTemporaryVar("destination")
nn, err := this.iprintf("var %s ", destinationVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateType(typ)
n += nn; if err != nil { return n, err }
nn, err = this.print("\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadFloat%d()\n", destinationVar, typ.Bits)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
if typeName == "" {
nn, err := this.iprintf("*%s = %s\n", valueSource, destinationVar)
n += nn; if err != nil { return n, err }
} else {
nn, err := this.iprintf("*%s = %s(%s)\n", valueSource, typeName, destinationVar)
n += nn; if err != nil { return n, err }
}
case TypeString, TypeBuffer:
// SBA: <data: U8>*
// LBA: <length: UN> <data: U8>*
@@ -535,17 +595,29 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("buffer := make([]byte, int(%s))\n", lengthVar)
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("buffer := make([]byte, %s)\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("nn, err = decoder.Read(buffer)\n")
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
n += nn; if err != nil { return n, err }
if _, ok := typ.(TypeString); ok {
nn, err = this.iprintf("*%s = string(buffer)\n", valueSource)
n += nn; if err != nil { return n, err }
if typeName == "" {
if _, ok := typ.(TypeString); ok {
nn, err = this.iprintf("*%s = string(buffer)\n", valueSource)
n += nn; if err != nil { return n, err }
} else {
nn, err = this.iprintf("*%s = buffer\n", valueSource)
n += nn; if err != nil { return n, err }
}
} else {
nn, err = this.iprintf("*%s = buffer\n", valueSource)
nn, err = this.iprintf("*%s = %s(buffer)\n", valueSource, typeName)
n += nn; if err != nil { return n, err }
}
case TypeArray:
@@ -632,6 +704,13 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st
lengthVar := this.newTemporaryVar("length")
nn, err := this.iprintf("var %s uint64\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()))\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
@@ -701,6 +780,13 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st
lengthVar := this.newTemporaryVar("length")
nn, err := this.iprintf("var %s uint64\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
nn, err = this.iprintf("return n, tape.ErrTooLong\n")
this.pop()
nn, err = this.iprintf("}\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()))\n", lengthVar)
n += nn; if err != nil { return n, err }
nn, err = this.generateErrorCheck()
@@ -713,7 +799,7 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st
// problems
// read fields
nn, err = this.iprintf("for _ = range int(%s) {\n", lengthVar)
nn, err = this.iprintf("for _ = range %s {\n", lengthVar)
n += nn; if err != nil { return n, err }
this.push()
// read field header
@@ -857,10 +943,10 @@ func (this *Generator) generateTag(typ Type, source string) (n int, err error) {
nn, err := this.printf("tape.FP.WithCN(%d)", bitsToCN(typ.Bits))
n += nn; if err != nil { return n, err }
case TypeString:
nn, err := this.printf("tape.StringTag(%s)", source)
nn, err := this.printf("tape.StringTag(string(%s))", source)
n += nn; if err != nil { return n, err }
case TypeBuffer:
nn, err := this.printf("tape.BufferTag(%s)", source)
nn, err := this.printf("tape.BufferTag([]byte(%s))", source)
n += nn; if err != nil { return n, err }
case TypeArray:
nn, err := this.printf("tape.OTA.WithCN(tape.IntBytes(uint64(len(%s))))", source)

View File

@@ -253,7 +253,7 @@ func TestGenerateRunDecodeWrongType(test *testing.T) {
Name: "Table",
Type: TypeTable { },
}
protocol.Messages[0x000C] = Message {
protocol.Messages[0x000D] = Message {
Name: "TableDefined",
Type: TypeTableDefined {
Fields: map[uint16] Field {
@@ -312,8 +312,8 @@ func TestGenerateRunDecodeWrongType(test *testing.T) {
reflectValue.Elem().Interface())
}
}
if n != len(data) - 1 {
log.Fatalf("n not equal: %d != %d", n, len(data) - 1)
if n != len(data) {
log.Fatalf("n not equal: %d != %d", n, len(data))
}
}
log.Println("- MessageInt8")
@@ -341,8 +341,20 @@ func TestGenerateRunDecodeWrongType(test *testing.T) {
if reflectValue.Len() != 0 {
log.Fatalf("len(destination) not zero: %v", reflectValue.Interface())
}
if n != len(data) - 1 {
log.Fatalf("n not equal: %d != %d", n, len(data) - 1)
if n != len(data) {
log.Fatalf("n not equal: %d != %d", n, len(data))
}
}
anyCase := func(destination Message) {
n, err := destination.Decode(tape.NewDecoder(bytes.NewBuffer(data)),)
if err != nil { log.Fatalf("error: %v | n: %d", err, n) }
reflectDestination := reflect.ValueOf(destination)
reflectValue := reflectDestination.Elem()
if reflectValue == reflect.Zero(reflectValue.Type()) {
log.Fatalf("len(destination) not zero: %v", reflectValue.Interface())
}
if n != len(data) {
log.Fatalf("n not equal: %d != %d", n, len(data))
}
}
// SBA/LBA types should only assign to other SBA/LBA types
@@ -360,9 +372,9 @@ func TestGenerateRunDecodeWrongType(test *testing.T) {
// tables should only assign to other tables
if index != 12 {
log.Println("- MessageTable")
{ var dest = new(MessageTable); arrayCase(&dest) }
{ var dest = make(MessageTable); arrayCase(&dest) }
log.Println("- MessageTableDefined")
{ var dest MessageTableDefined; arrayCase(&dest) }
{ var dest MessageTableDefined; anyCase(&dest) }
}
}
`)

View File

@@ -25,8 +25,8 @@ func defaultProtocol() Protocol {
}
}
func ParseReader(reader io.Reader) (*Protocol, error) {
lx, err := Lex("test.pdl", reader)
func ParseReader(fileName string, reader io.Reader) (*Protocol, error) {
lx, err := Lex(fileName, reader)
if err != nil { return nil, err }
return Parse(lx)
}

View File

@@ -33,7 +33,7 @@ func TestParse(test *testing.T) {
}
test.Log("CORRECT:", &correct)
got, err := ParseReader(strings.NewReader(`
got, err := ParseReader("test.pdl", strings.NewReader(`
M0000 Connect {
0000 Name String,
0001 Password String,

View File

@@ -1,9 +1,11 @@
package hopp
import "io"
import "os"
import "fmt"
import "net"
import "sync"
import "time"
import "sync/atomic"
import "git.tebibyte.media/sashakoshka/go-util/sync"
@@ -108,6 +110,10 @@ func (this *a) AcceptTrans() (Trans, error) {
}
}
func (this *a) SetDeadline(t time.Time) error {
return this.underlying.SetDeadline(t)
}
func (this *a) SetSizeLimit(limit int64) {
this.sizeLimit = limit
}
@@ -212,6 +218,10 @@ type transA struct {
currentWriter io.Closer
writeBuffer []byte
closed atomic.Bool
closeErr error
deadline *time.Timer
deadlineLock sync.Mutex
}
func (this *transA) Close() error {
@@ -221,6 +231,11 @@ func (this *transA) Close() error {
return err
}
func (this *transA) closeWithError(err error) error {
this.closeErr = err
return this.Close()
}
func (this *transA) closeDontUnlist() (err error) {
// MUST be goroutine safe
this.incoming.Close()
@@ -269,9 +284,9 @@ func (this *transA) Receive() (method uint16, data []byte, err error) {
}
func (this *transA) ReceiveReader() (uint16, io.Reader, error) {
// if the transaction has been closed, return an io.EOF
if this.closed.Load() {
return 0, nil, io.EOF
// if the transaction has been closed, return an appropriate error.
if err := this.errIfClosed(); err != nil {
return 0, nil, err
}
// drain previous reader if necessary
@@ -289,6 +304,54 @@ func (this *transA) ReceiveReader() (uint16, io.Reader, error) {
return method, reader, nil
}
func (this *transA) SetDeadline(t time.Time) error {
this.deadlineLock.Lock()
defer this.deadlineLock.Unlock()
if t == (time.Time { }) {
if this.deadline != nil {
this.deadline.Stop()
}
return nil
}
until := time.Until(t)
if this.deadline == nil {
this.deadline.Reset(until)
return nil
}
this.deadline = time.AfterFunc(until, func () {
this.closeWithError(os.ErrDeadlineExceeded)
})
return nil
}
// TODO
// func (this *transA) SetReadDeadline(t time.Time) error {
// // TODO
// }
//
// func (this *transA) SetWriteDeadline(t time.Time) error {
// // TODO
// }
func (this *transA) errIfClosed() error {
if !this.closed.Load() {
return nil
}
return this.bestErr()
}
func (this *transA) bestErr() error {
if this.parent.err != nil {
return this.parent.err
}
if this.closeErr != nil {
return this.closeErr
}
return io.EOF
}
type readerA struct {
parent *transA
leftover []byte
@@ -319,11 +382,7 @@ func (this *readerA) pull() (uint16, error) {
// close and return error on failure
this.eof = true
this.parent.Close()
if this.parent.parent.err == nil {
return 0, fmt.Errorf("could not receive message: %w", io.EOF)
} else {
return 0, this.parent.parent.err
}
return 0, fmt.Errorf("could not receive message: %w", this.parent.bestErr())
}
func (this *readerA) Read(buffer []byte) (int, error) {

View File

@@ -2,6 +2,7 @@ package hopp
import "io"
import "net"
import "time"
import "bytes"
import "errors"
import "context"
@@ -50,6 +51,10 @@ func (this *b) SetSizeLimit(limit int64) {
this.sizeLimit = limit
}
func (this *b) SetDeadline(t time.Time) error {
return this.underlying.SetDeadline(t)
}
func (this *b) newTrans(underlying Stream) *transB {
return &transB {
sizeLimit: this.sizeLimit,
@@ -124,6 +129,10 @@ func (this *transB) receiveReader() (uint16, int64, io.Reader, error) {
return method, size, data, nil
}
func (this *transB) SetDeadline(t time.Time) error {
return this.underlying.SetDeadline(t)
}
type writerB struct {
parent *transB
buffer bytes.Buffer
@@ -149,12 +158,16 @@ type MultiConn interface {
AcceptStream(context.Context) (Stream, error)
// OpenStream opens a new stream.
OpenStream() (Stream, error)
// See the documentation for [net.Conn.SetDeadline].
SetDeadline(time.Time) error
}
// Stream represents a single stream returned by a [MultiConn].
type Stream interface {
// See documentation for [net.Conn].
io.ReadWriteCloser
// See the documentation for [net.Conn.SetDeadline].
SetDeadline(time.Time) error
// ID returns the stream ID
ID() int64
}

View File

@@ -46,9 +46,16 @@ func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
case reflect.Uint32: return encoder.WriteUint32(uint32(reflectValue.Uint()))
case reflect.Int64: return encoder.WriteInt64(int64(reflectValue.Int()))
case reflect.Uint64: return encoder.WriteUint64(uint64(reflectValue.Uint()))
case reflect.String: return EncodeAny(encoder, []byte(reflectValue.String()), tag)
case reflect.String:
if reflectValue.Len() > MaxStructureLength {
return 0, ErrTooLong
}
return EncodeAny(encoder, []byte(reflectValue.String()), tag)
}
if reflectValue.CanConvert(reflect.TypeOf(dummyBuffer)) {
if reflectValue.Len() > MaxStructureLength {
return 0, ErrTooLong
}
if tag.Is(LBA) {
nn, err := encoder.WriteUintN(uint64(reflectValue.Len()), tag.CN() + 1)
n += nn; if err != nil { return n, err }
@@ -64,8 +71,13 @@ func EncodeAny(encoder *Encoder, value any, tag Tag) (n int, err error) {
case reflect.Slice:
return encodeAnySlice(encoder, value, tag)
// case reflect.Array:
// TODO: we can encode arrays. but can we decode into them?
// that's the fucken question. maybe we just do the first
// return encodeAnySlice(encoder, reflect.ValueOf(value).Slice(0, reflectType.Len()).Interface(), tag)
case reflect.Map:
if reflectValue.Len() > MaxStructureLength {
return 0, ErrTooLong
}
if reflectType.Key() == reflect.TypeOf(uint16(0)) {
return encodeAnyMap(encoder, value, tag)
}
@@ -133,7 +145,11 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
n += nn; if err != nil { return n, err }
case SBA:
// SBA: <data: U8>*
buffer := make([]byte, tag.CN())
length := tag.CN()
if length > MaxStructureLength {
return 0, ErrTooLong
}
buffer := make([]byte, length)
nn, err := decoder.Read(buffer)
n += nn; if err != nil { return n, err }
setByteArray(destination, buffer)
@@ -141,6 +157,9 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
// LBA: <length: UN> <data: U8>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
if length > uint64(MaxStructureLength) {
return 0, ErrTooLong
}
buffer := make([]byte, length)
nn, err = decoder.Read(buffer)
n += nn; if err != nil { return n, err }
@@ -149,10 +168,15 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
// OTA: <length: UN> <elementTag: tape.Tag> <values>*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
if length > uint64(MaxStructureLength) {
return 0, ErrTooLong
}
lengthCast, err := Uint64ToIntSafe(length)
if err != nil { return n, err }
oneTag, nn, err := decoder.ReadTag()
n += nn; if err != nil { return n, err }
if destination.Cap() < int(length) {
destination.Grow(int(length) - destination.Cap())
if destination.Cap() < lengthCast {
destination.Grow(lengthCast - destination.Cap())
}
// skip the rest of the array if the one tag doesn't
// match up with the destination
@@ -165,7 +189,7 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
break
}
if err != nil { return n, err }
destination.SetLen(int(length))
destination.SetLen(lengthCast)
for index := range length {
nn, err := decodeAny(decoder, destination.Index(int(index)), oneTag)
n += nn
@@ -179,6 +203,9 @@ func decodeAnyOrError(decoder *Decoder, destination reflect.Value, tag Tag) (n i
// KTV: <length: UN> (<key: U16> <tag: Tag> <value>)*
length, nn, err := decoder.ReadUintN(tag.CN() + 1)
n += nn; if err != nil { return n, err }
if length > uint64(MaxStructureLength) {
return 0, ErrTooLong
}
destination.Clear()
for _ = range length {
key, nn, err := decoder.ReadUint16()

12
tape/error.go Normal file
View File

@@ -0,0 +1,12 @@
package tape
// Error enumerates common errors in this package.
type Error string; const (
ErrTooLong Error = "data structure too long"
ErrTooLarge Error = "number too large"
)
// Error implements the error interface.
func (err Error) Error() string {
return string(err)
}

26
tape/limits.go Normal file
View File

@@ -0,0 +1,26 @@
package tape
// MaxStructureLength determines how long a TAPE data structure can be. This
// applies to:
//
// - OTA
// - SBA/LBA
// - KTV
//
// By default it is set at 2^20 (about a million).
// You shouldn't need to change this. If you do, it should only be set once at
// the start of the program.
var MaxStructureLength = 1024 * 1024
// MaxInt is the maximum value an int can hold. This varies depending on the
// system.
const MaxInt int = int(^uint(0) >> 1)
// Uint64ToIntSafe casts the input to an int if it can be done without overflow,
// or returns an error otherwise.
func Uint64ToIntSafe(input uint64) (int, error) {
if input > uint64(MaxInt) {
return 0, ErrTooLarge
}
return int(input), nil
}