diff --git a/metadapta.go b/metadapta.go index 1879b15..c23d579 100644 --- a/metadapta.go +++ b/metadapta.go @@ -17,6 +17,7 @@ type Party bool; const ( ) type a struct { + sizeLimit int underlying net.Conn party Party transID int64 @@ -87,6 +88,10 @@ func (this *a) AcceptTrans() (Trans, error) { } } +func (this *a) SetSizeLimit(limit int) { + this.sizeLimit = limit +} + func (this *a) unlistTransactionSafe(id int64) { this.transLock.Lock() defer this.transLock.Unlock() @@ -110,13 +115,13 @@ func (this *a) receive() { clear(this.transMap) }() for { - transID, method, payload, err := decodeMessageA(this.underlying) + transID, method, chunked, payload, err := decodeMessageA(this.underlying) if err != nil { this.err = fmt.Errorf("could not receive message: %w", err) return } - err = this.receiveMultiplex(transID, method, payload) + err = this.receiveMultiplex(transID, method, chunked, payload) if err != nil { this.err = fmt.Errorf("could not receive message: %w", err) return @@ -124,7 +129,7 @@ func (this *a) receive() { } } -func (this *a) receiveMultiplex(transID int64, method uint16, payload []byte) error { +func (this *a) receiveMultiplex(transID int64, method uint16, chunked bool, payload []byte) error { if transID == 0 { return ErrMessageMalformed } trans, err := func() (*transA, error) { @@ -152,15 +157,17 @@ func (this *a) receiveMultiplex(transID int64, method uint16, payload []byte) er trans.incoming.Send(incomingMessage { method: method, + chunked: chunked, payload: payload, }) return nil } type transA struct { - parent *a - id int64 - incoming usync.Gate[incomingMessage] + parent *a + id int64 + incoming usync.Gate[incomingMessage] + currentReader io.Reader } func (this *transA) Close() error { @@ -183,26 +190,78 @@ func (this *transA) Send(method uint16, data []byte) error { } func (this *transA) Receive() (method uint16, data []byte, err error) { - receive := this.incoming.Receive() + method, reader, err := this.ReceiveReader() + if err != nil { return 0, nil, err } + data, err = io.ReadAll(reader) + if err != nil { return 0, nil, err } + return method, data, nil +} + +func (this *transA) ReceiveReader() (uint16, io.Reader, error) { + // drain previous reader if necessary + if this.currentReader != nil { + io.Copy(io.Discard, this.currentReader) + } + + // create new reader + reader := &readerA { + parent: this, + } + method, err := reader.pull() + if err != nil { return 0, nil, err} + this.currentReader = reader + return method, reader, nil +} + +type readerA struct { + parent *transA + leftover []byte + eof bool +} + +func (this *readerA) pull() (uint16, error) { + // if the previous message ended the chain, return an io.EOF + if this.eof { + return 0, io.EOF + } + + // get a message from the transaction we are a part of + receive := this.parent.incoming.Receive() if receive != nil { if message, ok := <- receive; ok { if message.method != closeMethod { - return message.method, message.payload, nil + this.leftover = append(this.leftover, message.payload...) + if !message.chunked { + this.eof = true + } } } } // close and return error on failure - this.Close() - if this.parent.err == nil { - return 0, nil, fmt.Errorf("could not receive message: %w", io.EOF) + this.eof = true + this.parent.Close() + if this.parent.parent.err == nil { + return 0, fmt.Errorf("could not receive message: %w", io.EOF) } else { - return 0, nil, this.parent.err + return 0, this.parent.parent.err } } +func (this *readerA) Read(buffer []byte) (int, error) { + if len(this.leftover) == 0 { + if this.eof { return 0, io.EOF } + this.pull() + } + + copied := copy(buffer, this.leftover) + this.leftover = this.leftover[copied:] + return copied, nil +} + type incomingMessage struct { method uint16 + chunked bool payload []byte } @@ -218,22 +277,27 @@ func encodeMessageA(writer io.Writer, trans int64, method uint16, data []byte) e return err } -func decodeMessageA(reader io.Reader) (int64, uint16, []byte, error) { - headerBuffer := [12]byte { } +func decodeMessageA(reader io.Reader) (int64, uint16, bool, []byte, error) { + headerBuffer := [18]byte { } _, err := io.ReadFull(reader, headerBuffer[:]) - if err != nil { return 0, 0, nil, err } + if err != nil { return 0, 0, false, nil, err } transID, err := tape.DecodeI64[int64](headerBuffer[:8]) - if err != nil { return 0, 0, nil, err } + if err != nil { return 0, 0, false, nil, err } method, err := tape.DecodeI16[uint16](headerBuffer[8:10]) - if err != nil { return 0, 0, nil, err } - length, err := tape.DecodeI16[uint16](headerBuffer[10:12]) - if err != nil { return 0, 0, nil, err } - payloadBuffer := make([]byte, int(length)) + if err != nil { return 0, 0, false, nil, err } + size, err := tape.DecodeI64[uint64](headerBuffer[10:18]) + if err != nil { return 0, 0, false, nil, err } + chunked, size := splitCCBSize(size) + payloadBuffer := make([]byte, int(size)) _, err = io.ReadFull(reader, payloadBuffer) - if err != nil { return 0, 0, nil, err } - return transID, method, payloadBuffer, nil + if err != nil { return 0, 0, false, nil, err } + return transID, method, chunked, payloadBuffer, nil } func partyFromTransID(id int64) Party { return id > 0 } + +func splitCCBSize(size uint64) (bool, uint64) { + return size >> 63 > 1, size & 0x7FFFFFFFFFFFFFFF +}