fix bytecode encoding/decoding of builtin modules (#154)

* fix bytecode encoding/decoding of builtin modules

* Bytecode.Decode() to take map[string]objects.Importable

* add objects.ModuleMap

* update docs

* stdlib.GetModuleMap()
This commit is contained in:
Daniel 2019-03-20 01:28:40 -07:00 committed by GitHub
parent e785e38bf8
commit 3c30109cd0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 308 additions and 140 deletions

View file

@ -41,7 +41,7 @@ type Options struct {
Version string Version string
// Import modules // Import modules
Modules map[string]objects.Importable Modules *objects.ModuleMap
} }
// Run CLI // Run CLI
@ -56,7 +56,7 @@ func Run(options *Options) {
if options.InputFile == "" { if options.InputFile == "" {
// REPL // REPL
runREPL(options.Modules, os.Stdin, os.Stdout) RunREPL(options.Modules, os.Stdin, os.Stdout)
return return
} }
@ -67,17 +67,17 @@ func Run(options *Options) {
} }
if options.CompileOutput != "" { if options.CompileOutput != "" {
if err := compileOnly(options.Modules, inputData, options.InputFile, options.CompileOutput); err != nil { if err := CompileOnly(options.Modules, inputData, options.InputFile, options.CompileOutput); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error()) _, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)
} }
} else if filepath.Ext(options.InputFile) == sourceFileExt { } else if filepath.Ext(options.InputFile) == sourceFileExt {
if err := compileAndRun(options.Modules, inputData, options.InputFile); err != nil { if err := CompileAndRun(options.Modules, inputData, options.InputFile); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error()) _, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)
} }
} else { } else {
if err := runCompiled(inputData); err != nil { if err := RunCompiled(options.Modules, inputData); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error()) _, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1) os.Exit(1)
} }
@ -116,7 +116,8 @@ func doHelp() {
fmt.Println() fmt.Println()
} }
func compileOnly(modules map[string]objects.Importable, data []byte, inputFile, outputFile string) (err error) { // CompileOnly compiles the source code and writes the compiled binary into outputFile.
func CompileOnly(modules *objects.ModuleMap, data []byte, inputFile, outputFile string) (err error) {
bytecode, err := compileSrc(modules, data, filepath.Base(inputFile)) bytecode, err := compileSrc(modules, data, filepath.Base(inputFile))
if err != nil { if err != nil {
return return
@ -148,7 +149,8 @@ func compileOnly(modules map[string]objects.Importable, data []byte, inputFile,
return return
} }
func compileAndRun(modules map[string]objects.Importable, data []byte, inputFile string) (err error) { // CompileAndRun compiles the source code and executes it.
func CompileAndRun(modules *objects.ModuleMap, data []byte, inputFile string) (err error) {
bytecode, err := compileSrc(modules, data, filepath.Base(inputFile)) bytecode, err := compileSrc(modules, data, filepath.Base(inputFile))
if err != nil { if err != nil {
return return
@ -164,9 +166,10 @@ func compileAndRun(modules map[string]objects.Importable, data []byte, inputFile
return return
} }
func runCompiled(data []byte) (err error) { // RunCompiled reads the compiled binary from file and executes it.
func RunCompiled(modules *objects.ModuleMap, data []byte) (err error) {
bytecode := &compiler.Bytecode{} bytecode := &compiler.Bytecode{}
err = bytecode.Decode(bytes.NewReader(data)) err = bytecode.Decode(bytes.NewReader(data), modules)
if err != nil { if err != nil {
return return
} }
@ -181,7 +184,8 @@ func runCompiled(data []byte) (err error) {
return return
} }
func runREPL(modules map[string]objects.Importable, in io.Reader, out io.Writer) { // RunREPL starts REPL.
func RunREPL(modules *objects.ModuleMap, in io.Reader, out io.Writer) {
stdin := bufio.NewScanner(in) stdin := bufio.NewScanner(in)
fileSet := source.NewFileSet() fileSet := source.NewFileSet()
@ -254,7 +258,7 @@ func runREPL(modules map[string]objects.Importable, in io.Reader, out io.Writer)
} }
} }
func compileSrc(modules map[string]objects.Importable, src []byte, filename string) (*compiler.Bytecode, error) { func compileSrc(modules *objects.ModuleMap, src []byte, filename string) (*compiler.Bytecode, error) {
fileSet := source.NewFileSet() fileSet := source.NewFileSet()
srcFile := fileSet.AddFile(filename, -1, len(src)) srcFile := fileSet.AddFile(filename, -1, len(src))

64
cli/cli_test.go Normal file
View file

@ -0,0 +1,64 @@
package cli_test
import (
"io/ioutil"
"os"
"path/filepath"
"regexp"
"testing"
"github.com/d5/tengo/assert"
"github.com/d5/tengo/cli"
"github.com/d5/tengo/stdlib"
)
func TestCLICompileAndRun(t *testing.T) {
tempDir := filepath.Join(os.TempDir(), "tengo_tests")
_ = os.MkdirAll(tempDir, os.ModePerm)
binFile := filepath.Join(tempDir, "cli_bin")
outFile := filepath.Join(tempDir, "cli_out")
defer func() {
_ = os.RemoveAll(tempDir)
}()
src := []byte(`
os := import("os")
rand := import("rand")
times := import("times")
rand.seed(times.time_nanosecond(times.now()))
rand_num := func() {
return rand.intn(100)
}
file := os.create("` + outFile + `")
file.write_string("random number is " + rand_num())
file.close()
`)
mods := stdlib.GetModuleMap(stdlib.AllModuleNames()...)
err := cli.CompileOnly(mods, src, "src", binFile)
if !assert.NoError(t, err) {
return
}
compiledBin, err := ioutil.ReadFile(binFile)
if !assert.NoError(t, err) {
return
}
err = cli.RunCompiled(mods, compiledBin)
if !assert.NoError(t, err) {
return
}
read, err := ioutil.ReadFile(outFile)
if !assert.NoError(t, err) {
return
}
ok, err := regexp.Match(`^random number is \d+$`, read)
assert.NoError(t, err)
assert.True(t, ok, string(read))
}

View file

@ -27,7 +27,7 @@ func main() {
ShowVersion: showVersion, ShowVersion: showVersion,
Version: version, Version: version,
CompileOutput: compileOutput, CompileOutput: compileOutput,
Modules: stdlib.GetModules(stdlib.AllModuleNames()...), Modules: stdlib.GetModuleMap(stdlib.AllModuleNames()...),
InputFile: flag.Arg(0), InputFile: flag.Arg(0),
}) })
} }

View file

@ -2,13 +2,18 @@ package compiler
import ( import (
"encoding/gob" "encoding/gob"
"fmt"
"io" "io"
"github.com/d5/tengo/objects" "github.com/d5/tengo/objects"
) )
// Decode reads Bytecode data from the reader. // Decode reads Bytecode data from the reader.
func (b *Bytecode) Decode(r io.Reader) error { func (b *Bytecode) Decode(r io.Reader, modules *objects.ModuleMap) error {
if modules == nil {
modules = objects.NewModuleMap()
}
dec := gob.NewDecoder(r) dec := gob.NewDecoder(r)
if err := dec.Decode(&b.FileSet); err != nil { if err := dec.Decode(&b.FileSet); err != nil {
@ -25,38 +30,68 @@ func (b *Bytecode) Decode(r io.Reader) error {
return err return err
} }
for i, v := range b.Constants { for i, v := range b.Constants {
b.Constants[i] = fixDecoded(v) fv, err := fixDecoded(v, modules)
if err != nil {
return err
}
b.Constants[i] = fv
} }
return nil return nil
} }
func fixDecoded(o objects.Object) objects.Object { func fixDecoded(o objects.Object, modules *objects.ModuleMap) (objects.Object, error) {
switch o := o.(type) { switch o := o.(type) {
case *objects.Bool: case *objects.Bool:
if o.IsFalsy() { if o.IsFalsy() {
return objects.FalseValue return objects.FalseValue, nil
} }
return objects.TrueValue return objects.TrueValue, nil
case *objects.Undefined: case *objects.Undefined:
return objects.UndefinedValue return objects.UndefinedValue, nil
case *objects.Array: case *objects.Array:
for i, v := range o.Value { for i, v := range o.Value {
o.Value[i] = fixDecoded(v) fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[i] = fv
} }
case *objects.ImmutableArray: case *objects.ImmutableArray:
for i, v := range o.Value { for i, v := range o.Value {
o.Value[i] = fixDecoded(v) fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[i] = fv
} }
case *objects.Map: case *objects.Map:
for k, v := range o.Value { for k, v := range o.Value {
o.Value[k] = fixDecoded(v) fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[k] = fv
} }
case *objects.ImmutableMap: case *objects.ImmutableMap:
modName := moduleName(o)
if mod := modules.GetBuiltinModule(modName); mod != nil {
return mod.AsImmutableMap(modName), nil
}
for k, v := range o.Value { for k, v := range o.Value {
o.Value[k] = fixDecoded(v) // encoding of user function not supported
if _, isUserFunction := v.(*objects.UserFunction); isUserFunction {
return nil, fmt.Errorf("user function not decodable")
}
fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[k] = fv
} }
} }
return o return o, nil
} }

View file

@ -280,7 +280,7 @@ func testBytecodeSerialization(t *testing.T, b *compiler.Bytecode) {
assert.NoError(t, err) assert.NoError(t, err)
r := &compiler.Bytecode{} r := &compiler.Bytecode{}
err = r.Decode(bytes.NewReader(buf.Bytes())) err = r.Decode(bytes.NewReader(buf.Bytes()), nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, b.FileSet, r.FileSet) assert.Equal(t, b.FileSet, r.FileSet)

View file

@ -24,7 +24,7 @@ type Compiler struct {
symbolTable *SymbolTable symbolTable *SymbolTable
scopes []CompilationScope scopes []CompilationScope
scopeIndex int scopeIndex int
importModules map[string]objects.Importable modules *objects.ModuleMap
compiledModules map[string]*objects.CompiledFunction compiledModules map[string]*objects.CompiledFunction
allowFileImport bool allowFileImport bool
loops []*Loop loops []*Loop
@ -34,7 +34,7 @@ type Compiler struct {
} }
// NewCompiler creates a Compiler. // NewCompiler creates a Compiler.
func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []objects.Object, importModules map[string]objects.Importable, trace io.Writer) *Compiler { func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []objects.Object, modules *objects.ModuleMap, trace io.Writer) *Compiler {
mainScope := CompilationScope{ mainScope := CompilationScope{
symbolInit: make(map[string]bool), symbolInit: make(map[string]bool),
sourceMap: make(map[int]source.Pos), sourceMap: make(map[int]source.Pos),
@ -51,8 +51,8 @@ func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []object
} }
// builtin modules // builtin modules
if importModules == nil { if modules == nil {
importModules = make(map[string]objects.Importable) modules = objects.NewModuleMap()
} }
return &Compiler{ return &Compiler{
@ -63,7 +63,7 @@ func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []object
scopeIndex: 0, scopeIndex: 0,
loopIndex: -1, loopIndex: -1,
trace: trace, trace: trace,
importModules: importModules, modules: modules,
compiledModules: make(map[string]*objects.CompiledFunction), compiledModules: make(map[string]*objects.CompiledFunction),
} }
} }
@ -513,7 +513,7 @@ func (c *Compiler) Compile(node ast.Node) error {
return c.errorf(node, "empty module name") return c.errorf(node, "empty module name")
} }
if mod, ok := c.importModules[node.ModuleName]; ok { if mod := c.modules.Get(node.ModuleName); mod != nil {
v, err := mod.Import(node.ModuleName) v, err := mod.Import(node.ModuleName)
if err != nil { if err != nil {
return err return err
@ -644,7 +644,7 @@ func (c *Compiler) EnableFileImport(enable bool) {
} }
func (c *Compiler) fork(file *source.File, modulePath string, symbolTable *SymbolTable) *Compiler { func (c *Compiler) fork(file *source.File, modulePath string, symbolTable *SymbolTable) *Compiler {
child := NewCompiler(file, symbolTable, nil, c.importModules, c.trace) child := NewCompiler(file, symbolTable, nil, c.modules, c.trace)
child.modulePath = modulePath // module file path child.modulePath = modulePath // module file path
child.parent = c // parent to set to current compiler child.parent = c // parent to set to current compiler

View file

@ -119,17 +119,13 @@ Users can add and use a custom user type in Tengo code by implementing [Object](
To securely compile and execute _potentially_ unsafe script code, you can use the following Script functions. To securely compile and execute _potentially_ unsafe script code, you can use the following Script functions.
#### Script.SetImports(modules map[string]objects.Importable) #### Script.SetImports(modules *objects.ModuleMap)
SetImports sets the import modules with corresponding names. Script **does not** include any modules by default. You can use this function to include the [Standard Library](https://github.com/d5/tengo/blob/master/docs/stdlib.md). SetImports sets the import modules with corresponding names. Script **does not** include any modules by default. You can use this function to include the [Standard Library](https://github.com/d5/tengo/blob/master/docs/stdlib.md).
```golang ```golang
s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`)) s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`))
s.SetImports(map[string]objects.Importable{
"math": stdlib.BuiltinModules["math"],
})
// or
s.SetImports(stdlib.GetModules("math")) s.SetImports(stdlib.GetModules("math"))
// or, to include all stdlib at once // or, to include all stdlib at once
s.SetImports(stdlib.GetModules(stdlib.AllModuleNames()...)) s.SetImports(stdlib.GetModules(stdlib.AllModuleNames()...))
@ -140,9 +136,9 @@ You can also include Tengo's written module using `objects.SourceModule` (which
```golang ```golang
s := script.New([]byte(`double := import("double"); a := double(20)`)) s := script.New([]byte(`double := import("double"); a := double(20)`))
s.SetImports(map[string]objects.Importable{ mods := objects.NewModuleMap()
"double": &objects.SourceModule{Src: []byte(`export func(x) { return x * 2 }`)}, mods.AddSourceModule("double", []byte(`export func(x) { return x * 2 }`))
}) s.SetImports(mods)
``` ```

View file

@ -6,11 +6,18 @@ type BuiltinModule struct {
} }
// Import returns an immutable map for the module. // Import returns an immutable map for the module.
func (m *BuiltinModule) Import(name string) (interface{}, error) { func (m *BuiltinModule) Import(moduleName string) (interface{}, error) {
return m.AsImmutableMap(moduleName), nil
}
// AsImmutableMap converts builtin module into an immutable map.
func (m *BuiltinModule) AsImmutableMap(moduleName string) *ImmutableMap {
attrs := make(map[string]Object, len(m.Attrs)) attrs := make(map[string]Object, len(m.Attrs))
for k, v := range m.Attrs { for k, v := range m.Attrs {
attrs[k] = v.Copy() attrs[k] = v.Copy()
} }
attrs["__module_name__"] = &String{Value: name}
return &ImmutableMap{Value: attrs}, nil attrs["__module_name__"] = &String{Value: moduleName}
return &ImmutableMap{Value: attrs}
} }

View file

@ -3,5 +3,5 @@ package objects
// Importable interface represents importable module instance. // Importable interface represents importable module instance.
type Importable interface { type Importable interface {
// Import should return either an Object or module source code ([]byte). // Import should return either an Object or module source code ([]byte).
Import(name string) (interface{}, error) Import(moduleName string) (interface{}, error)
} }

65
objects/module_map.go Normal file
View file

@ -0,0 +1,65 @@
package objects
// ModuleMap represents a set of named modules.
// Use NewModuleMap to create a new module map.
type ModuleMap struct {
m map[string]Importable
}
// NewModuleMap creates a new module map.
func NewModuleMap() *ModuleMap {
return &ModuleMap{
m: make(map[string]Importable),
}
}
// Add adds an import module.
func (m *ModuleMap) Add(name string, module Importable) {
m.m[name] = module
}
// AddBuiltinModule adds a builtin module.
func (m *ModuleMap) AddBuiltinModule(name string, attrs map[string]Object) {
m.m[name] = &BuiltinModule{Attrs: attrs}
}
// AddSourceModule adds a source module.
func (m *ModuleMap) AddSourceModule(name string, src []byte) {
m.m[name] = &SourceModule{Src: src}
}
// Get returns an import module identified by name.
// It returns if the name is not found.
func (m *ModuleMap) Get(name string) Importable {
return m.m[name]
}
// GetBuiltinModule returns a builtin module identified by name.
// It returns if the name is not found or the module is not a builtin module.
func (m *ModuleMap) GetBuiltinModule(name string) *BuiltinModule {
mod, _ := m.m[name].(*BuiltinModule)
return mod
}
// GetSourceModule returns a source module identified by name.
// It returns if the name is not found or the module is not a source module.
func (m *ModuleMap) GetSourceModule(name string) *SourceModule {
mod, _ := m.m[name].(*SourceModule)
return mod
}
// Copy creates a copy of the module map.
func (m *ModuleMap) Copy() *ModuleMap {
c := &ModuleMap{
m: make(map[string]Importable),
}
for name, mod := range m.m {
c.m[name] = mod
}
return c
}
// Len returns the number of named modules.
func (m *ModuleMap) Len() int {
return len(m.m)
}

View file

@ -8,6 +8,7 @@ import (
type UserFunction struct { type UserFunction struct {
Name string Name string
Value CallableFunc Value CallableFunc
EncodingID string
} }
// TypeName returns the name of the type. // TypeName returns the name of the type.

View file

@ -24,7 +24,7 @@ type MAP = map[string]interface{}
type ARR = []interface{} type ARR = []interface{}
type testopts struct { type testopts struct {
modules map[string]objects.Importable modules *objects.ModuleMap
symbols map[string]objects.Object symbols map[string]objects.Object
maxAllocs int64 maxAllocs int64
skip2ndPass bool skip2ndPass bool
@ -32,7 +32,7 @@ type testopts struct {
func Opts() *testopts { func Opts() *testopts {
return &testopts{ return &testopts{
modules: make(map[string]objects.Importable), modules: objects.NewModuleMap(),
symbols: make(map[string]objects.Object), symbols: make(map[string]objects.Object),
maxAllocs: -1, maxAllocs: -1,
skip2ndPass: false, skip2ndPass: false,
@ -41,14 +41,11 @@ func Opts() *testopts {
func (o *testopts) copy() *testopts { func (o *testopts) copy() *testopts {
c := &testopts{ c := &testopts{
modules: make(map[string]objects.Importable), modules: o.modules.Copy(),
symbols: make(map[string]objects.Object), symbols: make(map[string]objects.Object),
maxAllocs: o.maxAllocs, maxAllocs: o.maxAllocs,
skip2ndPass: o.skip2ndPass, skip2ndPass: o.skip2ndPass,
} }
for k, v := range o.modules {
c.modules[k] = v
}
for k, v := range o.symbols { for k, v := range o.symbols {
c.symbols[k] = v c.symbols[k] = v
} }
@ -59,11 +56,11 @@ func (o *testopts) Module(name string, mod interface{}) *testopts {
c := o.copy() c := o.copy()
switch mod := mod.(type) { switch mod := mod.(type) {
case objects.Importable: case objects.Importable:
c.modules[name] = mod c.modules.Add(name, mod)
case string: case string:
c.modules[name] = &objects.SourceModule{Src: []byte(mod)} c.modules.AddSourceModule(name, []byte(mod))
case []byte: case []byte:
c.modules[name] = &objects.SourceModule{Src: mod} c.modules.AddSourceModule(name, mod)
default: default:
panic(fmt.Errorf("invalid module type: %T", mod)) panic(fmt.Errorf("invalid module type: %T", mod))
} }
@ -135,9 +132,7 @@ func expect(t *testing.T, input string, opts *testopts, expected interface{}) {
expectedObj = &objects.ImmutableMap{Value: eo.Value} expectedObj = &objects.ImmutableMap{Value: eo.Value}
} }
modules["__code__"] = &objects.SourceModule{ modules.AddSourceModule("__code__", []byte(fmt.Sprintf("out := undefined; %s; export out", input)))
Src: []byte(fmt.Sprintf("out := undefined; %s; export out", input)),
}
res, trace, err := traceCompileRun(file, symbols, modules, maxAllocs) res, trace, err := traceCompileRun(file, symbols, modules, maxAllocs)
if !assert.NoError(t, err) || if !assert.NoError(t, err) ||
@ -184,7 +179,7 @@ func (o *tracer) Write(p []byte) (n int, err error) {
return len(p), nil return len(p), nil
} }
func traceCompileRun(file *ast.File, symbols map[string]objects.Object, modules map[string]objects.Importable, maxAllocs int64) (res map[string]objects.Object, trace []string, err error) { func traceCompileRun(file *ast.File, symbols map[string]objects.Object, modules *objects.ModuleMap, maxAllocs int64) (res map[string]objects.Object, trace []string, err error) {
var v *runtime.VM var v *runtime.VM
defer func() { defer func() {

View file

@ -14,7 +14,7 @@ import (
// Script can simplify compilation and execution of embedded scripts. // Script can simplify compilation and execution of embedded scripts.
type Script struct { type Script struct {
variables map[string]*Variable variables map[string]*Variable
importModules map[string]objects.Importable modules *objects.ModuleMap
input []byte input []byte
maxAllocs int64 maxAllocs int64
maxConstObjects int maxConstObjects int
@ -59,8 +59,8 @@ func (s *Script) Remove(name string) bool {
} }
// SetImports sets import modules. // SetImports sets import modules.
func (s *Script) SetImports(modules map[string]objects.Importable) { func (s *Script) SetImports(modules *objects.ModuleMap) {
s.importModules = modules s.modules = modules
} }
// SetMaxAllocs sets the maximum number of objects allocations during the run time. // SetMaxAllocs sets the maximum number of objects allocations during the run time.
@ -96,7 +96,7 @@ func (s *Script) Compile() (*Compiled, error) {
return nil, err return nil, err
} }
c := compiler.NewCompiler(srcFile, symbolTable, nil, s.importModules, nil) c := compiler.NewCompiler(srcFile, symbolTable, nil, s.modules, nil)
c.EnableFileImport(s.enableFileImport) c.EnableFileImport(s.enableFileImport)
if err := c.Compile(file); err != nil { if err := c.Compile(file); err != nil {
return nil, err return nil, err

View file

@ -45,8 +45,7 @@ for i:=1; i<=d; i++ {
e := mod1.double(s) e := mod1.double(s)
`) `)
mod1 := &objects.BuiltinModule{ mod1 := map[string]objects.Object{
Attrs: map[string]objects.Object{
"double": &objects.UserFunction{ "double": &objects.UserFunction{
Value: func(args ...objects.Object) (ret objects.Object, err error) { Value: func(args ...objects.Object) (ret objects.Object, err error) {
arg0, _ := objects.ToInt64(args[0]) arg0, _ := objects.ToInt64(args[0])
@ -54,16 +53,15 @@ e := mod1.double(s)
return return
}, },
}, },
},
} }
scr := script.New(code) scr := script.New(code)
_ = scr.Add("a", 0) _ = scr.Add("a", 0)
_ = scr.Add("b", 0) _ = scr.Add("b", 0)
_ = scr.Add("c", 0) _ = scr.Add("c", 0)
scr.SetImports(map[string]objects.Importable{ mods := objects.NewModuleMap()
"mod1": mod1, mods.AddBuiltinModule("mod1", mod1)
}) scr.SetImports(mods)
compiled, err := scr.Compile() compiled, err := scr.Compile()
assert.NoError(t, err) assert.NoError(t, err)

View file

@ -12,38 +12,44 @@ import (
func TestScriptSourceModule(t *testing.T) { func TestScriptSourceModule(t *testing.T) {
// script1 imports "mod1" // script1 imports "mod1"
scr := script.New([]byte(`out := import("mod")`)) scr := script.New([]byte(`out := import("mod")`))
scr.SetImports(map[string]objects.Importable{ mods := objects.NewModuleMap()
"mod": &objects.SourceModule{Src: []byte(`export 5`)}, mods.AddSourceModule("mod", []byte(`export 5`))
}) scr.SetImports(mods)
c, err := scr.Run() c, err := scr.Run()
if !assert.NoError(t, err) {
return
}
assert.Equal(t, int64(5), c.Get("out").Value()) assert.Equal(t, int64(5), c.Get("out").Value())
// executing module function // executing module function
scr = script.New([]byte(`fn := import("mod"); out := fn()`)) scr = script.New([]byte(`fn := import("mod"); out := fn()`))
scr.SetImports(map[string]objects.Importable{ mods = objects.NewModuleMap()
"mod": &objects.SourceModule{Src: []byte(`a := 3; export func() { return a + 5 }`)}, mods.AddSourceModule("mod", []byte(`a := 3; export func() { return a + 5 }`))
}) scr.SetImports(mods)
c, err = scr.Run() c, err = scr.Run()
assert.NoError(t, err) if !assert.NoError(t, err) {
return
}
assert.Equal(t, int64(8), c.Get("out").Value()) assert.Equal(t, int64(8), c.Get("out").Value())
scr = script.New([]byte(`out := import("mod")`)) scr = script.New([]byte(`out := import("mod")`))
scr.SetImports(map[string]objects.Importable{ mods = objects.NewModuleMap()
"text": &objects.BuiltinModule{ mods.AddSourceModule("mod", []byte(`text := import("text"); export text.title("foo")`))
Attrs: map[string]objects.Object{ mods.AddBuiltinModule("text", map[string]objects.Object{
"title": &objects.UserFunction{Name: "title", Value: func(args ...objects.Object) (ret objects.Object, err error) { "title": &objects.UserFunction{Name: "title", Value: func(args ...objects.Object) (ret objects.Object, err error) {
s, _ := objects.ToString(args[0]) s, _ := objects.ToString(args[0])
return &objects.String{Value: strings.Title(s)}, nil return &objects.String{Value: strings.Title(s)}, nil
}}, }},
},
},
"mod": &objects.SourceModule{Src: []byte(`text := import("text"); export text.title("foo")`)},
}) })
scr.SetImports(mods)
c, err = scr.Run() c, err = scr.Run()
assert.NoError(t, err) if !assert.NoError(t, err) {
return
}
assert.Equal(t, "Foo", c.Get("out").Value()) assert.Equal(t, "Foo", c.Get("out").Value())
scr.SetImports(nil) scr.SetImports(nil)
_, err = scr.Run() _, err = scr.Run()
assert.Error(t, err) if !assert.Error(t, err) {
return
}
} }

View file

@ -53,7 +53,7 @@ func TestScript_Run(t *testing.T) {
func TestScript_BuiltinModules(t *testing.T) { func TestScript_BuiltinModules(t *testing.T) {
s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`)) s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`))
s.SetImports(map[string]objects.Importable{"math": stdlib.BuiltinModules["math"]}) s.SetImports(stdlib.GetModuleMap("math"))
c, err := s.Run() c, err := s.Run()
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)
@ -64,7 +64,7 @@ func TestScript_BuiltinModules(t *testing.T) {
assert.NotNil(t, c) assert.NotNil(t, c)
compiledGet(t, c, "a", 19.84) compiledGet(t, c, "a", 19.84)
s.SetImports(map[string]objects.Importable{"os": &objects.BuiltinModule{Attrs: map[string]objects.Object{}}}) s.SetImports(stdlib.GetModuleMap("os"))
_, err = s.Run() _, err = s.Run()
assert.Error(t, err) assert.Error(t, err)
@ -80,7 +80,7 @@ a := enum.all([1,2,3], func(_, v) {
return v > 0 return v > 0
}) })
`)) `))
s.SetImports(stdlib.GetModules("enum")) s.SetImports(stdlib.GetModuleMap("enum"))
c, err := s.Run() c, err := s.Run()
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)

View file

@ -3,12 +3,12 @@ package stdlib
import "github.com/d5/tengo/objects" import "github.com/d5/tengo/objects"
// BuiltinModules are builtin type standard library modules. // BuiltinModules are builtin type standard library modules.
var BuiltinModules = map[string]*objects.BuiltinModule{ var BuiltinModules = map[string]map[string]objects.Object{
"math": {Attrs: mathModule}, "math": mathModule,
"os": {Attrs: osModule}, "os": osModule,
"text": {Attrs: textModule}, "text": textModule,
"times": {Attrs: timesModule}, "times": timesModule,
"rand": {Attrs: randModule}, "rand": randModule,
"fmt": {Attrs: fmtModule}, "fmt": fmtModule,
"json": {Attrs: jsonModule}, "json": jsonModule,
} }

View file

@ -38,12 +38,10 @@ func main() {
package stdlib package stdlib
import "github.com/d5/tengo/objects"
// SourceModules are source type standard library modules. // SourceModules are source type standard library modules.
var SourceModules = map[string]*objects.SourceModule{` + "\n") var SourceModules = map[string]string{` + "\n")
for modName, modSrc := range modules { for modName, modSrc := range modules {
out.WriteString("\t\"" + modName + "\": {Src: []byte(`" + modSrc + "`)},\n") out.WriteString("\t\"" + modName + "\": `" + modSrc + "`,\n")
} }
out.WriteString("}\n") out.WriteString("}\n")

View file

@ -2,11 +2,9 @@
package stdlib package stdlib
import "github.com/d5/tengo/objects"
// SourceModules are source type standard library modules. // SourceModules are source type standard library modules.
var SourceModules = map[string]*objects.SourceModule{ var SourceModules = map[string]string{
"enum": {Src: []byte(`export { "enum": `export {
// all returns true if the given function fn evaluates to a truthy value on // all returns true if the given function fn evaluates to a truthy value on
// all of the items in the enumerable. // all of the items in the enumerable.
all: func(enumerable, fn) { all: func(enumerable, fn) {
@ -46,5 +44,5 @@ var SourceModules = map[string]*objects.SourceModule{
return res return res
} }
} }
`)}, `,
} }

View file

@ -16,16 +16,17 @@ func AllModuleNames() []string {
return names return names
} }
// GetModules returns the modules for the given names. // GetModuleMap returns the module map that includes all modules
// Duplicate names and invalid names are ignore. // for the given module names.
func GetModules(names ...string) map[string]objects.Importable { func GetModuleMap(names ...string) *objects.ModuleMap {
modules := make(map[string]objects.Importable) modules := objects.NewModuleMap()
for _, name := range names { for _, name := range names {
if mod := BuiltinModules[name]; mod != nil { if mod := BuiltinModules[name]; mod != nil {
modules[name] = mod modules.AddBuiltinModule(name, mod)
} }
if mod := SourceModules[name]; mod != nil { if mod := SourceModules[name]; mod != "" {
modules[name] = mod modules.AddSourceModule(name, []byte(mod))
} }
} }

View file

@ -74,25 +74,25 @@ if !is_error(cmd) {
} }
func TestGetModules(t *testing.T) { func TestGetModules(t *testing.T) {
mods := stdlib.GetModules() mods := stdlib.GetModuleMap()
assert.Equal(t, 0, len(mods)) assert.Equal(t, 0, mods.Len())
mods = stdlib.GetModules("os") mods = stdlib.GetModuleMap("os")
assert.Equal(t, 1, len(mods)) assert.Equal(t, 1, mods.Len())
assert.NotNil(t, mods["os"]) assert.NotNil(t, mods.Get("os"))
mods = stdlib.GetModules("os", "rand") mods = stdlib.GetModuleMap("os", "rand")
assert.Equal(t, 2, len(mods)) assert.Equal(t, 2, mods.Len())
assert.NotNil(t, mods["os"]) assert.NotNil(t, mods.Get("os"))
assert.NotNil(t, mods["rand"]) assert.NotNil(t, mods.Get("rand"))
mods = stdlib.GetModules("text", "text") mods = stdlib.GetModuleMap("text", "text")
assert.Equal(t, 1, len(mods)) assert.Equal(t, 1, mods.Len())
assert.NotNil(t, mods["text"]) assert.NotNil(t, mods.Get("text"))
mods = stdlib.GetModules("nonexisting", "text") mods = stdlib.GetModuleMap("nonexisting", "text")
assert.Equal(t, 1, len(mods)) assert.Equal(t, 1, mods.Len())
assert.NotNil(t, mods["text"]) assert.NotNil(t, mods.Get("text"))
} }
type callres struct { type callres struct {
@ -156,8 +156,8 @@ func (c callres) expectError() bool {
} }
func module(t *testing.T, moduleName string) callres { func module(t *testing.T, moduleName string) callres {
mod, ok := stdlib.BuiltinModules[moduleName] mod := stdlib.GetModuleMap(moduleName).GetBuiltinModule(moduleName)
if !ok { if mod == nil {
return callres{t: t, e: fmt.Errorf("module not found: %s", moduleName)} return callres{t: t, e: fmt.Errorf("module not found: %s", moduleName)}
} }
@ -231,7 +231,7 @@ func object(v interface{}) objects.Object {
func expect(t *testing.T, input string, expected interface{}) { func expect(t *testing.T, input string, expected interface{}) {
s := script.New([]byte(input)) s := script.New([]byte(input))
s.SetImports(stdlib.GetModules(stdlib.AllModuleNames()...)) s.SetImports(stdlib.GetModuleMap(stdlib.AllModuleNames()...))
c, err := s.Run() c, err := s.Run()
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)