diff --git a/connection.go b/connection.go index 88ec7d4..07eb829 100644 --- a/connection.go +++ b/connection.go @@ -2,7 +2,7 @@ package hopp import "io" import "net" -// import "time" +import "time" const defaultSizeLimit int64 = 1024 * 1024 // 1 megabyte @@ -23,6 +23,9 @@ type Conn interface { // be called in a loop to avoid the connection locking up. AcceptTrans() (Trans, error) + // SetDeadline operates is [net.Conn.SetDeadline] but for OpenTrans + // and AcceptTrans calls. + SetDeadline(t time.Time) error // SetSizeLimit sets a limit (in bytes) for how large messages can be. // By default, this limit is 1 megabyte. Note that this is only // enforced when sending and receiving byte slices, and it does not @@ -41,8 +44,6 @@ type Trans interface { // ID returns the transaction ID. This must not change, and it must be // unique within the connection. This method is safe for concurrent use. ID() int64 - - // TODO: add methods for setting send and receive deadlines // Send sends a message. This method is not safe for concurrent use. Send(method uint16, data []byte) error @@ -59,4 +60,12 @@ type Trans interface { // previously opened through this function will be discarded. This // method is not safe for concurrent use, and neither is its result. ReceiveReader() (method uint16, data io.Reader, err error) + + // See the documentation for [net.Conn.SetDeadline]. + SetDeadline(time.Time) error + // TODO + // // See the documentation for [net.Conn.SetReadDeadline]. + // SetReadDeadline(t time.Time) error + // // See the documentation for [net.Conn.SetWriteDeadline]. + // SetWriteDeadline(t time.Time) error } diff --git a/metadapta.go b/metadapta.go index 4cbe2ce..cae8b12 100644 --- a/metadapta.go +++ b/metadapta.go @@ -1,9 +1,11 @@ package hopp import "io" +import "os" import "fmt" import "net" import "sync" +import "time" import "sync/atomic" import "git.tebibyte.media/sashakoshka/go-util/sync" @@ -108,6 +110,10 @@ func (this *a) AcceptTrans() (Trans, error) { } } +func (this *a) SetDeadline(t time.Time) error { + return this.underlying.SetDeadline(t) +} + func (this *a) SetSizeLimit(limit int64) { this.sizeLimit = limit } @@ -212,6 +218,10 @@ type transA struct { currentWriter io.Closer writeBuffer []byte closed atomic.Bool + closeErr error + + deadline *time.Timer + deadlineLock sync.Mutex } func (this *transA) Close() error { @@ -221,6 +231,11 @@ func (this *transA) Close() error { return err } +func (this *transA) closeWithError(err error) error { + this.closeErr = err + return this.Close() +} + func (this *transA) closeDontUnlist() (err error) { // MUST be goroutine safe this.incoming.Close() @@ -269,9 +284,9 @@ func (this *transA) Receive() (method uint16, data []byte, err error) { } func (this *transA) ReceiveReader() (uint16, io.Reader, error) { - // if the transaction has been closed, return an io.EOF - if this.closed.Load() { - return 0, nil, io.EOF + // if the transaction has been closed, return an appropriate error. + if err := this.errIfClosed(); err != nil { + return 0, nil, err } // drain previous reader if necessary @@ -289,6 +304,54 @@ func (this *transA) ReceiveReader() (uint16, io.Reader, error) { return method, reader, nil } +func (this *transA) SetDeadline(t time.Time) error { + this.deadlineLock.Lock() + defer this.deadlineLock.Unlock() + + if t == (time.Time { }) { + if this.deadline != nil { + this.deadline.Stop() + } + return nil + } + + until := time.Until(t) + if this.deadline == nil { + this.deadline.Reset(until) + return nil + } + this.deadline = time.AfterFunc(until, func () { + this.closeWithError(os.ErrDeadlineExceeded) + }) + return nil +} + +// TODO +// func (this *transA) SetReadDeadline(t time.Time) error { +// // TODO +// } +// +// func (this *transA) SetWriteDeadline(t time.Time) error { +// // TODO +// } + +func (this *transA) errIfClosed() error { + if !this.closed.Load() { + return nil + } + return this.bestErr() +} + +func (this *transA) bestErr() error { + if this.parent.err != nil { + return this.parent.err + } + if this.closeErr != nil { + return this.closeErr + } + return io.EOF +} + type readerA struct { parent *transA leftover []byte @@ -319,11 +382,7 @@ func (this *readerA) pull() (uint16, error) { // close and return error on failure this.eof = true this.parent.Close() - if this.parent.parent.err == nil { - return 0, fmt.Errorf("could not receive message: %w", io.EOF) - } else { - return 0, this.parent.parent.err - } + return 0, fmt.Errorf("could not receive message: %w", this.parent.bestErr()) } func (this *readerA) Read(buffer []byte) (int, error) { diff --git a/metadaptb.go b/metadaptb.go index b2f96e6..e1118fe 100644 --- a/metadaptb.go +++ b/metadaptb.go @@ -2,6 +2,7 @@ package hopp import "io" import "net" +import "time" import "bytes" import "errors" import "context" @@ -50,6 +51,10 @@ func (this *b) SetSizeLimit(limit int64) { this.sizeLimit = limit } +func (this *b) SetDeadline(t time.Time) error { + return this.underlying.SetDeadline(t) +} + func (this *b) newTrans(underlying Stream) *transB { return &transB { sizeLimit: this.sizeLimit, @@ -124,6 +129,10 @@ func (this *transB) receiveReader() (uint16, int64, io.Reader, error) { return method, size, data, nil } +func (this *transB) SetDeadline(t time.Time) error { + return this.underlying.SetDeadline(t) +} + type writerB struct { parent *transB buffer bytes.Buffer @@ -149,12 +158,16 @@ type MultiConn interface { AcceptStream(context.Context) (Stream, error) // OpenStream opens a new stream. OpenStream() (Stream, error) + // See the documentation for [net.Conn.SetDeadline]. + SetDeadline(time.Time) error } // Stream represents a single stream returned by a [MultiConn]. type Stream interface { // See documentation for [net.Conn]. io.ReadWriteCloser + // See the documentation for [net.Conn.SetDeadline]. + SetDeadline(time.Time) error // ID returns the stream ID ID() int64 }