Support getting a reader for a message in METADAPT-B

This commit is contained in:
Sasha Koshka 2025-04-06 14:17:39 -04:00
parent fe8f2fc3ea
commit f4f8039fa0

View File

@ -5,6 +5,8 @@ import "net"
import "context" import "context"
import "git.tebibyte.media/sashakoshka/hopp/tape" import "git.tebibyte.media/sashakoshka/hopp/tape"
// TODO: change size limit to be int64
// B implements METADAPT-B over a multiplexed stream-oriented transport such as // B implements METADAPT-B over a multiplexed stream-oriented transport such as
// QUIC. // QUIC.
type b struct { type b struct {
@ -12,8 +14,8 @@ type b struct {
underlying MultiConn underlying MultiConn
} }
// AdaptB returns a connection implementing METADAPT-B over a singular stream- // AdaptB returns a connection implementing METADAPT-B over a multiplexed
// oriented transport such as TCP or UNIX domain stream sockets. // stream-oriented transport such as QUIC.
func AdaptB(underlying MultiConn) Conn { func AdaptB(underlying MultiConn) Conn {
return &b { return &b {
sizeLimit: defaultSizeLimit, sizeLimit: defaultSizeLimit,
@ -49,8 +51,8 @@ func (this *b) SetSizeLimit(limit int) {
this.sizeLimit = limit this.sizeLimit = limit
} }
func (this *b) newTrans(underlying Stream) transB { func (this *b) newTrans(underlying Stream) *transB {
return transB { return &transB {
sizeLimit: this.sizeLimit, sizeLimit: this.sizeLimit,
underlying: underlying, underlying: underlying,
} }
@ -59,22 +61,49 @@ func (this *b) newTrans(underlying Stream) transB {
type transB struct { type transB struct {
sizeLimit int sizeLimit int
underlying Stream underlying Stream
currentData io.Reader
} }
func (trans transB) Close() error { func (this *transB) Close() error {
return trans.underlying.Close() return this.underlying.Close()
} }
func (trans transB) ID() int64 { func (this *transB) ID() int64 {
return trans.underlying.ID() return this.underlying.ID()
} }
func (trans transB) Send(method uint16, data []byte) error { func (this *transB) Send(method uint16, data []byte) error {
return encodeMessageB(trans.underlying, trans.sizeLimit, method, data) return encodeMessageB(this.underlying, this.sizeLimit, method, data)
} }
func (trans transB) Receive() (uint16, []byte, error) { func (this *transB) Receive() (uint16, []byte, error) {
return decodeMessageB(trans.underlying, trans.sizeLimit) // get a reader for the next message
method, size, data, err := this.receiveReader()
if err != nil { return 0, nil, err }
// read the entire thing
payloadBuffer := make([]byte, int(size))
_, err = io.ReadFull(data, payloadBuffer)
if err != nil { return 0, nil, err }
// we have used up the reader by now so we can forget it exists
this.currentData = nil
return method, payloadBuffer, nil
}
func (this *transB) ReceiveReader() (uint16, io.Reader, error) {
method, _, data, err := this.receiveReader()
return method, data, err
}
func (this *transB) receiveReader() (uint16, int64, io.Reader, error) {
// decode the message
method, size, data, err := decodeMessageB(this.underlying, this.sizeLimit)
if err != nil { return 0, 0, nil, err }
// discard current reader if there is one
if this.currentData == nil {
io.Copy(io.Discard, this.currentData)
}
this.currentData = data
return method, size, data, nil
} }
// MultiConn represens a multiplexed stream-oriented transport for use in // MultiConn represens a multiplexed stream-oriented transport for use in
@ -110,19 +139,27 @@ func encodeMessageB(writer io.Writer, sizeLimit int, method uint16, data []byte)
return err return err
} }
func decodeMessageB(reader io.Reader, sizeLimit int) (uint16, []byte, error) { func decodeMessageB(
reader io.Reader,
sizeLimit int,
) (
method uint16,
size int64,
data io.Reader,
err error,
) {
headerBuffer := [10]byte { } headerBuffer := [10]byte { }
_, err := io.ReadFull(reader, headerBuffer[:]) _, err = io.ReadFull(reader, headerBuffer[:])
if err != nil { return 0, nil, err } if err != nil { return 0, 0, nil, err }
method, err := tape.DecodeI16[uint16](headerBuffer[:2]) method, err = tape.DecodeI16[uint16](headerBuffer[:2])
if err != nil { return 0, nil, err } if err != nil { return 0, 0, nil, err }
length, err := tape.DecodeI64[uint64](headerBuffer[2:10]) length, err := tape.DecodeI64[uint64](headerBuffer[2:10])
if err != nil { return 0, nil, err } if err != nil { return 0, 0, nil, err }
if length > uint64(sizeLimit) { if length > uint64(sizeLimit) {
return 0, nil, ErrPayloadTooLarge return 0, 0, nil, ErrPayloadTooLarge
} }
payloadBuffer := make([]byte, int(length)) return method, int64(length), &io.LimitedReader {
_, err = io.ReadFull(reader, payloadBuffer) R: reader,
if err != nil { return 0, nil, err } N: int64(length),
return method, payloadBuffer, nil }, nil
} }