fix a bug in tail-call optimization code

This commit is contained in:
Daniel Kang 2019-01-13 02:24:32 -08:00
parent adbbf419a0
commit 967ed03ccc
3 changed files with 23 additions and 3 deletions

View file

@ -894,6 +894,8 @@ func (v *VM) callFunction(fn *objects.CompiledFunction, freeVars []*objects.Obje
// |--------|
// | *ARG1 | for next function (tail-call)
// |--------|
// | FUNC | function itself
// |--------|
// | LOCAL3 | for current function
// |--------|
// | LOCAL2 | for current function
@ -904,7 +906,7 @@ func (v *VM) callFunction(fn *objects.CompiledFunction, freeVars []*objects.Obje
// |--------|
copy(v.stack[curFrame.basePointer:], v.stack[v.sp-numArgs:v.sp])
v.sp -= numArgs
v.sp -= numArgs + 1
curFrame.ip = -1
// stack after tail-call
@ -914,7 +916,9 @@ func (v *VM) callFunction(fn *objects.CompiledFunction, freeVars []*objects.Obje
// |--------|
// | *ARG2 |
// |--------|
// | *ARG1 | <- SP current
// | *ARG1 |
// |--------|
// | FUNC | <- SP current
// |--------|
// | LOCAL3 | for current function
// |--------|

View file

@ -66,6 +66,19 @@ func TestTailCall(t *testing.T) {
return f2(5, 0)
}
out = f1()`, 15)
// tail-call replacing loop
// without tail-call optimization, this code will cause stack overflow
expect(t, `
iter := func(n, max) {
if n == max {
return n
}
return iter(n+1, max)
}
out = iter(0, 9999)
`, 9999)
}
// tail call with free vars

View file

@ -148,7 +148,7 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object) (res map
var ipstr string
if v != nil {
frameIdx, ip := v.FrameInfo()
ipstr = fmt.Sprintf("\n (Frame=%d, IP=%d)", frameIdx, ip)
ipstr = fmt.Sprintf("\n (Frame=%d, IP=%d)", frameIdx, ip+1)
}
trace = append(trace, fmt.Sprintf("[Panic]\n\n %v%s\n", e, ipstr))
trace = append(trace, fmt.Sprintf("[Error Trace]\n\n %s\n", strings.Join(stackTrace, "\n ")))
@ -216,6 +216,9 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object) (res map
}
}
trace = append(trace, fmt.Sprintf("\n[Globals]\n\n%s", strings.Join(globalsStr, "\n")))
frameIdx, ip := v.FrameInfo()
trace = append(trace, fmt.Sprintf("\n[IP]\n\nFrame=%d, IP=%d", frameIdx, ip+1))
}
if err != nil {
return