Fix METADAPT-A not ever closing the connection properly

This commit is contained in:
2025-10-27 22:32:47 -04:00
parent 3daa66c4bc
commit 26b8174f92
2 changed files with 14 additions and 4 deletions

View File

@@ -6,6 +6,7 @@ import "fmt"
import "net" import "net"
import "sync" import "sync"
import "time" import "time"
import "context"
import "sync/atomic" import "sync/atomic"
import "git.tebibyte.media/sashakoshka/go-util/sync" import "git.tebibyte.media/sashakoshka/go-util/sync"
@@ -39,20 +40,23 @@ type a struct {
sendLock sync.Mutex sendLock sync.Mutex
transMap map[int64] *transA transMap map[int64] *transA
transChan chan *transA transChan chan *transA
done chan struct { } ctx context.Context
done func()
err error err error
} }
// AdaptA returns a connection implementing METADAPT-A over a singular stream- // AdaptA returns a connection implementing METADAPT-A over a singular stream-
// oriented transport such as TCP or UNIX domain stream sockets. // oriented transport such as TCP or UNIX domain stream sockets.
func AdaptA(underlying net.Conn, party Party) Conn { func AdaptA(underlying net.Conn, party Party) Conn {
ctx, done := context.WithCancel(context.Background())
conn := &a { conn := &a {
sizeLimit: defaultSizeLimit, sizeLimit: defaultSizeLimit,
underlying: underlying, underlying: underlying,
party: party, party: party,
transMap: make(map[int64] *transA), transMap: make(map[int64] *transA),
transChan: make(chan *transA), transChan: make(chan *transA),
done: make(chan struct { }), ctx: ctx,
done: done,
} }
if party == ClientSide { if party == ClientSide {
conn.transID = 1 conn.transID = 1
@@ -60,11 +64,15 @@ func AdaptA(underlying net.Conn, party Party) Conn {
conn.transID = -1 conn.transID = -1
} }
go conn.receive() go conn.receive()
go func() {
<- ctx.Done()
underlying.Close()
}()
return conn return conn
} }
func (this *a) Close() error { func (this *a) Close() error {
close(this.done) this.done()
return nil return nil
} }
@@ -105,7 +113,7 @@ func (this *a) AcceptTrans() (Trans, error) {
return nil, eof return nil, eof
} }
return trans, nil return trans, nil
case <- this.done: case <- this.ctx.Done():
return nil, eof return nil, eof
} }
} }
@@ -490,6 +498,7 @@ func decodeMessageA(
headerBuffer := [18]byte { } headerBuffer := [18]byte { }
_, err = io.ReadFull(reader, headerBuffer[:]) _, err = io.ReadFull(reader, headerBuffer[:])
if err != nil { return 0, 0, false, nil, err } if err != nil { return 0, 0, false, nil, err }
transID, err = decodeI64[int64](headerBuffer[:8]) transID, err = decodeI64[int64](headerBuffer[:8])
if err != nil { return 0, 0, false, nil, err } if err != nil { return 0, 0, false, nil, err }
method, err = decodeI16[uint16](headerBuffer[8:10]) method, err = decodeI16[uint16](headerBuffer[8:10])

View File

@@ -52,6 +52,7 @@ func TestConnA(test *testing.T) {
test.Error("CLIENT payload:", gotPayload) test.Error("CLIENT payload:", gotPayload)
test.Fatal("CLIENT ok byeeeeeeeeeeeee") test.Fatal("CLIENT ok byeeeeeeeeeeeee")
} }
test.Log("CLIENT transaction has closed")
} }
serverFunc := func(a Conn) { serverFunc := func(a Conn) {