From ae79a323091350ca937c697816c79b38d6785114 Mon Sep 17 00:00:00 2001 From: Sasha Koshka Date: Thu, 28 Aug 2025 12:51:55 -0400 Subject: [PATCH] generate: Respect limits when statically encoding --- generate/generate.go | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/generate/generate.go b/generate/generate.go index 0d70fce..9e0d132 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -343,7 +343,14 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri case TypeBuffer: // SBA: * // LBA: * - nn, err := this.iprintf("if %s.Is(tape.LBA) {\n", tagSource) + nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if %s.Is(tape.LBA) {\n", tagSource) n += nn; if err != nil { return n, err } this.push() nn, err = this.iprintf( @@ -355,14 +362,20 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } - nn, err = this.iprintf("nn, err = encoder.Write([]byte(%s))\n", valueSource) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() n += nn; if err != nil { return n, err } case TypeArray: // OTA: * - nn, err := this.iprintf( + nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( "nn, err = encoder.WriteUintN(uint64(len(%s)), %s.CN())\n", valueSource, tagSource) n += nn; if err != nil { return n, err } @@ -417,7 +430,14 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri n += nn; if err != nil { return n, err } case TypeTable: // KTV: ( )* - nn, err := this.iprintf( + nn, err := this.iprintf("if len(%s) > tape.MaxStructureLength {\n", valueSource) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( "nn, err = tape.EncodeAny(encoder, %s, %s)\n", valueSource, tagSource) n += nn; if err != nil { return n, err } @@ -425,7 +445,14 @@ func (this *Generator) generateEncodeValue(typ Type, valueSource, tagSource stri n += nn; if err != nil { return n, err } case TypeTableDefined: // KTV: ( )* - nn, err := this.iprintf( + nn, err := this.iprintf("if %d > tape.MaxStructureLength {\n", len(typ.Fields)) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } + nn, err = this.iprintf( "nn, err = encoder.WriteUintN(%d, %s.CN())\n", len(typ.Fields), tagSource) n += nn; if err != nil { return n, err }