SetBuiltinFunctions and SetBuiltinModules (#120)

* `SetBuiltinFunctions` and `SetBuiltinModules`

* nil implies no built in functions.

* Additional tests.

* Cleanup

* Updated SetBuiltinFunctions

* Docs updated.
This commit is contained in:
earncef 2019-02-28 17:26:25 +01:00 committed by Daniel
parent 39112d226e
commit 7cc683e867
9 changed files with 185 additions and 135 deletions

View file

@ -207,7 +207,7 @@ func runVM(bytecode *compiler.Bytecode) (time.Duration, objects.Object, error) {
start := time.Now()
v := runtime.NewVM(bytecode, globals, nil)
v := runtime.NewVM(bytecode, globals, nil, nil)
if err := v.Run(); err != nil {
return time.Since(start), nil, err
}

View file

@ -148,7 +148,7 @@ func compileAndRun(data []byte, inputFile string) (err error) {
return
}
machine := runtime.NewVM(bytecode, nil, nil)
machine := runtime.NewVM(bytecode, nil, nil, nil)
err = machine.Run()
if err != nil {
@ -165,7 +165,7 @@ func runCompiled(data []byte) (err error) {
return
}
machine := runtime.NewVM(bytecode, nil, nil)
machine := runtime.NewVM(bytecode, nil, nil, nil)
err = machine.Run()
if err != nil {
@ -216,7 +216,7 @@ func runREPL(in io.Reader, out io.Writer) {
bytecode := c.Bytecode()
machine := runtime.NewVM(bytecode, globals, nil)
machine := runtime.NewVM(bytecode, globals, nil, nil)
if err := machine.Run(); err != nil {
_, _ = fmt.Fprintln(out, err.Error())
continue

View file

@ -118,34 +118,38 @@ Users can add and use a custom user type in Tengo code by implementing [Object](
To securely compile and execute _potentially_ unsafe script code, you can use the following Script functions.
#### Script.DisableBuiltinFunction(name string)
#### Script.SetBuiltinFunctions(funcs []*objects.BuiltinFunction)
DisableBuiltinFunction disables and removes a builtin function from the compiler. Compiler will reports a compile-time error if the given name is referenced.
SetBuiltinFunctions resets all builtin functions in the compiler to the ones provided in the input parameter. Compiler will report a compile-time error if the a function not set is referenced. All builtin functions are included by default unless `SetBuiltinFunctions` is called.
```golang
s := script.New([]byte(`print([1, 2, 3])`))
s.DisableBuiltinFunction("print")
s.SetBuiltinFunctions(nil)
_, err := s.Run() // compile error
_, err := s.Run() // compile error
s.SetBuiltinFunctions([]*objects.BuiltinFunction{&objects.Builtins[0]})
_, err := s.Run() // prints [1, 2, 3]
```
Note that when a script is being added to another script as a module (via `Script.AddModule`), it does not inherit the disabled builtin function list from the main script.
#### Script.SetBuiltinModules(modules map[string]*objects.ImmutableMap)
#### Script.DisableStdModule(name string)
DisableStdModule disables a [standard library](https://github.com/d5/tengo/blob/master/docs/stdlib.md) module. Compile will report a compile-time error if the code tries to import the module with the given name.
SetBuiltinModules resets all [standard library](https://github.com/d5/tengo/blob/master/docs/stdlib.md) modules with modules provided in the input parameter. Compile will report a compile-time error if the code tries to import a module that hasn't been included. All standard library modules are included by default unless `SetBuiltinModules` is called.
```golang
s := script.New([]byte(`import("exec")`))
s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`))
s.DisableStdModule("exec")
s.SetBuiltinModules(nil)
_, err := s.Run() // compile error
_, err := s.Run() // compile error
s.SetBuiltinModules(map[string]*objects.ImmutableMap{"math": objectPtr(*stdlib.Modules["math"])})
_, err := s.Run() // a = 19.84
```
Note that when a script is being added to another script as a module (via `Script.AddModule`), it does not inherit the disabled standard module list from the main script.
#### Script.SetUserModuleLoader(loader compiler.ModuleLoader)
SetUserModuleLoader replaces the default user-module loader of the compiler, which tries to read the source from a local file.

View file

@ -1,135 +1,129 @@
package objects
// NamedBuiltinFunc is a named builtin function.
type NamedBuiltinFunc struct {
Name string
Func CallableFunc
}
// Builtins contains all default builtin functions.
var Builtins = []NamedBuiltinFunc{
var Builtins = []BuiltinFunction{
{
Name: "print",
Func: builtinPrint,
Name: "print",
Value: builtinPrint,
},
{
Name: "printf",
Func: builtinPrintf,
Name: "printf",
Value: builtinPrintf,
},
{
Name: "sprintf",
Func: builtinSprintf,
Name: "sprintf",
Value: builtinSprintf,
},
{
Name: "len",
Func: builtinLen,
Name: "len",
Value: builtinLen,
},
{
Name: "copy",
Func: builtinCopy,
Name: "copy",
Value: builtinCopy,
},
{
Name: "append",
Func: builtinAppend,
Name: "append",
Value: builtinAppend,
},
{
Name: "string",
Func: builtinString,
Name: "string",
Value: builtinString,
},
{
Name: "int",
Func: builtinInt,
Name: "int",
Value: builtinInt,
},
{
Name: "bool",
Func: builtinBool,
Name: "bool",
Value: builtinBool,
},
{
Name: "float",
Func: builtinFloat,
Name: "float",
Value: builtinFloat,
},
{
Name: "char",
Func: builtinChar,
Name: "char",
Value: builtinChar,
},
{
Name: "bytes",
Func: builtinBytes,
Name: "bytes",
Value: builtinBytes,
},
{
Name: "time",
Func: builtinTime,
Name: "time",
Value: builtinTime,
},
{
Name: "is_int",
Func: builtinIsInt,
Name: "is_int",
Value: builtinIsInt,
},
{
Name: "is_float",
Func: builtinIsFloat,
Name: "is_float",
Value: builtinIsFloat,
},
{
Name: "is_string",
Func: builtinIsString,
Name: "is_string",
Value: builtinIsString,
},
{
Name: "is_bool",
Func: builtinIsBool,
Name: "is_bool",
Value: builtinIsBool,
},
{
Name: "is_char",
Func: builtinIsChar,
Name: "is_char",
Value: builtinIsChar,
},
{
Name: "is_bytes",
Func: builtinIsBytes,
Name: "is_bytes",
Value: builtinIsBytes,
},
{
Name: "is_array",
Func: builtinIsArray,
Name: "is_array",
Value: builtinIsArray,
},
{
Name: "is_immutable_array",
Func: builtinIsImmutableArray,
Name: "is_immutable_array",
Value: builtinIsImmutableArray,
},
{
Name: "is_map",
Func: builtinIsMap,
Name: "is_map",
Value: builtinIsMap,
},
{
Name: "is_immutable_map",
Func: builtinIsImmutableMap,
Name: "is_immutable_map",
Value: builtinIsImmutableMap,
},
{
Name: "is_time",
Func: builtinIsTime,
Name: "is_time",
Value: builtinIsTime,
},
{
Name: "is_error",
Func: builtinIsError,
Name: "is_error",
Value: builtinIsError,
},
{
Name: "is_undefined",
Func: builtinIsUndefined,
Name: "is_undefined",
Value: builtinIsUndefined,
},
{
Name: "is_function",
Func: builtinIsFunction,
Name: "is_function",
Value: builtinIsFunction,
},
{
Name: "is_callable",
Func: builtinIsCallable,
Name: "is_callable",
Value: builtinIsCallable,
},
{
Name: "to_json",
Func: builtinToJSON,
Name: "to_json",
Value: builtinToJSON,
},
{
Name: "from_json",
Func: builtinFromJSON,
Name: "from_json",
Value: builtinFromJSON,
},
{
Name: "type_name",
Func: builtinTypeName,
Name: "type_name",
Value: builtinTypeName,
},
}

View file

@ -26,7 +26,6 @@ var (
truePtr = &objects.TrueValue
falsePtr = &objects.FalseValue
undefinedPtr = &objects.UndefinedValue
builtinFuncs []objects.Object
)
// VM is a virtual machine that executes the bytecode compiled by Compiler.
@ -43,11 +42,12 @@ type VM struct {
curIPLimit int
ip int
aborting int64
builtinFuncs []objects.Object
builtinModules map[string]*objects.Object
}
// NewVM creates a VM.
func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinModules map[string]*objects.Object) *VM {
func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinFuncs []objects.Object, builtinModules map[string]*objects.Object) *VM {
if globals == nil {
globals = make([]*objects.Object, GlobalsSize)
}
@ -56,6 +56,16 @@ func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinModule
builtinModules = stdlib.Modules
}
if builtinFuncs == nil {
builtinFuncs = make([]objects.Object, len(objects.Builtins))
for idx, fn := range objects.Builtins {
builtinFuncs[idx] = &objects.BuiltinFunction{
Name: fn.Name,
Value: fn.Value,
}
}
}
frames := make([]Frame, MaxFrames)
frames[0].fn = bytecode.MainFunction
frames[0].freeVars = nil
@ -74,6 +84,7 @@ func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object, builtinModule
curInsts: frames[0].fn.Instructions,
curIPLimit: len(frames[0].fn.Instructions) - 1,
ip: -1,
builtinFuncs: builtinFuncs,
builtinModules: builtinModules,
}
}
@ -1183,7 +1194,7 @@ mainloop:
break mainloop
}
v.stack[v.sp] = &builtinFuncs[builtinIndex]
v.stack[v.sp] = &v.builtinFuncs[builtinIndex]
v.sp++
case compiler.OpGetBuiltinModule:
@ -1412,13 +1423,3 @@ func indexAssign(dst, src *objects.Object, selectors []*objects.Object) error {
return nil
}
func init() {
builtinFuncs = make([]objects.Object, len(objects.Builtins))
for i, b := range objects.Builtins {
builtinFuncs[i] = &objects.BuiltinFunction{
Name: b.Name,
Value: b.Func,
}
}
}

View file

@ -240,7 +240,7 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object, userModu
trace = append(trace, fmt.Sprintf("\n[Compiled Constants]\n\n%s", strings.Join(bytecode.FormatConstants(), "\n")))
trace = append(trace, fmt.Sprintf("\n[Compiled Instructions]\n\n%s\n", strings.Join(bytecode.FormatInstructions(), "\n")))
v = runtime.NewVM(bytecode, globals, nil)
v = runtime.NewVM(bytecode, globals, nil, nil)
err = v.Run()
{

View file

@ -14,11 +14,11 @@ import (
// Script can simplify compilation and execution of embedded scripts.
type Script struct {
variables map[string]*Variable
removedBuiltins map[string]bool
removedStdModules map[string]bool
userModuleLoader compiler.ModuleLoader
input []byte
variables map[string]*Variable
builtinFuncs []objects.Object
builtinModules map[string]*objects.Object
userModuleLoader compiler.ModuleLoader
input []byte
}
// New creates a Script instance with an input script.
@ -56,22 +56,28 @@ func (s *Script) Remove(name string) bool {
return true
}
// DisableBuiltinFunction disables a builtin function.
func (s *Script) DisableBuiltinFunction(name string) {
if s.removedBuiltins == nil {
s.removedBuiltins = make(map[string]bool)
// SetBuiltinFunctions allows to define builtin functions.
func (s *Script) SetBuiltinFunctions(funcs []*objects.BuiltinFunction) {
if funcs != nil {
s.builtinFuncs = make([]objects.Object, len(funcs))
for idx, fn := range funcs {
s.builtinFuncs[idx] = fn
}
} else {
s.builtinFuncs = []objects.Object{}
}
s.removedBuiltins[name] = true
}
// DisableStdModule disables a standard library module.
func (s *Script) DisableStdModule(name string) {
if s.removedStdModules == nil {
s.removedStdModules = make(map[string]bool)
// SetBuiltinModules allows to define builtin modules.
func (s *Script) SetBuiltinModules(modules map[string]*objects.ImmutableMap) {
if modules != nil {
s.builtinModules = make(map[string]*objects.Object, len(modules))
for k, mod := range modules {
s.builtinModules[k] = objectPtr(mod)
}
} else {
s.builtinModules = map[string]*objects.Object{}
}
s.removedStdModules[name] = true
}
// SetUserModuleLoader sets the user module loader for the compiler.
@ -81,7 +87,7 @@ func (s *Script) SetUserModuleLoader(loader compiler.ModuleLoader) {
// Compile compiles the script with all the defined variables, and, returns Compiled object.
func (s *Script) Compile() (*Compiled, error) {
symbolTable, stdModules, globals, err := s.prepCompile()
symbolTable, builtinModules, globals, err := s.prepCompile()
if err != nil {
return nil, err
}
@ -95,7 +101,7 @@ func (s *Script) Compile() (*Compiled, error) {
return nil, err
}
c := compiler.NewCompiler(srcFile, symbolTable, nil, stdModules, nil)
c := compiler.NewCompiler(srcFile, symbolTable, nil, builtinModules, nil)
if s.userModuleLoader != nil {
c.SetModuleLoader(s.userModuleLoader)
@ -107,7 +113,7 @@ func (s *Script) Compile() (*Compiled, error) {
return &Compiled{
symbolTable: symbolTable,
machine: runtime.NewVM(c.Bytecode(), globals, nil),
machine: runtime.NewVM(c.Bytecode(), globals, s.builtinFuncs, s.builtinModules),
}, nil
}
@ -136,24 +142,36 @@ func (s *Script) RunContext(ctx context.Context) (compiled *Compiled, err error)
return
}
func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, stdModules map[string]bool, globals []*objects.Object, err error) {
func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, builtinModules map[string]bool, globals []*objects.Object, err error) {
var names []string
for name := range s.variables {
names = append(names, name)
}
symbolTable = compiler.NewSymbolTable()
for idx, fn := range objects.Builtins {
if !s.removedBuiltins[fn.Name] {
symbolTable.DefineBuiltin(idx, fn.Name)
if s.builtinFuncs == nil {
s.builtinFuncs = make([]objects.Object, len(objects.Builtins))
for idx, fn := range objects.Builtins {
s.builtinFuncs[idx] = &objects.BuiltinFunction{
Name: fn.Name,
Value: fn.Value,
}
}
}
stdModules = make(map[string]bool)
for name := range stdlib.Modules {
if !s.removedStdModules[name] {
stdModules[name] = true
}
if s.builtinModules == nil {
s.builtinModules = stdlib.Modules
}
for idx, fn := range s.builtinFuncs {
f := fn.(*objects.BuiltinFunction)
symbolTable.DefineBuiltin(idx, f.Name)
}
builtinModules = make(map[string]bool)
for name := range s.builtinModules {
builtinModules[name] = true
}
globals = make([]*objects.Object, runtime.GlobalsSize, runtime.GlobalsSize)
@ -178,3 +196,7 @@ func (s *Script) copyVariables() map[string]*Variable {
return vars
}
func objectPtr(o objects.Object) *objects.Object {
return &o
}

View file

@ -34,7 +34,7 @@ func TestScript_SetUserModuleLoader(t *testing.T) {
c, err = scr.Run()
assert.NoError(t, err)
assert.Equal(t, int64(3), c.Get("out").Value())
scr.DisableBuiltinFunction("len")
scr.SetBuiltinFunctions(nil)
_, err = scr.Run()
assert.Error(t, err)
@ -49,7 +49,7 @@ func TestScript_SetUserModuleLoader(t *testing.T) {
c, err = scr.Run()
assert.NoError(t, err)
assert.Equal(t, "Foo", c.Get("out").Value())
scr.DisableStdModule("text")
scr.SetBuiltinModules(nil)
_, err = scr.Run()
assert.Error(t, err)

View file

@ -4,7 +4,9 @@ import (
"testing"
"github.com/d5/tengo/assert"
"github.com/d5/tengo/objects"
"github.com/d5/tengo/script"
"github.com/d5/tengo/stdlib"
)
func TestScript_Add(t *testing.T) {
@ -37,24 +39,51 @@ func TestScript_Run(t *testing.T) {
compiledGet(t, c, "a", int64(5))
}
func TestScript_DisableBuiltinFunction(t *testing.T) {
func TestScript_SetBuiltinFunctions(t *testing.T) {
s := script.New([]byte(`a := len([1, 2, 3])`))
c, err := s.Run()
assert.NoError(t, err)
assert.NotNil(t, c)
compiledGet(t, c, "a", int64(3))
s.DisableBuiltinFunction("len")
s = script.New([]byte(`a := len([1, 2, 3])`))
s.SetBuiltinFunctions([]*objects.BuiltinFunction{&objects.Builtins[3]})
c, err = s.Run()
assert.NoError(t, err)
assert.NotNil(t, c)
compiledGet(t, c, "a", int64(3))
s.SetBuiltinFunctions([]*objects.BuiltinFunction{&objects.Builtins[0]})
_, err = s.Run()
assert.Error(t, err)
s.SetBuiltinFunctions(nil)
_, err = s.Run()
assert.Error(t, err)
}
func TestScript_DisableStdModule(t *testing.T) {
func TestScript_SetBuiltinModules(t *testing.T) {
s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`))
c, err := s.Run()
assert.NoError(t, err)
assert.NotNil(t, c)
compiledGet(t, c, "a", 19.84)
s.DisableStdModule("math")
s.SetBuiltinModules(map[string]*objects.ImmutableMap{"math": objectPtr(*stdlib.Modules["math"])})
c, err = s.Run()
assert.NoError(t, err)
assert.NotNil(t, c)
compiledGet(t, c, "a", 19.84)
s.SetBuiltinModules(map[string]*objects.ImmutableMap{"os": objectPtr(*stdlib.Modules["os"])})
_, err = s.Run()
assert.Error(t, err)
s.SetBuiltinModules(nil)
_, err = s.Run()
assert.Error(t, err)
}
func objectPtr(o objects.Object) *objects.ImmutableMap {
return o.(*objects.ImmutableMap)
}