add Script.DisableBuiltinFunction, Script.DisableStdModule, Script.SetUserModuleLoader functions
This commit is contained in:
parent
c816b705c1
commit
88ba20da4e
2 changed files with 89 additions and 6 deletions
script
|
@ -7,14 +7,18 @@ import (
|
|||
"github.com/d5/tengo/compiler"
|
||||
"github.com/d5/tengo/compiler/parser"
|
||||
"github.com/d5/tengo/compiler/source"
|
||||
"github.com/d5/tengo/compiler/stdlib"
|
||||
"github.com/d5/tengo/objects"
|
||||
"github.com/d5/tengo/runtime"
|
||||
)
|
||||
|
||||
// Script can simplify compilation and execution of embedded scripts.
|
||||
type Script struct {
|
||||
variables map[string]*Variable
|
||||
input []byte
|
||||
variables map[string]*Variable
|
||||
removedBuiltins map[string]bool
|
||||
removedStdModules map[string]bool
|
||||
userModuleLoader compiler.ModuleLoader
|
||||
input []byte
|
||||
}
|
||||
|
||||
// New creates a Script instance with an input script.
|
||||
|
@ -52,9 +56,32 @@ func (s *Script) Remove(name string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// DisableBuiltinFunction disables a builtin function.
|
||||
func (s *Script) DisableBuiltinFunction(name string) {
|
||||
if s.removedBuiltins == nil {
|
||||
s.removedBuiltins = make(map[string]bool)
|
||||
}
|
||||
|
||||
s.removedBuiltins[name] = true
|
||||
}
|
||||
|
||||
// DisableStdModule disables a standard library module.
|
||||
func (s *Script) DisableStdModule(name string) {
|
||||
if s.removedStdModules == nil {
|
||||
s.removedStdModules = make(map[string]bool)
|
||||
}
|
||||
|
||||
s.removedStdModules[name] = true
|
||||
}
|
||||
|
||||
// SetUserModuleLoader sets the user module loader for the compiler.
|
||||
func (s *Script) SetUserModuleLoader(loader compiler.ModuleLoader) {
|
||||
s.userModuleLoader = loader
|
||||
}
|
||||
|
||||
// Compile compiles the script with all the defined variables, and, returns Compiled object.
|
||||
func (s *Script) Compile() (*Compiled, error) {
|
||||
symbolTable, globals := s.prepCompile()
|
||||
symbolTable, stdModules, globals := s.prepCompile()
|
||||
|
||||
fileSet := source.NewFileSet()
|
||||
|
||||
|
@ -64,7 +91,12 @@ func (s *Script) Compile() (*Compiled, error) {
|
|||
return nil, fmt.Errorf("parse error: %s", err.Error())
|
||||
}
|
||||
|
||||
c := compiler.NewCompiler(symbolTable, nil, nil)
|
||||
c := compiler.NewCompiler(symbolTable, stdModules, nil)
|
||||
|
||||
if s.userModuleLoader != nil {
|
||||
c.SetModuleLoader(s.userModuleLoader)
|
||||
}
|
||||
|
||||
if err := c.Compile(file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -100,7 +132,7 @@ func (s *Script) RunContext(ctx context.Context) (compiled *Compiled, err error)
|
|||
return
|
||||
}
|
||||
|
||||
func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, globals []*objects.Object) {
|
||||
func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, stdModules map[string]*objects.ImmutableMap, globals []*objects.Object) {
|
||||
var names []string
|
||||
for name := range s.variables {
|
||||
names = append(names, name)
|
||||
|
@ -108,7 +140,16 @@ func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, globals []*ob
|
|||
|
||||
symbolTable = compiler.NewSymbolTable()
|
||||
for idx, fn := range objects.Builtins {
|
||||
symbolTable.DefineBuiltin(idx, fn.Name)
|
||||
if !s.removedBuiltins[fn.Name] {
|
||||
symbolTable.DefineBuiltin(idx, fn.Name)
|
||||
}
|
||||
}
|
||||
|
||||
stdModules = make(map[string]*objects.ImmutableMap)
|
||||
for name, mod := range stdlib.Modules {
|
||||
if !s.removedStdModules[name] {
|
||||
stdModules[name] = mod
|
||||
}
|
||||
}
|
||||
|
||||
globals = make([]*objects.Object, len(names), len(names))
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package script_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/d5/tengo/assert"
|
||||
|
@ -36,3 +37,44 @@ func TestScript_Run(t *testing.T) {
|
|||
assert.NotNil(t, c)
|
||||
compiledGet(t, c, "a", int64(5))
|
||||
}
|
||||
|
||||
func TestScript_DisableBuiltinFunction(t *testing.T) {
|
||||
s := script.New([]byte(`a := len([1, 2, 3])`))
|
||||
c, err := s.Run()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, c)
|
||||
compiledGet(t, c, "a", int64(3))
|
||||
s.DisableBuiltinFunction("len")
|
||||
_, err = s.Run()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestScript_DisableStdModule(t *testing.T) {
|
||||
s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`))
|
||||
c, err := s.Run()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, c)
|
||||
compiledGet(t, c, "a", 19.84)
|
||||
s.DisableStdModule("math")
|
||||
_, err = s.Run()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestScript_SetUserModuleLoader(t *testing.T) {
|
||||
s := script.New([]byte(`math := import("mod1"); a := math.foo()`))
|
||||
_, err := s.Run()
|
||||
assert.Error(t, err)
|
||||
s.SetUserModuleLoader(func(moduleName string) (res []byte, err error) {
|
||||
if moduleName == "mod1" {
|
||||
res = []byte(`foo := func() { return 5 }`)
|
||||
return
|
||||
}
|
||||
|
||||
err = errors.New("module not found")
|
||||
return
|
||||
})
|
||||
c, err := s.Run()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, c)
|
||||
compiledGet(t, c, "a", int64(5))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue