diff --git a/vm.go b/vm.go index 783a54a..af8783f 100644 --- a/vm.go +++ b/vm.go @@ -80,14 +80,14 @@ func (v *VM) Run() (err error) { if err != nil { filePos := v.fileSet.Position( v.curFrame.fn.SourcePos(v.ip - 1)) - err = fmt.Errorf("Runtime Error: %s\n\tat %s", - err.Error(), filePos) + err = fmt.Errorf("Runtime Error: %w\n\tat %s", + err, filePos) for v.framesIndex > 1 { v.framesIndex-- v.curFrame = &v.frames[v.framesIndex-1] filePos = v.fileSet.Position( v.curFrame.fn.SourcePos(v.curFrame.ip - 1)) - err = fmt.Errorf("%s\n\tat %s", err.Error(), filePos) + err = fmt.Errorf("%w\n\tat %s", err, filePos) } return err } diff --git a/vm_test.go b/vm_test.go index 60e17b1..c26bf0c 100644 --- a/vm_test.go +++ b/vm_test.go @@ -91,6 +91,19 @@ func (o *testopts) Skip2ndPass() *testopts { return c } +type customError struct { + err error + str string +} + +func (c *customError) Error() string { + return c.str +} + +func (c *customError) Unwrap() error { + return c.err +} + func TestArray(t *testing.T) { expectRun(t, `out = [1, 2 * 2, 3 + 3]`, nil, ARR{1, 4, 6}) @@ -912,6 +925,82 @@ export func() { }`), "Runtime Error: invalid operation: int + string\n\tat mod2:4:9") } +func TestVMErrorUnwrap(t *testing.T) { + userErr := errors.New("user runtime error") + userFunc := func(err error) *tengo.UserFunction { + return &tengo.UserFunction{Name: "user_func", Value: func(args ...tengo.Object) (tengo.Object, error) { + return nil, err + }} + } + userModule := func(err error) *tengo.BuiltinModule { + return &tengo.BuiltinModule{ + Attrs: map[string]tengo.Object{ + "afunction": &tengo.UserFunction{ + Name: "afunction", + Value: func(a ...tengo.Object) (tengo.Object, error) { + return nil, err + }, + }, + }, + } + } + + expectError(t, `user_func()`, + Opts().Symbol("user_func", userFunc(userErr)), + "Runtime Error: "+userErr.Error(), + ) + expectErrorIs(t, `user_func()`, + Opts().Symbol("user_func", userFunc(userErr)), + userErr, + ) + + wrapUserErr := &customError{err: userErr, str: "custom error"} + + expectErrorIs(t, `user_func()`, + Opts().Symbol("user_func", userFunc(wrapUserErr)), + wrapUserErr, + ) + expectErrorIs(t, `user_func()`, + Opts().Symbol("user_func", userFunc(wrapUserErr)), + userErr, + ) + var asErr1 *customError + expectErrorAs(t, `user_func()`, + Opts().Symbol("user_func", userFunc(wrapUserErr)), + &asErr1, + ) + require.True(t, asErr1.Error() == wrapUserErr.Error(), + "expected error as:%v, got:%v", wrapUserErr, asErr1) + + expectError(t, `import("mod1").afunction()`, + Opts().Module("mod1", userModule(userErr)), + "Runtime Error: "+userErr.Error(), + ) + expectErrorIs(t, `import("mod1").afunction()`, + Opts().Module("mod1", userModule(userErr)), + userErr, + ) + expectError(t, `import("mod1").afunction()`, + Opts().Module("mod1", userModule(wrapUserErr)), + "Runtime Error: "+wrapUserErr.Error(), + ) + expectErrorIs(t, `import("mod1").afunction()`, + Opts().Module("mod1", userModule(wrapUserErr)), + wrapUserErr, + ) + expectErrorIs(t, `import("mod1").afunction()`, + Opts().Module("mod1", userModule(wrapUserErr)), + userErr, + ) + var asErr2 *customError + expectErrorAs(t, `import("mod1").afunction()`, + Opts().Module("mod1", userModule(wrapUserErr)), + &asErr2, + ) + require.True(t, asErr2.Error() == wrapUserErr.Error(), + "expected error as:%v, got:%v", wrapUserErr, asErr2) +} + func TestError(t *testing.T) { expectRun(t, `out = error(1)`, nil, errorObject(1)) expectRun(t, `out = error(1).value`, nil, 1) @@ -3266,6 +3355,60 @@ func expectError( expected, err.Error(), strings.Join(trace, "\n")) } +func expectErrorIs( + t *testing.T, + input string, + opts *testopts, + expected error, +) { + if opts == nil { + opts = Opts() + } + symbols := opts.symbols + modules := opts.modules + maxAllocs := opts.maxAllocs + + // parse + program := parse(t, input) + if program == nil { + return + } + + // compiler/VM + _, trace, err := traceCompileRun(program, symbols, modules, maxAllocs) + require.Error(t, err, "\n"+strings.Join(trace, "\n")) + require.True(t, errors.Is(err, expected), + "expected error is: %s, got: %s\n%s", + expected.Error(), err.Error(), strings.Join(trace, "\n")) +} + +func expectErrorAs( + t *testing.T, + input string, + opts *testopts, + expected interface{}, +) { + if opts == nil { + opts = Opts() + } + symbols := opts.symbols + modules := opts.modules + maxAllocs := opts.maxAllocs + + // parse + program := parse(t, input) + if program == nil { + return + } + + // compiler/VM + _, trace, err := traceCompileRun(program, symbols, modules, maxAllocs) + require.Error(t, err, "\n"+strings.Join(trace, "\n")) + require.True(t, errors.As(err, expected), + "expected error as: %v, got: %v\n%s", + expected, err, strings.Join(trace, "\n")) +} + type vmTracer struct { Out []string }