SetBuiltinFunctions and SetBuiltinModules (#120)

* `SetBuiltinFunctions` and `SetBuiltinModules`

* nil implies no built in functions.

* Additional tests.

* Cleanup

* Updated SetBuiltinFunctions

* Docs updated.
This commit is contained in:
earncef 2019-02-28 17:26:25 +01:00 committed by Daniel
parent 39112d226e
commit 7cc683e867
9 changed files with 185 additions and 135 deletions

View file

@ -207,7 +207,7 @@ func runVM(bytecode *compiler.Bytecode) (time.Duration, objects.Object, error) {
start := time.Now() start := time.Now()
v := runtime.NewVM(bytecode, globals, nil) v := runtime.NewVM(bytecode, globals, nil, nil)
if err := v.Run(); err != nil { if err := v.Run(); err != nil {
return time.Since(start), nil, err return time.Since(start), nil, err
} }

View file

@ -148,7 +148,7 @@ func compileAndRun(data []byte, inputFile string) (err error) {
return return
} }
machine := runtime.NewVM(bytecode, nil, nil) machine := runtime.NewVM(bytecode, nil, nil, nil)
err = machine.Run() err = machine.Run()
if err != nil { if err != nil {
@ -165,7 +165,7 @@ func runCompiled(data []byte) (err error) {
return return
} }
machine := runtime.NewVM(bytecode, nil, nil) machine := runtime.NewVM(bytecode, nil, nil, nil)
err = machine.Run() err = machine.Run()
if err != nil { if err != nil {
@ -216,7 +216,7 @@ func runREPL(in io.Reader, out io.Writer) {
bytecode := c.Bytecode() bytecode := c.Bytecode()
machine := runtime.NewVM(bytecode, globals, nil) machine := runtime.NewVM(bytecode, globals, nil, nil)
if err := machine.Run(); err != nil { if err := machine.Run(); err != nil {
_, _ = fmt.Fprintln(out, err.Error()) _, _ = fmt.Fprintln(out, err.Error())
continue continue

View file

@ -118,34 +118,38 @@ Users can add and use a custom user type in Tengo code by implementing [Object](
To securely compile and execute _potentially_ unsafe script code, you can use the following Script functions. To securely compile and execute _potentially_ unsafe script code, you can use the following Script functions.
#### Script.DisableBuiltinFunction(name string) #### Script.SetBuiltinFunctions(funcs []*objects.BuiltinFunction)
DisableBuiltinFunction disables and removes a builtin function from the compiler. Compiler will reports a compile-time error if the given name is referenced. SetBuiltinFunctions resets all builtin functions in the compiler to the ones provided in the input parameter. Compiler will report a compile-time error if the a function not set is referenced. All builtin functions are included by default unless `SetBuiltinFunctions` is called.
```golang ```golang
s := script.New([]byte(`print([1, 2, 3])`)) s := script.New([]byte(`print([1, 2, 3])`))
s.DisableBuiltinFunction("print") s.SetBuiltinFunctions(nil)
_, err := s.Run() // compile error _, err := s.Run() // compile error
s.SetBuiltinFunctions([]*objects.BuiltinFunction{&objects.Builtins[0]})
_, err := s.Run() // prints [1, 2, 3]
``` ```
Note that when a script is being added to another script as a module (via `Script.AddModule`), it does not inherit the disabled builtin function list from the main script. #### Script.SetBuiltinModules(modules map[string]*objects.ImmutableMap)
#### Script.DisableStdModule(name string) SetBuiltinModules resets all [standard library](https://github.com/d5/tengo/blob/master/docs/stdlib.md) modules with modules provided in the input parameter. Compile will report a compile-time error if the code tries to import a module that hasn't been included. All standard library modules are included by default unless `SetBuiltinModules` is called.
DisableStdModule disables a [standard library](https://github.com/d5/tengo/blob/master/docs/stdlib.md) module. Compile will report a compile-time error if the code tries to import the module with the given name.
```golang ```golang
s := script.New([]byte(`import("exec")`)) s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`))
s.DisableStdModule("exec") s.SetBuiltinModules(nil)
_, err := s.Run() // compile error _, err := s.Run() // compile error
s.SetBuiltinModules(map[string]*objects.ImmutableMap{"math": objectPtr(*stdlib.Modules["math"])})
_, err := s.Run() // a = 19.84
``` ```
Note that when a script is being added to another script as a module (via `Script.AddModule`), it does not inherit the disabled standard module list from the main script.
#### Script.SetUserModuleLoader(loader compiler.ModuleLoader) #### Script.SetUserModuleLoader(loader compiler.ModuleLoader)
SetUserModuleLoader replaces the default user-module loader of the compiler, which tries to read the source from a local file. SetUserModuleLoader replaces the default user-module loader of the compiler, which tries to read the source from a local file.

View file

@ -1,135 +1,129 @@
package objects package objects
// NamedBuiltinFunc is a named builtin function.
type NamedBuiltinFunc struct {
Name string
Func CallableFunc
}
// Builtins contains all default builtin functions. // Builtins contains all default builtin functions.
var Builtins = []NamedBuiltinFunc{ var Builtins = []BuiltinFunction{
{ {
Name: "print", Name: "print",
Func: builtinPrint, Value: builtinPrint,
}, },
{ {
Name: "printf", Name: "printf",
Func: builtinPrintf, Value: builtinPrintf,
}, },
{ {
Name: "sprintf", Name: "sprintf",
Func: builtinSprintf, Value: builtinSprintf,
}, },
{ {
Name: "len", Name: "len",
Func: builtinLen, Value: builtinLen,
}, },
{ {
Name: "copy", Name: "copy",
Func: builtinCopy, Value: builtinCopy,
}, },
{ {
Name: "append", Name: "append",
Func: builtinAppend, Value: builtinAppend,
}, },
{ {
Name: "string", Name: "string",
Func: builtinString, Value: builtinString,
}, },
{ {
Name: "int", Name: "int",
Func: builtinInt, Value: builtinInt,
}, },
{ {
Name: "bool", Name: "bool",
Func: builtinBool, Value: builtinBool,
}, },
{ {
Name: "float", Name: "float",
Func: builtinFloat, Value: builtinFloat,
}, },
{ {
Name: "char", Name: "char",
Func: builtinChar, Value: builtinChar,
}, },
{ {
Name: "bytes", Name: "bytes",
Func: builtinBytes, Value: builtinBytes,
}, },
{ {
Name: "time", Name: "time",
Func: builtinTime, Value: builtinTime,
}, },
{ {
Name: "is_int", Name: "is_int",
Func: builtinIsInt, Value: builtinIsInt,
}, },
{ {
Name: "is_float", Name: "is_float",
Func: builtinIsFloat, Value: builtinIsFloat,
}, },
{ {
Name: "is_string", Name: "is_string",
Func: builtinIsString, Value: builtinIsString,
}, },
{ {
Name: "is_bool", Name: "is_bool",
Func: builtinIsBool, Value: builtinIsBool,
}, },
{ {
Name: "is_char", Name: "is_char",
Func: builtinIsChar, Value: builtinIsChar,
}, },
{ {
Name: "is_bytes", Name: "is_bytes",
Func: builtinIsBytes, Value: builtinIsBytes,
}, },
{ {
Name: "is_array", Name: "is_array",
Func: builtinIsArray, Value: builtinIsArray,
}, },
{ {
Name: "is_immutable_array", Name: "is_immutable_array",
Func: builtinIsImmutableArray, Value: builtinIsImmutableArray,
}, },
{ {
Name: "is_map", Name: "is_map",
Func: builtinIsMap, Value: builtinIsMap,
}, },
{ {
Name: "is_immutable_map", Name: "is_immutable_map",
Func: builtinIsImmutableMap, Value: builtinIsImmutableMap,
}, },
{ {
Name: "is_time", Name: "is_time",
Func: builtinIsTime, Value: builtinIsTime,
}, },
{ {
Name: "is_error", Name: "is_error",
Func: builtinIsError, Value: builtinIsError,
}, },
{ {
Name: "is_undefined", Name: "is_undefined",
Func: builtinIsUndefined, Value: builtinIsUndefined,
}, },
{ {
Name: "is_function", Name: "is_function",
Func: builtinIsFunction, Value: builtinIsFunction,
}, },
{ {
Name: "is_callable", Name: "is_callable",
Func: builtinIsCallable, Value: builtinIsCallable,
}, },
{ {
Name: "to_json", Name: "to_json",
Func: builtinToJSON, Value: builtinToJSON,
}, },
{ {
Name: "from_json", Name: "from_json",
Func: builtinFromJSON, Value: builtinFromJSON,
}, },
{ {
Name: "type_name", Name: "type_name",
Func: builtinTypeName, Value: builtinTypeName,
}, },
} }

View file

@ -26,7 +26,6 @@ var (
truePtr = &objects.TrueValue truePtr = &objects.TrueValue
falsePtr = &objects.FalseValue falsePtr = &objects.FalseValue
undefinedPtr = &objects.UndefinedValue undefinedPtr = &objects.UndefinedValue
builtinFuncs []objects.Object
) )
// VM is a virtual machine that executes the bytecode compiled by Compiler. // VM is a virtual machine that executes the bytecode compiled by Compiler.
@ -43,11 +42,12 @@ type VM struct {
curIPLimit int curIPLimit int
ip int ip int
aborting int64 aborting int64
builtinFuncs []objects.Object
builtinModules map[string]*objects.Object builtinModules map[string]*objects.Object
} }
// NewVM creates a VM. // NewVM creates a VM.
func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinModules map[string]*objects.Object) *VM { func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinFuncs []objects.Object, builtinModules map[string]*objects.Object) *VM {
if globals == nil { if globals == nil {
globals = make([]*objects.Object, GlobalsSize) globals = make([]*objects.Object, GlobalsSize)
} }
@ -56,6 +56,16 @@ func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinModule
builtinModules = stdlib.Modules builtinModules = stdlib.Modules
} }
if builtinFuncs == nil {
builtinFuncs = make([]objects.Object, len(objects.Builtins))
for idx, fn := range objects.Builtins {
builtinFuncs[idx] = &objects.BuiltinFunction{
Name: fn.Name,
Value: fn.Value,
}
}
}
frames := make([]Frame, MaxFrames) frames := make([]Frame, MaxFrames)
frames[0].fn = bytecode.MainFunction frames[0].fn = bytecode.MainFunction
frames[0].freeVars = nil frames[0].freeVars = nil
@ -74,6 +84,7 @@ func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinModule
curInsts: frames[0].fn.Instructions, curInsts: frames[0].fn.Instructions,
curIPLimit: len(frames[0].fn.Instructions) - 1, curIPLimit: len(frames[0].fn.Instructions) - 1,
ip: -1, ip: -1,
builtinFuncs: builtinFuncs,
builtinModules: builtinModules, builtinModules: builtinModules,
} }
} }
@ -1183,7 +1194,7 @@ mainloop:
break mainloop break mainloop
} }
v.stack[v.sp] = &builtinFuncs[builtinIndex] v.stack[v.sp] = &v.builtinFuncs[builtinIndex]
v.sp++ v.sp++
case compiler.OpGetBuiltinModule: case compiler.OpGetBuiltinModule:
@ -1412,13 +1423,3 @@ func indexAssign(dst, src *objects.Object, selectors []*objects.Object) error {
return nil return nil
} }
func init() {
builtinFuncs = make([]objects.Object, len(objects.Builtins))
for i, b := range objects.Builtins {
builtinFuncs[i] = &objects.BuiltinFunction{
Name: b.Name,
Value: b.Func,
}
}
}

View file

@ -240,7 +240,7 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object, userModu
trace = append(trace, fmt.Sprintf("\n[Compiled Constants]\n\n%s", strings.Join(bytecode.FormatConstants(), "\n"))) trace = append(trace, fmt.Sprintf("\n[Compiled Constants]\n\n%s", strings.Join(bytecode.FormatConstants(), "\n")))
trace = append(trace, fmt.Sprintf("\n[Compiled Instructions]\n\n%s\n", strings.Join(bytecode.FormatInstructions(), "\n"))) trace = append(trace, fmt.Sprintf("\n[Compiled Instructions]\n\n%s\n", strings.Join(bytecode.FormatInstructions(), "\n")))
v = runtime.NewVM(bytecode, globals, nil) v = runtime.NewVM(bytecode, globals, nil, nil)
err = v.Run() err = v.Run()
{ {

View file

@ -14,11 +14,11 @@ import (
// Script can simplify compilation and execution of embedded scripts. // Script can simplify compilation and execution of embedded scripts.
type Script struct { type Script struct {
variables map[string]*Variable variables map[string]*Variable
removedBuiltins map[string]bool builtinFuncs []objects.Object
removedStdModules map[string]bool builtinModules map[string]*objects.Object
userModuleLoader compiler.ModuleLoader userModuleLoader compiler.ModuleLoader
input []byte input []byte
} }
// New creates a Script instance with an input script. // New creates a Script instance with an input script.
@ -56,22 +56,28 @@ func (s *Script) Remove(name string) bool {
return true return true
} }
// DisableBuiltinFunction disables a builtin function. // SetBuiltinFunctions allows to define builtin functions.
func (s *Script) DisableBuiltinFunction(name string) { func (s *Script) SetBuiltinFunctions(funcs []*objects.BuiltinFunction) {
if s.removedBuiltins == nil { if funcs != nil {
s.removedBuiltins = make(map[string]bool) s.builtinFuncs = make([]objects.Object, len(funcs))
for idx, fn := range funcs {
s.builtinFuncs[idx] = fn
}
} else {
s.builtinFuncs = []objects.Object{}
} }
s.removedBuiltins[name] = true
} }
// DisableStdModule disables a standard library module. // SetBuiltinModules allows to define builtin modules.
func (s *Script) DisableStdModule(name string) { func (s *Script) SetBuiltinModules(modules map[string]*objects.ImmutableMap) {
if s.removedStdModules == nil { if modules != nil {
s.removedStdModules = make(map[string]bool) s.builtinModules = make(map[string]*objects.Object, len(modules))
for k, mod := range modules {
s.builtinModules[k] = objectPtr(mod)
}
} else {
s.builtinModules = map[string]*objects.Object{}
} }
s.removedStdModules[name] = true
} }
// SetUserModuleLoader sets the user module loader for the compiler. // SetUserModuleLoader sets the user module loader for the compiler.
@ -81,7 +87,7 @@ func (s *Script) SetUserModuleLoader(loader compiler.ModuleLoader) {
// Compile compiles the script with all the defined variables, and, returns Compiled object. // Compile compiles the script with all the defined variables, and, returns Compiled object.
func (s *Script) Compile() (*Compiled, error) { func (s *Script) Compile() (*Compiled, error) {
symbolTable, stdModules, globals, err := s.prepCompile() symbolTable, builtinModules, globals, err := s.prepCompile()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -95,7 +101,7 @@ func (s *Script) Compile() (*Compiled, error) {
return nil, err return nil, err
} }
c := compiler.NewCompiler(srcFile, symbolTable, nil, stdModules, nil) c := compiler.NewCompiler(srcFile, symbolTable, nil, builtinModules, nil)
if s.userModuleLoader != nil { if s.userModuleLoader != nil {
c.SetModuleLoader(s.userModuleLoader) c.SetModuleLoader(s.userModuleLoader)
@ -107,7 +113,7 @@ func (s *Script) Compile() (*Compiled, error) {
return &Compiled{ return &Compiled{
symbolTable: symbolTable, symbolTable: symbolTable,
machine: runtime.NewVM(c.Bytecode(), globals, nil), machine: runtime.NewVM(c.Bytecode(), globals, s.builtinFuncs, s.builtinModules),
}, nil }, nil
} }
@ -136,24 +142,36 @@ func (s *Script) RunContext(ctx context.Context) (compiled *Compiled, err error)
return return
} }
func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, stdModules map[string]bool, globals []*objects.Object, err error) { func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, builtinModules map[string]bool, globals []*objects.Object, err error) {
var names []string var names []string
for name := range s.variables { for name := range s.variables {
names = append(names, name) names = append(names, name)
} }
symbolTable = compiler.NewSymbolTable() symbolTable = compiler.NewSymbolTable()
for idx, fn := range objects.Builtins {
if !s.removedBuiltins[fn.Name] { if s.builtinFuncs == nil {
symbolTable.DefineBuiltin(idx, fn.Name) s.builtinFuncs = make([]objects.Object, len(objects.Builtins))
for idx, fn := range objects.Builtins {
s.builtinFuncs[idx] = &objects.BuiltinFunction{
Name: fn.Name,
Value: fn.Value,
}
} }
} }
stdModules = make(map[string]bool) if s.builtinModules == nil {
for name := range stdlib.Modules { s.builtinModules = stdlib.Modules
if !s.removedStdModules[name] { }
stdModules[name] = true
} for idx, fn := range s.builtinFuncs {
f := fn.(*objects.BuiltinFunction)
symbolTable.DefineBuiltin(idx, f.Name)
}
builtinModules = make(map[string]bool)
for name := range s.builtinModules {
builtinModules[name] = true
} }
globals = make([]*objects.Object, runtime.GlobalsSize, runtime.GlobalsSize) globals = make([]*objects.Object, runtime.GlobalsSize, runtime.GlobalsSize)
@ -178,3 +196,7 @@ func (s *Script) copyVariables() map[string]*Variable {
return vars return vars
} }
func objectPtr(o objects.Object) *objects.Object {
return &o
}

View file

@ -34,7 +34,7 @@ func TestScript_SetUserModuleLoader(t *testing.T) {
c, err = scr.Run() c, err = scr.Run()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int64(3), c.Get("out").Value()) assert.Equal(t, int64(3), c.Get("out").Value())
scr.DisableBuiltinFunction("len") scr.SetBuiltinFunctions(nil)
_, err = scr.Run() _, err = scr.Run()
assert.Error(t, err) assert.Error(t, err)
@ -49,7 +49,7 @@ func TestScript_SetUserModuleLoader(t *testing.T) {
c, err = scr.Run() c, err = scr.Run()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "Foo", c.Get("out").Value()) assert.Equal(t, "Foo", c.Get("out").Value())
scr.DisableStdModule("text") scr.SetBuiltinModules(nil)
_, err = scr.Run() _, err = scr.Run()
assert.Error(t, err) assert.Error(t, err)

View file

@ -4,7 +4,9 @@ import (
"testing" "testing"
"github.com/d5/tengo/assert" "github.com/d5/tengo/assert"
"github.com/d5/tengo/objects"
"github.com/d5/tengo/script" "github.com/d5/tengo/script"
"github.com/d5/tengo/stdlib"
) )
func TestScript_Add(t *testing.T) { func TestScript_Add(t *testing.T) {
@ -37,24 +39,51 @@ func TestScript_Run(t *testing.T) {
compiledGet(t, c, "a", int64(5)) compiledGet(t, c, "a", int64(5))
} }
func TestScript_DisableBuiltinFunction(t *testing.T) { func TestScript_SetBuiltinFunctions(t *testing.T) {
s := script.New([]byte(`a := len([1, 2, 3])`)) s := script.New([]byte(`a := len([1, 2, 3])`))
c, err := s.Run() c, err := s.Run()
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)
compiledGet(t, c, "a", int64(3)) compiledGet(t, c, "a", int64(3))
s.DisableBuiltinFunction("len")
s = script.New([]byte(`a := len([1, 2, 3])`))
s.SetBuiltinFunctions([]*objects.BuiltinFunction{&objects.Builtins[3]})
c, err = s.Run()
assert.NoError(t, err)
assert.NotNil(t, c)
compiledGet(t, c, "a", int64(3))
s.SetBuiltinFunctions([]*objects.BuiltinFunction{&objects.Builtins[0]})
_, err = s.Run()
assert.Error(t, err)
s.SetBuiltinFunctions(nil)
_, err = s.Run() _, err = s.Run()
assert.Error(t, err) assert.Error(t, err)
} }
func TestScript_DisableStdModule(t *testing.T) { func TestScript_SetBuiltinModules(t *testing.T) {
s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`)) s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`))
c, err := s.Run() c, err := s.Run()
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)
compiledGet(t, c, "a", 19.84) compiledGet(t, c, "a", 19.84)
s.DisableStdModule("math")
s.SetBuiltinModules(map[string]*objects.ImmutableMap{"math": objectPtr(*stdlib.Modules["math"])})
c, err = s.Run()
assert.NoError(t, err)
assert.NotNil(t, c)
compiledGet(t, c, "a", 19.84)
s.SetBuiltinModules(map[string]*objects.ImmutableMap{"os": objectPtr(*stdlib.Modules["os"])})
_, err = s.Run()
assert.Error(t, err)
s.SetBuiltinModules(nil)
_, err = s.Run() _, err = s.Run()
assert.Error(t, err) assert.Error(t, err)
} }
func objectPtr(o objects.Object) *objects.ImmutableMap {
return o.(*objects.ImmutableMap)
}