diff --git a/metadapta.go b/metadapta.go index ca6a629..e9b8e56 100644 --- a/metadapta.go +++ b/metadapta.go @@ -10,6 +10,7 @@ import "git.tebibyte.media/sashakoshka/go-util/sync" const closeMethod = 0xFFFF const int64Max = int64((^uint64(0)) >> 1) + // Party represents a side of a connection. type Party bool; const ( ServerSide Party = false @@ -33,6 +34,7 @@ type a struct { // oriented transport such as TCP or UNIX domain stream sockets. func AdaptA(underlying net.Conn, party Party) Conn { conn := &a { + sizeLimit: defaultSizeLimit, underlying: underlying, party: party, transMap: make(map[int64] *transA), @@ -115,7 +117,7 @@ func (this *a) receive() { clear(this.transMap) }() for { - transID, method, chunked, payload, err := decodeMessageA(this.underlying) + transID, method, chunked, payload, err := decodeMessageA(this.underlying, this.sizeLimit) if err != nil { this.err = fmt.Errorf("could not receive message: %w", err) return @@ -277,7 +279,16 @@ func encodeMessageA(writer io.Writer, trans int64, method uint16, data []byte) e return err } -func decodeMessageA(reader io.Reader) (transID int64, method uint16, chunked bool, payloadBuffer []byte, err error) { +func decodeMessageA( + reader io.Reader, + sizeLimit int64, +) ( + transID int64, + method uint16, + chunked bool, + payloadBuffer []byte, + err error, +) { headerBuffer := [18]byte { } _, err = io.ReadFull(reader, headerBuffer[:]) if err != nil { return 0, 0, false, nil, err } @@ -288,6 +299,9 @@ func decodeMessageA(reader io.Reader) (transID int64, method uint16, chunked boo size, err := tape.DecodeI64[uint64](headerBuffer[10:18]) if err != nil { return 0, 0, false, nil, err } chunked, size = splitCCBSize(size) + if size > uint64(sizeLimit) { + return 0, 0, false, nil, ErrPayloadTooLarge + } payloadBuffer = make([]byte, int(size)) _, err = io.ReadFull(reader, payloadBuffer) if err != nil { return 0, 0, false, nil, err }