generate: Don't create a new decoder for a possibly nil reader

This commit is contained in:
Sasha Koshka 2025-11-19 13:13:40 -05:00
parent 3136dcbfdf
commit 0ac34b2f22
2 changed files with 22 additions and 4 deletions

View File

@ -1212,10 +1212,10 @@ func (this *Generator) generateReceive() (n int, err error) {
this.push()
nn, err = this.iprintf("method, reader, err := trans.ReceiveReader()\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("decoder := tape.NewDecoder(reader)\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("if err != nil { return nil, n, err }\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("decoder := tape.NewDecoder(reader)\n")
n += nn; if err != nil { return n, err }
nn, err = this.iprintf("switch method {\n")
n += nn; if err != nil { return n, err }
for method, message := range this.protocol.Messages {

View File

@ -547,7 +547,7 @@ func TestGenerateRunConn(test *testing.T) {
group.Add(2)
// server
listener, err := hopp.Listen("tcp", "localhost:9999", nil)
listener, err := hopp.Listen("tcp", "localhost:43957", nil)
if err != nil { log.Fatalln("SERVER listen:", err) }
go func() {
defer listener.Close()
@ -557,6 +557,7 @@ func TestGenerateRunConn(test *testing.T) {
if err != nil { log.Fatalln("SERVER accept:", err) }
trans, err := conn.AcceptTrans()
if err != nil { log.Fatalln("SERVER accept trans:", err) }
message, n, err := Receive(trans)
if err != nil { log.Fatalln("SERVER receive:", err) }
log.Println("SERVER got message", message)
@ -565,6 +566,15 @@ func TestGenerateRunConn(test *testing.T) {
if !ok { log.Fatalln("SERVER expected MessagePong") }
if casted != 77 { log.Fatalln("SERVER wrong message value") }
if n != 5 { log.Fatalln("SERVER wrong n value") }
message, n, err = Receive(trans)
if err != nil { log.Fatalln("SERVER receive:", err) }
log.Println("SERVER got message", message)
log.Println("SERVER got n", n)
casted, ok = message.(MessagePing)
if !ok { log.Fatalln("SERVER expected MessagePong") }
if casted != 78 { log.Fatalln("SERVER wrong message value") }
if n != 5 { log.Fatalln("SERVER wrong n value") }
}()
// client
@ -574,7 +584,7 @@ func TestGenerateRunConn(test *testing.T) {
log.Println("CLIENT dialing")
conn, err := hopp.Dial(
context.Background(),
"tcp", "localhost:9999",
"tcp", "localhost:43957",
nil)
if err != nil { log.Fatalln("CLIENT dial:", err) }
defer conn.Close()
@ -583,12 +593,20 @@ func TestGenerateRunConn(test *testing.T) {
log.Println("CLIENT opening trans")
trans, err := conn.OpenTrans()
if err != nil { log.Fatalln("CLIENT open trans:", err) }
message := MessagePing(77)
log.Println("CLIENT sending message")
n, err := Send(trans, &message)
if err != nil { log.Fatalln("CLIENT send:", err) }
log.Println("CLIENT sent n", n)
if n != 5 { log.Fatalln("CLIENT wrong n value") }
message = MessagePing(78)
log.Println("CLIENT sending message")
n, err = Send(trans, &message)
if err != nil { log.Fatalln("CLIENT send:", err) }
log.Println("CLIENT sent n", n)
if n != 5 { log.Fatalln("CLIENT wrong n value") }
}()
group.Wait()