diff --git a/metadapta.go b/metadapta.go index 8250478..9739500 100644 --- a/metadapta.go +++ b/metadapta.go @@ -4,9 +4,13 @@ import "io" import "fmt" import "net" import "sync" +import "sync/atomic" import "git.tebibyte.media/sashakoshka/hopp/tape" import "git.tebibyte.media/sashakoshka/go-util/sync" +// TODO investigate why 30 never reaches the server, causing it to wait for ever +// and never close the connection, causing the client to also wait forever + const closeMethod = 0xFFFF const int64Max = int64((^uint64(0)) >> 1) const defaultChunkSize = 0x1000 @@ -17,6 +21,14 @@ type Party bool; const ( ClientSide Party = true ) +func (party Party) String() string { + if party == ServerSide { + return "server" + } else { + return "client" + } +} + type a struct { sizeLimit int64 underlying net.Conn @@ -52,7 +64,7 @@ func AdaptA(underlying net.Conn, party Party) Conn { func (this *a) Close() error { close(this.done) - return this.underlying.Close() + return nil } func (this *a) LocalAddr() net.Addr { @@ -85,11 +97,15 @@ func (this *a) OpenTrans() (Trans, error) { } func (this *a) AcceptTrans() (Trans, error) { + eof := fmt.Errorf("could not accept transaction: %w", io.EOF) select { case trans := <- this.transChan: + if trans == nil { + return nil, eof + } return trans, nil case <- this.done: - return nil, fmt.Errorf("could not accept transaction: %w", io.EOF) + return nil, eof } } @@ -119,7 +135,11 @@ func (this *a) receive() { trans.closeDontUnlist() } clear(this.transMap) + this.underlying.Close() }() + + // receive MMBs in a loop and forward them to transactions until shit + // starts closing for { transID, method, chunked, payload, err := decodeMessageA(this.underlying, this.sizeLimit) if err != nil { @@ -127,7 +147,7 @@ func (this *a) receive() { return } - err = this.receiveMultiplex(transID, method, chunked, payload) + err = this.multiplexMMB(transID, method, chunked, payload) if err != nil { this.err = fmt.Errorf("could not receive message: %w", err) return @@ -135,7 +155,7 @@ func (this *a) receive() { } } -func (this *a) receiveMultiplex(transID int64, method uint16, chunked bool, payload []byte) error { +func (this *a) multiplexMMB(transID int64, method uint16, chunked bool, payload []byte) error { if transID == 0 { return ErrMessageMalformed } trans, err := func() (*transA, error) { @@ -144,6 +164,12 @@ func (this *a) receiveMultiplex(transID int64, method uint16, chunked bool, payl trans, ok := this.transMap[transID] if !ok { + // check if this is a superfluous close message and just + // do nothing if so + if method == closeMethod { + return nil, nil + } + // it is forbidden for the other party to initiate a transaction // with an ID from this party if this.party == partyFromTransID(transID) { @@ -161,14 +187,24 @@ func (this *a) receiveMultiplex(transID int64, method uint16, chunked bool, payl }() if err != nil { return err } - trans.incoming.Send(incomingMessage { - method: method, - chunked: chunked, - payload: payload, - }) + if trans == nil { + return nil + } + + if method == closeMethod { + return trans.Close() + } else { + trans.incoming.Send(incomingMessage { + method: method, + chunked: chunked, + payload: payload, + }) + } return nil } +// most methods in transA don't need to be goroutine safe except those marked +// as such type transA struct { parent *a id int64 @@ -176,18 +212,24 @@ type transA struct { currentReader io.Reader currentWriter io.Closer writeBuffer []byte + closed atomic.Bool } func (this *transA) Close() error { + // MUST be goroutine safe err := this.closeDontUnlist() this.parent.unlistTransactionSafe(this.ID()) return err } -func (this *transA) closeDontUnlist() error { - this.Send(closeMethod, nil) - this.parent.sendMessageSafe(this.id, 0xFFFF, nil) - return this.incoming.Close() +func (this *transA) closeDontUnlist() (err error) { + // MUST be goroutine safe + this.incoming.Close() + if !this.closed.Load() { + err = this.Send(closeMethod, nil) + } + this.closed.Store(true) + return err } func (this *transA) ID() int64 { @@ -228,6 +270,11 @@ func (this *transA) Receive() (method uint16, data []byte, err error) { } func (this *transA) ReceiveReader() (uint16, io.Reader, error) { + // if the transaction has been closed, return an io.EOF + if this.closed.Load() { + return 0, nil, io.EOF + } + // drain previous reader if necessary if this.currentReader != nil { io.Copy(io.Discard, this.currentReader) @@ -249,13 +296,14 @@ type readerA struct { eof bool } +// pull pulls the next MMB in this message from the transaction. func (this *readerA) pull() (uint16, error) { // if the previous message ended the chain, return an io.EOF if this.eof { return 0, io.EOF } - // get a message from the transaction we are a part of + // get an MMB from the transaction we are a part of receive := this.parent.incoming.Receive() if receive != nil { if message, ok := <- receive; ok { @@ -265,6 +313,9 @@ func (this *readerA) pull() (uint16, error) { this.eof = true } return message.method, nil + } else { + // signal parent transaction of closure + this.parent.closed.Store(true) } } } diff --git a/metadapta_test.go b/metadapta_test.go index c85d6d9..62dfdd9 100644 --- a/metadapta_test.go +++ b/metadapta_test.go @@ -45,9 +45,12 @@ func TestConnA(test *testing.T) { } } test.Log("CLIENT waiting for transaction close...") - _, _, err = trans.Receive() + gotMethod, gotPayload, err := trans.Receive() if !errors.Is(err, io.EOF) { - test.Fatal("CLIENT wrong error:", err) + test.Error("CLIENT wrong error:", err) + test.Error("CLIENT method:", gotMethod) + test.Error("CLIENT payload:", gotPayload) + test.Fatal("CLIENT ok byeeeeeeeeeeeee") } } @@ -66,6 +69,79 @@ func TestConnA(test *testing.T) { clientServerEnvironment(test, clientFunc, serverFunc) } +func TestTransOpenCloseA(test *testing.T) { + // currently: + // + // | data sent | data recvd | close sent | close recvd + // 10 | X | X | X | server hangs + // 20 | X | X | X | client hangs + // 30 | X | | X | + // + // when a close message is recvd, it tries to push to the trans and + // hangs on trans.incoming.Send, which hangs on sending the value to the + // underlying channel. why is this? + // + // check if we are really getting values from the channel when pulling + // from the trans channel when we are expecting a close. + + clientFunc := func(conn Conn) { + // 10 + trans, err := conn.OpenTrans() + if err != nil { test.Error("CLIENT", err); return } + test.Log("CLIENT sending 10") + trans.Send(10, []byte("hi")) + trans.Close() + + // 20 + test.Log("CLIENT awaiting 20") + trans, err = conn.AcceptTrans() + if err != nil { test.Error("CLIENT", err); return } + test.Cleanup(func() { trans.Close() }) + gotMethod, gotPayload, err := trans.Receive() + if err != nil { test.Error("CLIENT", err); return } + test.Logf("CLIENT m: %d p: %s", gotMethod, gotPayload) + if gotMethod != 20 { test.Error("CLIENT wrong method")} + + // 30 + trans, err = conn.OpenTrans() + if err != nil { test.Error("CLIENT", err); return } + test.Log("CLIENT sending 30") + trans.Send(30, []byte("good")) + trans.Close() + } + + serverFunc := func(conn Conn) { + // 10 + test.Log("SERVER awaiting 10") + trans, err := conn.AcceptTrans() + if err != nil { test.Error("SERVER", err); return } + test.Cleanup(func() { trans.Close() }) + gotMethod, gotPayload, err := trans.Receive() + if err != nil { test.Error("SERVER", err); return } + test.Logf("SERVER m: %d p: %s", gotMethod, gotPayload) + if gotMethod != 10 { test.Error("SERVER wrong method")} + + // 20 + trans, err = conn.OpenTrans() + if err != nil { test.Error("SERVER", err); return } + test.Log("SERVER sending 20") + trans.Send(20, []byte("hi how r u")) + trans.Close() + + // 30 + test.Log("SERVER awaiting 30") + trans, err = conn.AcceptTrans() + if err != nil { test.Error("SERVER", err); return } + test.Cleanup(func() { trans.Close() }) + gotMethod, gotPayload, err = trans.Receive() + if err != nil { test.Error("SERVER", err); return } + test.Logf("SERVER m: %d p: %s", gotMethod, gotPayload) + if gotMethod != 30 { test.Error("SERVER wrong method")} + } + + clientServerEnvironment(test, clientFunc, serverFunc) +} + func TestEncodeMessageA(test *testing.T) { buffer := new(bytes.Buffer) payload := []byte { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 } @@ -127,6 +203,28 @@ func TestDecodeMessageAErr(test *testing.T) { } } +func TestEncodeDecodeMessageA(test *testing.T) { + correctTransID := int64(2) + correctMethod := uint16(30) + correctPayload := []byte("good") + buffer := bytes.Buffer { } + err := encodeMessageA(&buffer, defaultSizeLimit, correctTransID, correctMethod, correctPayload) + if err != nil { test.Fatal(err) } + transID, method, chunked, payload, err := decodeMessageA(&buffer, defaultSizeLimit) + if got, correct := transID, int64(2); got != correct { + test.Fatalf("not equal: %v %v", got, correct) + } + if got, correct := method, uint16(30); got != correct { + test.Fatalf("not equal: %v %v", got, correct) + } + if chunked { + test.Fatalf("message should not be chunked") + } + if got, correct := payload, correctPayload; !slices.Equal(got, correct) { + test.Fatalf("not equal: %v %v", got, correct) + } +} + func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), serverFunc func(conn Conn)) { network := "tcp" addr := "localhost:7959" @@ -145,6 +243,7 @@ func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), server test.Cleanup(func() { a.Close() }) serverFunc(a) + test.Log("SERVER closing") }() // client @@ -158,9 +257,10 @@ func clientServerEnvironment(test *testing.T, clientFunc func(conn Conn), server clientFunc(a) test.Log("CLIENT waiting for connection close...") - _, err = a.AcceptTrans() + trans, err := a.AcceptTrans() if !errors.Is(err, io.EOF) { - test.Fatal("CLIENT wrong error:", err) + test.Error("CLIENT wrong error:", err) + test.Fatal("CLIENT trans:", trans) } test.Log("CLIENT DONE") conn.Close()