From 01fe30f02a67b13d3a79e0ce0e930cd864707849 Mon Sep 17 00:00:00 2001 From: Daniel Date: Sat, 23 Mar 2019 12:59:54 -0700 Subject: [PATCH] bug fix for function return and if statement (#160) * fix a bug where function return instruction is missing with if statement * remove script cancel context test --- compiler/compiler.go | 39 ++++++++++++++++++++++++++++++++++--- compiler/compiler_module.go | 5 +---- compiler/instructions.go | 13 +++++++++++++ runtime/vm_function_test.go | 11 +++++++++++ runtime/vm_if_test.go | 4 ++++ runtime/vm_module_test.go | 8 ++++++++ script/compiled_test.go | 12 +----------- 7 files changed, 74 insertions(+), 18 deletions(-) diff --git a/compiler/compiler.go b/compiler/compiler.go index b00fef6..3345cd8 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -405,9 +405,7 @@ func (c *Compiler) Compile(node ast.Node) error { } // add OpReturn if function returns nothing - if !c.lastInstructionIs(OpReturn) { - c.emit(node, OpReturn, 0) - } + c.fixReturn(node) freeSymbols := c.symbolTable.FreeSymbols() numLocals := c.symbolTable.MaxSymbols() @@ -733,6 +731,41 @@ func (c *Compiler) changeOperand(opPos int, operand ...int) { c.replaceInstruction(opPos, inst) } +// fixReturn appends "return" statement at the end of the function if +// 1) the function does not have a "return" statement at the end. +// 2) or, there are jump instructions that jump to the end of the function. +func (c *Compiler) fixReturn(node ast.Node) { + var appendReturn bool + + if !c.lastInstructionIs(OpReturn) { + appendReturn = true + } else { + var lastOp Opcode + insts := c.scopes[c.scopeIndex].instructions + endPos := len(insts) + iterateInstructions(insts, func(pos int, opcode Opcode, operands []int) bool { + defer func() { lastOp = opcode }() + + switch opcode { + case OpJump, OpJumpFalsy, OpAndJump, OpOrJump: + dst := operands[0] + if dst == endPos && lastOp != OpReturn { + appendReturn = true + return false + } else if dst > endPos { + panic(fmt.Errorf("wrong jump position: %d (end: %d)", dst, endPos)) + } + } + + return true + }) + } + + if appendReturn { + c.emit(node, OpReturn, 0) + } +} + func (c *Compiler) emit(node ast.Node, opcode Opcode, operands ...int) int { filePos := source.NoPos if node != nil { diff --git a/compiler/compiler_module.go b/compiler/compiler_module.go index 55590e0..2c2bb5a 100644 --- a/compiler/compiler_module.go +++ b/compiler/compiler_module.go @@ -49,10 +49,7 @@ func (c *Compiler) compileModule(node ast.Node, moduleName, modulePath string, s return nil, err } - // add OpReturn (== export undefined) if export is missing - if !moduleCompiler.lastInstructionIs(OpReturn) { - moduleCompiler.emit(nil, OpReturn) - } + moduleCompiler.fixReturn(node) compiledFunc := moduleCompiler.Bytecode().MainFunction compiledFunc.NumLocals = symbolTable.MaxSymbols() diff --git a/compiler/instructions.go b/compiler/instructions.go index b04b282..80c88d1 100644 --- a/compiler/instructions.go +++ b/compiler/instructions.go @@ -57,3 +57,16 @@ func FormatInstructions(b []byte, posOffset int) []string { return out } + +func iterateInstructions(b []byte, fn func(pos int, opcode Opcode, operands []int) bool) { + for i := 0; i < len(b); i++ { + numOperands := OpcodeOperands[Opcode(b[i])] + operands, read := ReadOperands(numOperands, b[i+1:]) + + if !fn(i, b[i], operands) { + break + } + + i += read + } +} diff --git a/runtime/vm_function_test.go b/runtime/vm_function_test.go index f4c25f1..696ee4b 100644 --- a/runtime/vm_function_test.go +++ b/runtime/vm_function_test.go @@ -253,4 +253,15 @@ func() { }() }() }()`, nil, 15) + + // function skipping return + expect(t, `out = func() {}()`, nil, objects.UndefinedValue) + expect(t, `out = func(v) { if v { return true } }(1)`, nil, true) + expect(t, `out = func(v) { if v { return true } }(0)`, nil, objects.UndefinedValue) + expect(t, `out = func(v) { if v { } else { return true } }(1)`, nil, objects.UndefinedValue) + expect(t, `out = func(v) { if v { return } }(1)`, nil, objects.UndefinedValue) + expect(t, `out = func(v) { if v { return } }(0)`, nil, objects.UndefinedValue) + expect(t, `out = func(v) { if v { } else { return } }(1)`, nil, objects.UndefinedValue) + expect(t, `out = func(v) { for ;;v++ { if v == 3 { return true } } }(1)`, nil, true) + expect(t, `out = func(v) { for ;;v++ { if v == 3 { break } } }(1)`, nil, objects.UndefinedValue) } diff --git a/runtime/vm_if_test.go b/runtime/vm_if_test.go index 864d1de..5dd392d 100644 --- a/runtime/vm_if_test.go +++ b/runtime/vm_if_test.go @@ -61,4 +61,8 @@ func() { out = a }() `, nil, 3) + + // expression statement in init (should not leave objects on stack) + expect(t, `a := 1; if a; a { out = a }`, nil, 1) + expect(t, `a := 1; if a + 4; a { out = a }`, nil, 1) } diff --git a/runtime/vm_module_test.go b/runtime/vm_module_test.go index 6310ac1..97a355d 100644 --- a/runtime/vm_module_test.go +++ b/runtime/vm_module_test.go @@ -156,6 +156,14 @@ export func(a) { a() } `), "Runtime Error: not callable: int\n\tat mod1:3:4\n\tat test:4:1") + + // module skipping export + expect(t, `out = import("mod0")`, Opts().Module("mod0", ``), objects.UndefinedValue) + expect(t, `out = import("mod0")`, Opts().Module("mod0", `if 1 { export true }`), true) + expect(t, `out = import("mod0")`, Opts().Module("mod0", `if 0 { export true }`), objects.UndefinedValue) + expect(t, `out = import("mod0")`, Opts().Module("mod0", `if 1 { } else { export true }`), objects.UndefinedValue) + expect(t, `out = import("mod0")`, Opts().Module("mod0", `for v:=0;;v++ { if v == 3 { export true } } }`), true) + expect(t, `out = import("mod0")`, Opts().Module("mod0", `for v:=0;;v++ { if v == 3 { break } } }`), objects.UndefinedValue) } func TestModuleBlockScopes(t *testing.T) { diff --git a/script/compiled_test.go b/script/compiled_test.go index e4f3084..13d2d76 100644 --- a/script/compiled_test.go +++ b/script/compiled_test.go @@ -84,19 +84,9 @@ func TestCompiled_RunContext(t *testing.T) { assert.NoError(t, err) compiledGet(t, c, "a", int64(5)) - // cancelled - c = compile(t, `for true {}`, nil) - ctx, cancel := context.WithCancel(context.Background()) - go func() { - time.Sleep(1 * time.Millisecond) - cancel() - }() - err = c.RunContext(ctx) - assert.Equal(t, context.Canceled, err) - // timeout c = compile(t, `for true {}`, nil) - ctx, cancel = context.WithTimeout(context.Background(), 1*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) defer cancel() err = c.RunContext(ctx) assert.Equal(t, context.DeadlineExceeded, err)