diff --git a/generate/generate.go b/generate/generate.go index 9e0d132..5e85502 100644 --- a/generate/generate.go +++ b/generate/generate.go @@ -595,6 +595,13 @@ func (this *Generator) generateDecodeValue(typ Type, typeName, valueSource, tagS this.pop() nn, err = this.iprintf("}\n") n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } nn, err = this.iprintf("buffer := make([]byte, int(%s))\n", lengthVar) n += nn; if err != nil { return n, err } nn, err = this.iprintf("nn, err = decoder.Read(buffer)\n") @@ -697,6 +704,13 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st lengthVar := this.newTemporaryVar("length") nn, err := this.iprintf("var %s uint64\n", lengthVar) n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()))\n", lengthVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck() @@ -766,6 +780,13 @@ func (this *Generator) generateDecodeBranch(hash [16]byte, typ Type, typeName st lengthVar := this.newTemporaryVar("length") nn, err := this.iprintf("var %s uint64\n", lengthVar) n += nn; if err != nil { return n, err } + nn, err = this.iprintf("if %s > uint64(tape.MaxStructureLength) {\n", lengthVar) + n += nn; if err != nil { return n, err } + this.push() + nn, err = this.iprintf("return n, tape.ErrTooLong\n") + this.pop() + nn, err = this.iprintf("}\n") + n += nn; if err != nil { return n, err } nn, err = this.iprintf("%s, nn, err = decoder.ReadUintN(int(tag.CN()))\n", lengthVar) n += nn; if err != nil { return n, err } nn, err = this.generateErrorCheck()