diff --git a/assert/assert.go b/assert/assert.go index 5b16289..ae4ff4e 100644 --- a/assert/assert.go +++ b/assert/assert.go @@ -121,8 +121,8 @@ func Equal(t *testing.T, expected, actual interface{}, msg ...interface{}) bool if expected != actual.(rune) { return failExpectedActual(t, expected, actual, msg...) } - case compiler.Symbol: - if !equalSymbol(expected, actual.(compiler.Symbol)) { + case *compiler.Symbol: + if !equalSymbol(expected, actual.(*compiler.Symbol)) { return failExpectedActual(t, expected, actual, msg...) } case source.Pos: @@ -238,7 +238,7 @@ func equalIntSlice(a, b []int) bool { return true } -func equalSymbol(a, b compiler.Symbol) bool { +func equalSymbol(a, b *compiler.Symbol) bool { return a.Name == b.Name && a.Index == b.Index && a.Scope == b.Scope diff --git a/compiler/ast/export_stmt.go b/compiler/ast/export_stmt.go new file mode 100644 index 0000000..64eb760 --- /dev/null +++ b/compiler/ast/export_stmt.go @@ -0,0 +1,27 @@ +package ast + +import ( + "github.com/d5/tengo/compiler/source" +) + +// ExportStmt represents an export statement. +type ExportStmt struct { + ExportPos source.Pos + Result Expr +} + +func (s *ExportStmt) stmtNode() {} + +// Pos returns the position of first character belonging to the node. +func (s *ExportStmt) Pos() source.Pos { + return s.ExportPos +} + +// End returns the position of first character immediately after the node. +func (s *ExportStmt) End() source.Pos { + return s.Result.End() +} + +func (s *ExportStmt) String() string { + return "export " + s.Result.String() +} diff --git a/compiler/compilation_scope.go b/compiler/compilation_scope.go index 03f86de..b7ee7b2 100644 --- a/compiler/compilation_scope.go +++ b/compiler/compilation_scope.go @@ -5,4 +5,5 @@ package compiler type CompilationScope struct { instructions []byte lastInstructions [2]EmittedInstruction + symbolInit map[string]bool } diff --git a/compiler/compiler.go b/compiler/compiler.go index 9fd2a49..17a6d3a 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -21,7 +21,7 @@ type Compiler struct { scopeIndex int moduleLoader ModuleLoader stdModules map[string]*objects.ImmutableMap - compiledModules map[string]*objects.CompiledModule + compiledModules map[string]*objects.CompiledFunction loops []*Loop loopIndex int trace io.Writer @@ -60,7 +60,7 @@ func NewCompiler(symbolTable *SymbolTable, stdModules map[string]*objects.Immuta loopIndex: -1, trace: trace, stdModules: stdModules, - compiledModules: make(map[string]*objects.CompiledModule), + compiledModules: make(map[string]*objects.CompiledFunction), } } @@ -383,7 +383,10 @@ func (c *Compiler) Compile(node ast.Node) error { c.enterScope() for _, p := range node.Type.Params.List { - c.symbolTable.Define(p.Name) + s := c.symbolTable.Define(p.Name) + + // function arguments is not assigned directly. + s.LocalAssigned = true } if err := c.Compile(node.Body); err != nil { @@ -402,6 +405,50 @@ func (c *Compiler) Compile(node ast.Node) error { for _, s := range freeSymbols { switch s.Scope { case ScopeLocal: + if !s.LocalAssigned { + // Here, the closure is capturing a local variable that's not yet assigned its value. + // One example is a local recursive function: + // + // func() { + // foo := func(x) { + // // .. + // return foo(x-1) + // } + // } + // + // which translate into + // + // 0000 GETL 0 + // 0002 CLOSURE ? 1 + // 0006 DEFL 0 + // + // . So the local variable (0) is being captured before it's assigned the value. + // + // Solution is to transform the code into something like this: + // + // func() { + // foo := undefined + // foo = func(x) { + // // .. + // return foo(x-1) + // } + // } + // + // that is equivalent to + // + // 0000 NULL + // 0001 DEFL 0 + // 0003 GETL 0 + // 0005 CLOSURE ? 1 + // 0009 SETL 0 + // + + c.emit(OpNull) + c.emit(OpDefineLocal, s.Index) + + s.LocalAssigned = true + } + c.emit(OpGetLocal, s.Index) case ScopeFree: c.emit(OpGetFree, s.Index) @@ -461,9 +508,28 @@ func (c *Compiler) Compile(node ast.Node) error { return err } - c.emit(OpModule, c.addConstant(userMod)) + c.emit(OpConstant, c.addConstant(userMod)) + c.emit(OpCall, 0) } + case *ast.ExportStmt: + // export statement must be in top-level scope + if c.scopeIndex != 0 { + return fmt.Errorf("cannot use 'export' inside function") + } + + // export statement is simply ignore when compiling non-module code + if c.parent == nil { + break + } + + if err := c.Compile(node.Result); err != nil { + return err + } + + c.emit(OpImmutable) + c.emit(OpReturnValue) + case *ast.ErrorExpr: if err := c.Compile(node.Expr); err != nil { return err diff --git a/compiler/compiler_assign.go b/compiler/compiler_assign.go index 3efc376..50e6f42 100644 --- a/compiler/compiler_assign.go +++ b/compiler/compiler_assign.go @@ -105,12 +105,15 @@ func (c *Compiler) compileAssign(lhs, rhs []ast.Expr, op token.Token) error { if numSel > 0 { c.emit(OpSetSelLocal, symbol.Index, numSel) } else { - if op == token.Define { + if op == token.Define && !symbol.LocalAssigned { c.emit(OpDefineLocal, symbol.Index) } else { c.emit(OpSetLocal, symbol.Index) } } + + // mark the symbol as local-assigned + symbol.LocalAssigned = true case ScopeFree: if numSel > 0 { c.emit(OpSetSelFree, symbol.Index, numSel) diff --git a/compiler/compiler_module.go b/compiler/compiler_module.go index eda394e..5930c32 100644 --- a/compiler/compiler_module.go +++ b/compiler/compiler_module.go @@ -14,7 +14,7 @@ var ( fileSet = source.NewFileSet() ) -func (c *Compiler) compileModule(moduleName string) (*objects.CompiledModule, error) { +func (c *Compiler) compileModule(moduleName string) (*objects.CompiledFunction, error) { compiledModule, exists := c.loadCompiledModule(moduleName) if exists { return compiledModule, nil @@ -69,7 +69,7 @@ func (c *Compiler) checkCyclicImports(moduleName string) error { return nil } -func (c *Compiler) doCompileModule(moduleName string, src []byte) (*objects.CompiledModule, error) { +func (c *Compiler) doCompileModule(moduleName string, src []byte) (*objects.CompiledFunction, error) { p := parser.NewParser(fileSet.AddFile(moduleName, -1, len(src)), src, nil) file, err := p.ParseFile() if err != nil { @@ -77,27 +77,36 @@ func (c *Compiler) doCompileModule(moduleName string, src []byte) (*objects.Comp } symbolTable := NewSymbolTable() - globals := make(map[string]int) + // inherit builtin functions + for idx, fn := range objects.Builtins { + s, _, ok := c.symbolTable.Resolve(fn.Name) + if ok && s.Scope == ScopeBuiltin { + symbolTable.DefineBuiltin(idx, fn.Name) + } + } + + // no global scope for the module + symbolTable = symbolTable.Fork(false) + + // compile module moduleCompiler := c.fork(moduleName, symbolTable) if err := moduleCompiler.Compile(file); err != nil { return nil, err } - for _, name := range symbolTable.Names() { - symbol, _, _ := symbolTable.Resolve(name) - if symbol.Scope == ScopeGlobal { - globals[name] = symbol.Index - } + // add OpReturn (== export undefined) if export is missing + if !moduleCompiler.lastInstructionIs(OpReturnValue) { + moduleCompiler.emit(OpReturn) } - return &objects.CompiledModule{ + return &objects.CompiledFunction{ Instructions: moduleCompiler.Bytecode().Instructions, - Globals: globals, + NumLocals: symbolTable.MaxSymbols(), }, nil } -func (c *Compiler) loadCompiledModule(moduleName string) (mod *objects.CompiledModule, ok bool) { +func (c *Compiler) loadCompiledModule(moduleName string) (mod *objects.CompiledFunction, ok bool) { if c.parent != nil { return c.parent.loadCompiledModule(moduleName) } @@ -107,7 +116,7 @@ func (c *Compiler) loadCompiledModule(moduleName string) (mod *objects.CompiledM return } -func (c *Compiler) storeCompiledModule(moduleName string, module *objects.CompiledModule) { +func (c *Compiler) storeCompiledModule(moduleName string, module *objects.CompiledFunction) { if c.parent != nil { c.parent.storeCompiledModule(moduleName, module) } diff --git a/compiler/compiler_scopes.go b/compiler/compiler_scopes.go index f0a9061..fdc68b0 100644 --- a/compiler/compiler_scopes.go +++ b/compiler/compiler_scopes.go @@ -7,6 +7,7 @@ func (c *Compiler) currentInstructions() []byte { func (c *Compiler) enterScope() { scope := CompilationScope{ instructions: make([]byte, 0), + symbolInit: make(map[string]bool), } c.scopes = append(c.scopes, scope) diff --git a/compiler/opcodes.go b/compiler/opcodes.go index d64a030..e8f06b8 100644 --- a/compiler/opcodes.go +++ b/compiler/opcodes.go @@ -41,6 +41,7 @@ const ( OpCall // Call function OpReturn // Return OpReturnValue // Return value + OpExport // Export OpGetGlobal // Get global variable OpSetGlobal // Set global variable OpSetSelGlobal // Set global variable using selectors @@ -57,7 +58,6 @@ const ( OpIteratorNext // Iterator next OpIteratorKey // Iterator key OpIteratorValue // Iterator value - OpModule // Module ) // OpcodeNames is opcode names. @@ -101,6 +101,7 @@ var OpcodeNames = [...]string{ OpCall: "CALL", OpReturn: "RET", OpReturnValue: "RETVAL", + OpExport: "EXPORT", OpGetLocal: "GETL", OpSetLocal: "SETL", OpDefineLocal: "DEFL", @@ -114,7 +115,6 @@ var OpcodeNames = [...]string{ OpIteratorNext: "ITNXT", OpIteratorKey: "ITKEY", OpIteratorValue: "ITVAL", - OpModule: "MODULE", } // OpcodeOperands is the number of operands. @@ -158,6 +158,7 @@ var OpcodeOperands = [...][]int{ OpCall: {1}, OpReturn: {}, OpReturnValue: {}, + OpExport: {}, OpGetLocal: {1}, OpSetLocal: {1}, OpDefineLocal: {1}, @@ -171,7 +172,6 @@ var OpcodeOperands = [...][]int{ OpIteratorNext: {}, OpIteratorKey: {}, OpIteratorValue: {}, - OpModule: {2}, } // ReadOperands reads operands from the bytecode. diff --git a/compiler/parser/parser.go b/compiler/parser/parser.go index 7df8bc3..5af2666 100644 --- a/compiler/parser/parser.go +++ b/compiler/parser/parser.go @@ -628,6 +628,8 @@ func (p *Parser) parseStmt() (stmt ast.Stmt) { return s case token.Return: return p.parseReturnStmt() + case token.Export: + return p.parseExportStmt() case token.If: return p.parseIfStmt() case token.For: @@ -874,6 +876,23 @@ func (p *Parser) parseReturnStmt() ast.Stmt { } } +func (p *Parser) parseExportStmt() ast.Stmt { + if p.trace { + defer un(trace(p, "ExportStmt")) + } + + pos := p.pos + p.expect(token.Export) + + x := p.parseExpr() + p.expectSemi() + + return &ast.ExportStmt{ + ExportPos: pos, + Result: x, + } +} + func (p *Parser) parseSimpleStmt(forIn bool) ast.Stmt { if p.trace { defer un(trace(p, "SimpleStmt")) diff --git a/compiler/parser/parser_function_test.go b/compiler/parser/parser_function_test.go index dd6f74e..1ec1ebb 100644 --- a/compiler/parser/parser_function_test.go +++ b/compiler/parser/parser_function_test.go @@ -8,23 +8,6 @@ import ( ) func TestFunction(t *testing.T) { - // TODO: function declaration currently not parsed. - // All functions are parsed as function literal instead. - // In Go, function declaration is parsed only at the top level. - //expect(t, "func a(b, c, d) {}", func(p pfn) []ast.Stmt { - // return stmts( - // declStmt( - // funcDecl( - // ident("a", p(1, 6)), - // funcType( - // identList(p(1, 7), p(1, 15), - // ident("b", p(1, 8)), - // ident("c", p(1, 11)), - // ident("d", p(1, 14))), - // p(1, 12)), - // blockStmt(p(1, 17), p(1, 18))))) - //}) - expect(t, "a = func(b, c, d) { return d }", func(p pfn) []ast.Stmt { return stmts( assignStmt( diff --git a/compiler/parser/sync.go b/compiler/parser/sync.go index 83e5e62..e68d623 100644 --- a/compiler/parser/sync.go +++ b/compiler/parser/sync.go @@ -8,4 +8,5 @@ var stmtStart = map[token.Token]bool{ token.For: true, token.If: true, token.Return: true, + token.Export: true, } diff --git a/compiler/scanner/scanner.go b/compiler/scanner/scanner.go index ae9156b..11e71b9 100644 --- a/compiler/scanner/scanner.go +++ b/compiler/scanner/scanner.go @@ -77,7 +77,7 @@ func (s *Scanner) Scan() (tok token.Token, literal string, pos source.Pos) { literal = s.scanIdentifier() tok = token.Lookup(literal) switch tok { - case token.Ident, token.Break, token.Continue, token.Return, token.True, token.False, token.Undefined: + case token.Ident, token.Break, token.Continue, token.Return, token.Export, token.True, token.False, token.Undefined: insertSemi = true } case '0' <= ch && ch <= '9': diff --git a/compiler/scanner/scanner_test.go b/compiler/scanner/scanner_test.go index 5a67369..c975386 100644 --- a/compiler/scanner/scanner_test.go +++ b/compiler/scanner/scanner_test.go @@ -116,6 +116,7 @@ func TestScanner_Scan(t *testing.T) { {token.Func, "func"}, {token.If, "if"}, {token.Return, "return"}, + {token.Export, "export"}, } // combine diff --git a/compiler/symbol.go b/compiler/symbol.go index 21b9508..bcd5323 100644 --- a/compiler/symbol.go +++ b/compiler/symbol.go @@ -2,7 +2,8 @@ package compiler // Symbol represents a symbol in the symbol table. type Symbol struct { - Name string - Scope SymbolScope - Index int + Name string + Scope SymbolScope + Index int + LocalAssigned bool // if the local symbol is assigned at least once } diff --git a/compiler/symbol_table.go b/compiler/symbol_table.go index eb83f8c..5983cc9 100644 --- a/compiler/symbol_table.go +++ b/compiler/symbol_table.go @@ -4,22 +4,22 @@ package compiler type SymbolTable struct { parent *SymbolTable block bool - store map[string]Symbol + store map[string]*Symbol numDefinition int maxDefinition int - freeSymbols []Symbol + freeSymbols []*Symbol } // NewSymbolTable creates a SymbolTable. func NewSymbolTable() *SymbolTable { return &SymbolTable{ - store: make(map[string]Symbol), + store: make(map[string]*Symbol), } } // Define adds a new symbol in the current scope. -func (t *SymbolTable) Define(name string) Symbol { - symbol := Symbol{Name: name, Index: t.nextIndex()} +func (t *SymbolTable) Define(name string) *Symbol { + symbol := &Symbol{Name: name, Index: t.nextIndex()} t.numDefinition++ if t.Parent(true) == nil { @@ -36,8 +36,8 @@ func (t *SymbolTable) Define(name string) Symbol { } // DefineBuiltin adds a symbol for builtin function. -func (t *SymbolTable) DefineBuiltin(index int, name string) Symbol { - symbol := Symbol{ +func (t *SymbolTable) DefineBuiltin(index int, name string) *Symbol { + symbol := &Symbol{ Name: name, Index: index, Scope: ScopeBuiltin, @@ -49,7 +49,7 @@ func (t *SymbolTable) DefineBuiltin(index int, name string) Symbol { } // Resolve resolves a symbol with a given name. -func (t *SymbolTable) Resolve(name string) (symbol Symbol, depth int, ok bool) { +func (t *SymbolTable) Resolve(name string) (symbol *Symbol, depth int, ok bool) { symbol, ok = t.store[name] if !ok && t.parent != nil { symbol, depth, ok = t.parent.Resolve(name) @@ -76,7 +76,7 @@ func (t *SymbolTable) Resolve(name string) (symbol Symbol, depth int, ok bool) { // Fork creates a new symbol table for a new scope. func (t *SymbolTable) Fork(block bool) *SymbolTable { return &SymbolTable{ - store: make(map[string]Symbol), + store: make(map[string]*Symbol), parent: t, block: block, } @@ -97,7 +97,7 @@ func (t *SymbolTable) MaxSymbols() int { } // FreeSymbols returns free symbols for the scope. -func (t *SymbolTable) FreeSymbols() []Symbol { +func (t *SymbolTable) FreeSymbols() []*Symbol { return t.freeSymbols } @@ -128,12 +128,12 @@ func (t *SymbolTable) updateMaxDefs(numDefs int) { } } -func (t *SymbolTable) defineFree(original Symbol) Symbol { +func (t *SymbolTable) defineFree(original *Symbol) *Symbol { // TODO: should we check duplicates? t.freeSymbols = append(t.freeSymbols, original) - symbol := Symbol{ + symbol := &Symbol{ Name: original.Name, Index: len(t.freeSymbols) - 1, Scope: ScopeFree, diff --git a/compiler/symbol_table_test.go b/compiler/symbol_table_test.go index bf1bfb2..888c9f0 100644 --- a/compiler/symbol_table_test.go +++ b/compiler/symbol_table_test.go @@ -91,23 +91,23 @@ func TestSymbolTable(t *testing.T) { resolveExpect(t, local2Block2, "b", globalSymbol("b", 1), 2) } -func symbol(name string, scope compiler.SymbolScope, index int) compiler.Symbol { - return compiler.Symbol{ +func symbol(name string, scope compiler.SymbolScope, index int) *compiler.Symbol { + return &compiler.Symbol{ Name: name, Scope: scope, Index: index, } } -func globalSymbol(name string, index int) compiler.Symbol { +func globalSymbol(name string, index int) *compiler.Symbol { return symbol(name, compiler.ScopeGlobal, index) } -func localSymbol(name string, index int) compiler.Symbol { +func localSymbol(name string, index int) *compiler.Symbol { return symbol(name, compiler.ScopeLocal, index) } -func freeSymbol(name string, index int) compiler.Symbol { +func freeSymbol(name string, index int) *compiler.Symbol { return symbol(name, compiler.ScopeFree, index) } @@ -115,7 +115,7 @@ func symbolTable() *compiler.SymbolTable { return compiler.NewSymbolTable() } -func resolveExpect(t *testing.T, symbolTable *compiler.SymbolTable, name string, expectedSymbol compiler.Symbol, expectedDepth int) { +func resolveExpect(t *testing.T, symbolTable *compiler.SymbolTable, name string, expectedSymbol *compiler.Symbol, expectedDepth int) { actualSymbol, actualDepth, ok := symbolTable.Resolve(name) assert.True(t, ok) assert.Equal(t, expectedSymbol, actualSymbol) diff --git a/compiler/token/tokens.go b/compiler/token/tokens.go index 210bf82..b32d36e 100644 --- a/compiler/token/tokens.go +++ b/compiler/token/tokens.go @@ -76,6 +76,7 @@ const ( Immutable If Return + Export True False In @@ -149,6 +150,7 @@ var tokens = [...]string{ Immutable: "immutable", If: "if", Return: "return", + Export: "export", True: "true", False: "false", In: "in", diff --git a/docs/interoperability.md b/docs/interoperability.md index 3cf7007..547f93f 100644 --- a/docs/interoperability.md +++ b/docs/interoperability.md @@ -5,7 +5,6 @@ - [Using Scripts](#using-scripts) - [Type Conversion Table](#type-conversion-table) - [User Types](#user-types) - - [Importing Scripts](#importing-scripts) - [Sandbox Environments](#sandbox-environments) - [Compiler and VM](#compiler-and-vm) @@ -115,20 +114,6 @@ When adding a Variable _([Script.Add](https://godoc.org/github.com/d5/tengo/scri Users can add and use a custom user type in Tengo code by implementing [Object](https://godoc.org/github.com/d5/tengo/objects#Object) interface. Tengo runtime will treat the user types in the same way it does to the runtime types with no performance overhead. See [Object Types](https://github.com/d5/tengo/blob/master/docs/objects.md) for more details. -### Importing Scripts - -A script can import and use another script in the same way it can load the standard library or the user module. `Script.AddModule` function adds another script as a named module. - -```golang -mod1Script := script.New([]byte(`a := 5`)) // mod1 script - -mainScript := script.New([]byte(`print(import("mod1").a)`)) // main script -mainScript.AddModule("mod1", mod1Script) // add mod1 using name "mod1" -mainScript.Run() // prints "5" -``` - -Note that the script modules added using `Script.AddModule` will be compiled and run right before the main script is compiled. - ## Sandbox Environments To securely compile and execute _potentially_ unsafe script code, you can use the following Script functions. @@ -145,6 +130,8 @@ s.DisableBuiltinFunction("print") _, err := s.Run() // compile error ``` +Note that when a script is being added to another script as a module (via `Script.AddModule`), it does not inherit the disabled builtin function list from the main script. + #### Script.DisableStdModule(name string) DisableStdModule disables a [standard library](https://github.com/d5/tengo/blob/master/docs/stdlib.md) module. Compile will report a compile-time error if the code tries to import the module with the given name. @@ -157,6 +144,8 @@ s.DisableStdModule("exec") _, err := s.Run() // compile error ``` +Note that when a script is being added to another script as a module (via `Script.AddModule`), it does not inherit the disabled standard module list from the main script. + #### Script.SetUserModuleLoader(loader compiler.ModuleLoader) SetUserModuleLoader replaces the default user-module loader of the compiler, which tries to read the source from a local file. @@ -173,6 +162,8 @@ s.SetUserModuleLoader(func(moduleName string) ([]byte, error) { }) ``` +Note that when a script is being added to another script as a module (via `Script.AddModule`), it does not inherit the module loader from the main script. + ## Compiler and VM Although it's not recommended, you can directly create and run the Tengo [Parser](https://godoc.org/github.com/d5/tengo/compiler/parser#Parser), [Compiler](https://godoc.org/github.com/d5/tengo/compiler#Compiler), and [VM](https://godoc.org/github.com/d5/tengo/runtime#VM) for yourself instead of using Scripts and Script Variables. It's a bit more involved as you have to manage the symbol tables and global variables between them, but, basically that's what Script and Script Variable is doing internally. diff --git a/docs/tutorial.md b/docs/tutorial.md index 56088f8..5f446c8 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -217,22 +217,32 @@ if is_error(err1) { // 'is_error' builtin function You can load other scripts as modules using `import` expression. Main script: -```golang -mod1 := import("./mod1") // assuming mod1.tengo file exists in the current directory - // same as 'import("./mod1.tengo")' or 'import("mod1")' -mod1.func1(a) // module function -a += mod1.foo // module variable -//mod1.foo = 5 // error: module variables are read-only -``` - -`mod1.tengo` file: ```golang -func1 := func(x) { print(x) } -foo := 2 +sum := import("./sum") // assuming sum.tengo file exists in the current directory + // same as 'import("./sum.tengo")' or 'import("sum")' +print(sum(10)) // module function ``` -Basically, `import` expression returns all the global variables defined in the module as an ImmutableMap value. +`sum.tengo` file: + +```golang +base := 5 + +export func(x) { + return x + base +} +``` + +In Tengo, modules are very similar to functions. + +- `import` expression loads the module and execute like a function. +- Module should return a value using `export` statement. + - Module can return `export` any Tengo objects: int, string, map, array, function, etc. + - `export` in a module is like `return` in a function: it stops execution and return a value to the importing code. + - `export`-ed values are always immutable. + - If the module does not have any `export` statement, `import` expression simply returns `undefined`. _(Just like the function that has no `return`.)_ + - Note that `export` statement is completely ignored and not evaluated if the code is executed as a regular script. Also, you can use `import` to load the [Standard Library](https://github.com/d5/tengo/blob/master/docs/stdlib.md). diff --git a/objects/builtin_len.go b/objects/builtin_len.go index fd6529a..f749363 100644 --- a/objects/builtin_len.go +++ b/objects/builtin_len.go @@ -12,6 +12,8 @@ func builtinLen(args ...Object) (Object, error) { switch arg := args[0].(type) { case *Array: return &Int{Value: int64(len(arg.Value))}, nil + case *ImmutableArray: + return &Int{Value: int64(len(arg.Value))}, nil case *String: return &Int{Value: int64(len(arg.Value))}, nil case *Bytes: diff --git a/objects/compiled_module.go b/objects/compiled_module.go index d531500..76e0ea6 100644 --- a/objects/compiled_module.go +++ b/objects/compiled_module.go @@ -6,8 +6,8 @@ import ( // CompiledModule represents a compiled module. type CompiledModule struct { - Instructions []byte // compiled instructions - Globals map[string]int // global variable name-to-index map + Instructions []byte // compiled instructions + NumGlobals int } // TypeName returns the name of the type. @@ -27,14 +27,9 @@ func (o *CompiledModule) BinaryOp(op token.Token, rhs Object) (Object, error) { // Copy returns a copy of the type. func (o *CompiledModule) Copy() Object { - globals := make(map[string]int, len(o.Globals)) - for name, index := range o.Globals { - globals[name] = index - } - return &CompiledModule{ Instructions: append([]byte{}, o.Instructions...), - Globals: globals, + NumGlobals: o.NumGlobals, } } diff --git a/runtime/vm.go b/runtime/vm.go index ba0057a..cc1990b 100644 --- a/runtime/vm.go +++ b/runtime/vm.go @@ -47,10 +47,6 @@ type VM struct { func NewVM(bytecode *compiler.Bytecode, globals []*objects.Object) *VM { if globals == nil { globals = make([]*objects.Object, GlobalsSize) - } else if len(globals) < GlobalsSize { - g := make([]*objects.Object, GlobalsSize) - copy(g, globals) - globals = g } frames := make([]Frame, MaxFrames) @@ -648,7 +644,7 @@ func (v *VM) Run() error { case *objects.Array: numElements := int64(len(left.Value)) - if lowIdx < 0 || lowIdx >= numElements { + if lowIdx < 0 || lowIdx > numElements { return fmt.Errorf("index out of bounds: %d", lowIdx) } if highIdx < 0 { @@ -673,7 +669,7 @@ func (v *VM) Run() error { case *objects.ImmutableArray: numElements := int64(len(left.Value)) - if lowIdx < 0 || lowIdx >= numElements { + if lowIdx < 0 || lowIdx > numElements { return fmt.Errorf("index out of bounds: %d", lowIdx) } if highIdx < 0 { @@ -698,7 +694,7 @@ func (v *VM) Run() error { case *objects.String: numElements := int64(len(left.Value)) - if lowIdx < 0 || lowIdx >= numElements { + if lowIdx < 0 || lowIdx > numElements { return fmt.Errorf("index out of bounds: %d", lowIdx) } if highIdx < 0 { @@ -822,14 +818,15 @@ func (v *VM) Run() error { v.curIPLimit = len(v.curInsts) - 1 v.ip = v.curFrame.ip - v.sp = lastFrame.basePointer - 1 + //v.sp = lastFrame.basePointer - 1 + v.sp = lastFrame.basePointer - if v.sp >= StackSize { + if v.sp-1 >= StackSize { return ErrStackOverflow } - v.stack[v.sp] = undefinedPtr - v.sp++ + v.stack[v.sp-1] = undefinedPtr + //v.sp++ case compiler.OpDefineLocal: localIndex := int(v.curInsts[v.ip+1]) @@ -1005,14 +1002,6 @@ func (v *VM) Run() error { v.stack[v.sp] = &val v.sp++ - case compiler.OpModule: - cidx := int(v.curInsts[v.ip+2]) | int(v.curInsts[v.ip+1])<<8 - v.ip += 2 - - if err := v.importModule(v.constants[cidx].(*objects.CompiledModule)); err != nil { - return err - } - default: return fmt.Errorf("unknown opcode: %d", v.curInsts[v.ip]) } @@ -1020,7 +1009,7 @@ func (v *VM) Run() error { // check if stack still has some objects left if v.sp > 0 && atomic.LoadInt64(&v.aborting) == 0 { - return fmt.Errorf("non empty stack after execution") + return fmt.Errorf("non empty stack after execution: %d", v.sp) } return nil @@ -1033,7 +1022,7 @@ func (v *VM) Globals() []*objects.Object { // FrameInfo returns the current function call frame information. func (v *VM) FrameInfo() (frameIndex, ip int) { - return v.framesIndex - 1, v.frames[v.framesIndex-1].ip + return v.framesIndex - 1, v.ip } func (v *VM) pushClosure(constIndex, numFree int) error { @@ -1160,35 +1149,6 @@ func (v *VM) callFunction(fn *objects.CompiledFunction, freeVars []*objects.Obje return nil } -// TODO: should reuse *objects.ImmutableMap for the same imports? -func (v *VM) importModule(compiledModule *objects.CompiledModule) error { - // import module is basically to create a new instance of VM - // and run the module code and retrieve all global variables after execution. - moduleVM := NewVM(&compiler.Bytecode{ - Instructions: compiledModule.Instructions, - Constants: v.constants, - }, nil) - if err := moduleVM.Run(); err != nil { - return err - } - - mmValue := make(map[string]objects.Object) - for name, index := range compiledModule.Globals { - mmValue[name] = *moduleVM.globals[index] - } - - var mm objects.Object = &objects.ImmutableMap{Value: mmValue} - - if v.sp >= StackSize { - return ErrStackOverflow - } - - v.stack[v.sp] = &mm - v.sp++ - - return nil -} - func indexAssign(dst, src *objects.Object, selectors []*objects.Object) error { numSel := len(selectors) diff --git a/runtime/vm_builtin_test.go b/runtime/vm_builtin_test.go index aebbe78..0eef432 100644 --- a/runtime/vm_builtin_test.go +++ b/runtime/vm_builtin_test.go @@ -10,6 +10,14 @@ func TestBuiltinFunction(t *testing.T) { expect(t, `out = len("")`, 0) expect(t, `out = len("four")`, 4) expect(t, `out = len("hello world")`, 11) + expect(t, `out = len([])`, 0) + expect(t, `out = len([1, 2, 3])`, 3) + expect(t, `out = len({})`, 0) + expect(t, `out = len({a:1, b:2})`, 2) + expect(t, `out = len(immutable([]))`, 0) + expect(t, `out = len(immutable([1, 2, 3]))`, 3) + expect(t, `out = len(immutable({}))`, 0) + expect(t, `out = len(immutable({a:1, b:2}))`, 2) expectError(t, `len(1)`) expectError(t, `len("one", "two")`) diff --git a/runtime/vm_function_test.go b/runtime/vm_function_test.go index 4d4d115..158bdc3 100644 --- a/runtime/vm_function_test.go +++ b/runtime/vm_function_test.go @@ -20,35 +20,35 @@ func TestFunction(t *testing.T) { expect(t, `x := 10; f := func(x) { return x; }; f(5); out = x;`, 10) expect(t, ` -f2 := func(a) { - f1 := func(a) { - return a * 2; + f2 := func(a) { + f1 := func(a) { + return a * 2; + }; + + return f1(a) * 3; }; - - return f1(a) * 3; -}; - -out = f2(10); -`, 60) + + out = f2(10); + `, 60) // closures expect(t, ` - newAdder := func(x) { - return func(y) { return x + y }; - }; + newAdder := func(x) { + return func(y) { return x + y }; + }; - add2 := newAdder(2); - out = add2(5); - `, 7) + add2 := newAdder(2); + out = add2(5); + `, 7) // function as a argument expect(t, ` -add := func(a, b) { return a + b }; -sub := func(a, b) { return a - b }; -applyFunc := func(a, b, f) { return f(a, b) }; - -out = applyFunc(applyFunc(2, 2, add), 3, sub); -`, 1) + add := func(a, b) { return a + b }; + sub := func(a, b) { return a - b }; + applyFunc := func(a, b, f) { return f(a, b) }; + + out = applyFunc(applyFunc(2, 2, add), 3, sub); + `, 1) expect(t, `f1 := func() { return 5 + 10; }; out = f1();`, 15) expect(t, `f1 := func() { return 1 }; f2 := func() { return 2 }; out = f1() + f2()`, 3) @@ -60,170 +60,163 @@ out = applyFunc(applyFunc(2, 2, add), 3, sub); expect(t, `three := func() { one := 1; two := 2; return one + two }; out = three()`, 3) expect(t, `three := func() { one := 1; two := 2; return one + two }; seven := func() { three := 3; four := 4; return three + four }; out = three() + seven()`, 10) expect(t, ` -foo1 := func() { - foo := 50 - return foo -} -foo2 := func() { - foo := 100 - return foo -} -out = foo1() + foo2()`, 150) + foo1 := func() { + foo := 50 + return foo + } + foo2 := func() { + foo := 100 + return foo + } + out = foo1() + foo2()`, 150) expect(t, ` -g := 50; -minusOne := func() { - n := 1; - return g - n; -}; -minusTwo := func() { - n := 2; - return g - n; -}; -out = minusOne() + minusTwo() -`, 97) + g := 50; + minusOne := func() { + n := 1; + return g - n; + }; + minusTwo := func() { + n := 2; + return g - n; + }; + out = minusOne() + minusTwo() + `, 97) expect(t, ` -f1 := func() { - f2 := func() { return 1; } - return f2 -}; -out = f1()() -`, 1) + f1 := func() { + f2 := func() { return 1; } + return f2 + }; + out = f1()() + `, 1) expect(t, ` -f1 := func(a) { return a; }; -out = f1(4)`, 4) + f1 := func(a) { return a; }; + out = f1(4)`, 4) expect(t, ` -f1 := func(a, b) { return a + b; }; -out = f1(1, 2)`, 3) + f1 := func(a, b) { return a + b; }; + out = f1(1, 2)`, 3) expect(t, ` -sum := func(a, b) { - c := a + b; - return c; -}; -out = sum(1, 2);`, 3) + sum := func(a, b) { + c := a + b; + return c; + }; + out = sum(1, 2);`, 3) expect(t, ` -sum := func(a, b) { - c := a + b; - return c; -}; -out = sum(1, 2) + sum(3, 4);`, 10) + sum := func(a, b) { + c := a + b; + return c; + }; + out = sum(1, 2) + sum(3, 4);`, 10) expect(t, ` -sum := func(a, b) { - c := a + b - return c -}; -outer := func() { - return sum(1, 2) + sum(3, 4) -}; -out = outer();`, 10) + sum := func(a, b) { + c := a + b + return c + }; + outer := func() { + return sum(1, 2) + sum(3, 4) + }; + out = outer();`, 10) expect(t, ` -g := 10; - -sum := func(a, b) { - c := a + b; - return c + g; -} - -outer := func() { - return sum(1, 2) + sum(3, 4) + g; -} - -out = outer() + g -`, 50) + g := 10; + + sum := func(a, b) { + c := a + b; + return c + g; + } + + outer := func() { + return sum(1, 2) + sum(3, 4) + g; + } + + out = outer() + g + `, 50) expectError(t, `func() { return 1; }(1)`) expectError(t, `func(a) { return a; }()`) expectError(t, `func(a, b) { return a + b; }(1)`) expect(t, ` - f1 := func(a) { - return func() { return a; }; - }; - f2 := f1(99); - out = f2() - `, 99) - - expect(t, ` - f1 := func(a, b) { - return func(c) { return a + b + c }; - }; - - f2 := f1(1, 2); - out = f2(8); - `, 11) - expect(t, ` - f1 := func(a, b) { - c := a + b; - return func(d) { return c + d }; - }; - f2 := f1(1, 2); - out = f2(8); - `, 11) - expect(t, ` - f1 := func(a, b) { - c := a + b; - return func(d) { - e := d + c; - return func(f) { return e + f }; - } - }; - f2 := f1(1, 2); - f3 := f2(3); - out = f3(8); - `, 14) - expect(t, ` - a := 1; - f1 := func(b) { - return func(c) { - return func(d) { return a + b + c + d } + f1 := func(a) { + return func() { return a; }; }; - }; - f2 := f1(2); - f3 := f2(3); - out = f3(8); - `, 14) - expect(t, ` - f1 := func(a, b) { - one := func() { return a; }; - two := func() { return b; }; - return func() { return one() + two(); } - }; - f2 := f1(9, 90); - out = f2(); - `, 99) + f2 := f1(99); + out = f2() + `, 99) - // recursion expect(t, ` - fib := func(x) { - if x == 0 { - return 0 - } else if x == 1 { - return 1 - } else { - return fib(x-1) + fib(x-2) - } - } - out = fib(15)`, 610) + f1 := func(a, b) { + return func(c) { return a + b + c }; + }; + + f2 := f1(1, 2); + out = f2(8); + `, 11) + expect(t, ` + f1 := func(a, b) { + c := a + b; + return func(d) { return c + d }; + }; + f2 := f1(1, 2); + out = f2(8); + `, 11) + expect(t, ` + f1 := func(a, b) { + c := a + b; + return func(d) { + e := d + c; + return func(f) { return e + f }; + } + }; + f2 := f1(1, 2); + f3 := f2(3); + out = f3(8); + `, 14) + expect(t, ` + a := 1; + f1 := func(b) { + return func(c) { + return func(d) { return a + b + c + d } + }; + }; + f2 := f1(2); + f3 := f2(3); + out = f3(8); + `, 14) + expect(t, ` + f1 := func(a, b) { + one := func() { return a; }; + two := func() { return b; }; + return func() { return one() + two(); } + }; + f2 := f1(9, 90); + out = f2(); + `, 99) - // TODO: currently recursion inside the local scope function definition is not supported. - // Workaround is to define the identifier first then assign the function like below. - // Want to fix this. + // global function recursion expect(t, ` -func() { - fib := 0 - fib = func(x) { - if x == 0 { - return 0 - } else if x == 1 { - return 1 - } else { - return fib(x-1) + fib(x-2) + fib := func(x) { + if x == 0 { + return 0 + } else if x == 1 { + return 1 + } else { + return fib(x-1) + fib(x-2) + } } + out = fib(15)`, 610) + + // local function recursion + expect(t, ` +out = func() { + sum := func(x) { + return x == 0 ? 0 : x + sum(x-1) } - out = fib(15) -}()`, 610) + return sum(5) +}()`, 15) + + expectError(t, `return 5`) } diff --git a/runtime/vm_module_test.go b/runtime/vm_module_test.go index e059968..f500ae1 100644 --- a/runtime/vm_module_test.go +++ b/runtime/vm_module_test.go @@ -1,8 +1,12 @@ package runtime_test -import "testing" +import ( + "testing" -func TestModule(t *testing.T) { + "github.com/d5/tengo/objects" +) + +func TestStdLib(t *testing.T) { // stdlib expect(t, `math := import("math"); out = math.abs(1)`, 1.0) expect(t, `math := import("math"); out = math.abs(-1)`, 1.0) @@ -54,27 +58,101 @@ if !is_error(cmd) { } `, []byte("foo bar\n")) +} + +func TestUserModules(t *testing.T) { // user modules - expectWithUserModules(t, `out = import("mod1").bar()`, 5.0, map[string]string{ - "mod1": `bar := func() { return 5.0 }`, + + // export none + expectWithUserModules(t, `out = import("mod1")`, objects.UndefinedValue, map[string]string{ + "mod1": `fn := func() { return 5.0 }; a := 2`, }) + + // export values + expectWithUserModules(t, `out = import("mod1")`, 5, map[string]string{ + "mod1": `export 5`, + }) + expectWithUserModules(t, `out = import("mod1")`, "foo", map[string]string{ + "mod1": `export "foo"`, + }) + + // export composite types + expectWithUserModules(t, `out = import("mod1")`, IARR{1, 2, 3}, map[string]string{ + "mod1": `export [1, 2, 3]`, + }) + expectWithUserModules(t, `out = import("mod1")`, IMAP{"a": 1, "b": 2}, map[string]string{ + "mod1": `export {a: 1, b: 2}`, + }) + // export value is immutable + expectErrorWithUserModules(t, `m1 := import("mod1"); m1.a = 5`, map[string]string{ + "mod1": `export {a: 1, b: 2}`, + }) + expectErrorWithUserModules(t, `m1 := import("mod1"); m1[1] = 5`, map[string]string{ + "mod1": `export [1, 2, 3]`, + }) + + // code after export statement will not be executed + expectWithUserModules(t, `out = import("mod1")`, 10, map[string]string{ + "mod1": `a := 10; export a; a = 20`, + }) + expectWithUserModules(t, `out = import("mod1")`, 10, map[string]string{ + "mod1": `a := 10; export a; a = 20; export a`, + }) + + // export function + expectWithUserModules(t, `out = import("mod1")()`, 5.0, map[string]string{ + "mod1": `export func() { return 5.0 }`, + }) + // export function that reads module-global variable + expectWithUserModules(t, `out = import("mod1")()`, 6.5, map[string]string{ + "mod1": `a := 1.5; export func() { return a + 5.0 }`, + }) + // export function that read local variable + expectWithUserModules(t, `out = import("mod1")()`, 6.5, map[string]string{ + "mod1": `export func() { a := 1.5; return a + 5.0 }`, + }) + // export function that read free variables + expectWithUserModules(t, `out = import("mod1")()`, 6.5, map[string]string{ + "mod1": `export func() { a := 1.5; return func() { return a + 5.0 }() }`, + }) + + // recursive function in module + expectWithUserModules(t, `out = import("mod1")`, 15, map[string]string{ + "mod1": ` +a := func(x) { + return x == 0 ? 0 : x + a(x-1) +} + +export a(5) +`}) + expectWithUserModules(t, `out = import("mod1")`, 15, map[string]string{ + "mod1": ` +export func() { + a := func(x) { + return x == 0 ? 0 : x + a(x-1) + } + + return a(5) +}() +`}) + // (main) -> mod1 -> mod2 - expectWithUserModules(t, `out = import("mod1").mod2.bar()`, 5.0, map[string]string{ - "mod1": `mod2 := import("mod2")`, - "mod2": `bar := func() { return 5.0 }`, + expectWithUserModules(t, `out = import("mod1")()`, 5.0, map[string]string{ + "mod1": `export import("mod2")`, + "mod2": `export func() { return 5.0 }`, }) // (main) -> mod1 -> mod2 // -> mod2 - expectWithUserModules(t, `import("mod1"); out = import("mod2").bar()`, 5.0, map[string]string{ - "mod1": `mod2 := import("mod2")`, - "mod2": `bar := func() { return 5.0 }`, + expectWithUserModules(t, `import("mod1"); out = import("mod2")()`, 5.0, map[string]string{ + "mod1": `export import("mod2")`, + "mod2": `export func() { return 5.0 }`, }) // (main) -> mod1 -> mod2 -> mod3 // -> mod2 -> mod3 - expectWithUserModules(t, `import("mod1"); out = import("mod2").mod3.bar()`, 5.0, map[string]string{ - "mod1": `mod2 := import("mod2")`, - "mod2": `mod3 := import("mod3")`, - "mod3": `bar := func() { return 5.0 }`, + expectWithUserModules(t, `import("mod1"); out = import("mod2")()`, 5.0, map[string]string{ + "mod1": `export import("mod2")`, + "mod2": `export import("mod3")`, + "mod3": `export func() { return 5.0 }`, }) // cyclic imports @@ -104,26 +182,29 @@ if !is_error(cmd) { "mod1": `import("mod2")`, }) - // for-in - expectWithUserModules(t, `for _, n in import("mod1") { out += n }`, 6, map[string]string{ - "mod1": `a := 1; b := 2; c := 3`, - }) - expectWithUserModules(t, `for k, _ in import("mod1") { out += k }`, "a", map[string]string{ - "mod1": `a := 1`, // only 1 global variable because module map does not sort the keys - }) - - // mutating global variables inside the module does not affect exported values - expectWithUserModules(t, `m1 := import("mod1"); m1.mutate(); out = m1.a`, 3, map[string]string{ - "mod1": `a := 3; mutate := func() { a = 10 }`, - }) - - // module map is immutable - expectErrorWithUserModules(t, `m1 := import("mod1"); m1.a = 5`, map[string]string{ - "mod1": `a := 3`, - }) - // module is immutable but its variables is not necessarily immutable. expectWithUserModules(t, `m1 := import("mod1"); m1.a.b = 5; out = m1.a.b`, 5, map[string]string{ - "mod1": `a := {b: 3}`, + "mod1": `export {a: {b: 3}}`, + }) + + // make sure module has same builtin functions + expectWithUserModules(t, `out = import("mod1")`, "int", map[string]string{ + "mod1": `export func() { return type_name(0) }()`, + }) + + // 'export' statement is ignored outside module + expect(t, `a := 5; export func() { a = 10 }(); out = a`, 5) + + // 'export' must be in the top-level + expectErrorWithUserModules(t, `import("mod1")`, map[string]string{ + "mod1": `func() { export 5 }()`, + }) + expectErrorWithUserModules(t, `import("mod1")`, map[string]string{ + "mod1": `func() { func() { export 5 }() }()`, + }) + + // module cannot access outer scope + expectErrorWithUserModules(t, `a := 5; import("mod")`, map[string]string{ + "mod1": `export a`, }) } diff --git a/runtime/vm_string_test.go b/runtime/vm_string_test.go index d3e7985..76611ff 100644 --- a/runtime/vm_string_test.go +++ b/runtime/vm_string_test.go @@ -28,7 +28,7 @@ func TestString(t *testing.T) { expectError(t, fmt.Sprintf("%s[%d]", strStr, strLen)) // slice operator - for low := 0; low < strLen; low++ { + for low := 0; low <= strLen; low++ { for high := low; high <= strLen; high++ { expect(t, fmt.Sprintf("out = %s[%d:%d]", strStr, low, high), str[low:high]) expect(t, fmt.Sprintf("out = %s[0 + %d : 0 + %d]", strStr, low, high), str[low:high]) diff --git a/runtime/vm_test.go b/runtime/vm_test.go index 4b21e82..103e785 100644 --- a/runtime/vm_test.go +++ b/runtime/vm_test.go @@ -247,12 +247,18 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object, userModu bytecode := c.Bytecode() var constStr []string for cidx, cn := range bytecode.Constants { - if cmFn, ok := cn.(*objects.CompiledFunction); ok { + switch cn := cn.(type) { + case *objects.CompiledFunction: constStr = append(constStr, fmt.Sprintf("[% 3d] (Compiled Function|%p)", cidx, &cn)) - for _, l := range compiler.FormatInstructions(cmFn.Instructions, 0) { + for _, l := range compiler.FormatInstructions(cn.Instructions, 0) { constStr = append(constStr, fmt.Sprintf(" %s", l)) } - } else { + case *objects.CompiledModule: + constStr = append(constStr, fmt.Sprintf("[% 3d] (Compiled Module|%p)", cidx, &cn)) + for _, l := range compiler.FormatInstructions(cn.Instructions, 0) { + constStr = append(constStr, fmt.Sprintf(" %s", l)) + } + default: constStr = append(constStr, fmt.Sprintf("[% 3d] %s (%s|%p)", cidx, cn, reflect.TypeOf(cn).Elem().Name(), &cn)) } } @@ -273,22 +279,7 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object, userModu res[name] = *globals[sym.Index] } - var globalsStr []string - for gidx, g := range globals { - if g == nil { - break - } - - if cmFn, ok := (*g).(*objects.Closure); ok { - globalsStr = append(globalsStr, fmt.Sprintf("[% 3d] (Closure|%p)", gidx, g)) - for _, l := range compiler.FormatInstructions(cmFn.Fn.Instructions, 0) { - globalsStr = append(globalsStr, fmt.Sprintf(" %s", l)) - } - } else { - globalsStr = append(globalsStr, fmt.Sprintf("[% 3d] %s (%s|%p)", gidx, (*g).String(), reflect.TypeOf(*g).Elem().Name(), g)) - } - } - trace = append(trace, fmt.Sprintf("\n[Globals]\n\n%s", strings.Join(globalsStr, "\n"))) + trace = append(trace, fmt.Sprintf("\n[Globals]\n\n%s", strings.Join(formatGlobals(globals), "\n"))) frameIdx, ip := v.FrameInfo() trace = append(trace, fmt.Sprintf("\n[IP]\n\nFrame=%d, IP=%d", frameIdx, ip+1)) @@ -300,6 +291,26 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object, userModu return } +func formatGlobals(globals []*objects.Object) (formatted []string) { + for idx, global := range globals { + if global == nil { + return + } + + switch global := (*global).(type) { + case *objects.Closure: + formatted = append(formatted, fmt.Sprintf("[% 3d] (Closure|%p)", idx, global)) + for _, l := range compiler.FormatInstructions(global.Fn.Instructions, 0) { + formatted = append(formatted, fmt.Sprintf(" %s", l)) + } + default: + formatted = append(formatted, fmt.Sprintf("[% 3d] %s (%s|%p)", idx, global.String(), reflect.TypeOf(global).Elem().Name(), global)) + } + } + + return +} + func parse(t *testing.T, input string) *ast.File { testFileSet := source.NewFileSet() testFile := testFileSet.AddFile("", -1, len(input)) diff --git a/script/script.go b/script/script.go index 4bb1492..4b9ade9 100644 --- a/script/script.go +++ b/script/script.go @@ -17,7 +17,6 @@ type Script struct { variables map[string]*Variable removedBuiltins map[string]bool removedStdModules map[string]bool - scriptModules map[string]*Script userModuleLoader compiler.ModuleLoader input []byte } @@ -80,16 +79,6 @@ func (s *Script) SetUserModuleLoader(loader compiler.ModuleLoader) { s.userModuleLoader = loader } -// AddModule adds another script as a module. Script module will be -// compiled and run right before the main script s is compiled. -func (s *Script) AddModule(name string, scriptModule *Script) { - if s.scriptModules == nil { - s.scriptModules = make(map[string]*Script) - } - - s.scriptModules[name] = scriptModule -} - // Compile compiles the script with all the defined variables, and, returns Compiled object. func (s *Script) Compile() (*Compiled, error) { symbolTable, stdModules, globals, err := s.prepCompile() @@ -165,40 +154,8 @@ func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, stdModules ma stdModules[name] = mod } } - for name, scriptModule := range s.scriptModules { - if scriptModule == nil { - err = fmt.Errorf("script module must not be nil: %s", name) - } - var compiledModule *Compiled - compiledModule, err = scriptModule.Compile() - if err != nil { - return - } - - err = compiledModule.Run() - if err != nil { - return - } - - mod := &objects.ImmutableMap{ - Value: make(map[string]objects.Object), - } - - for _, symbolName := range compiledModule.symbolTable.Names() { - symbol, _, ok := compiledModule.symbolTable.Resolve(symbolName) - if ok && symbol.Scope == compiler.ScopeGlobal { - value := compiledModule.machine.Globals()[symbol.Index] - if value != nil { - mod.Value[symbolName] = *value - } - } - } - - stdModules[name] = mod - } - - globals = make([]*objects.Object, len(names), len(names)) + globals = make([]*objects.Object, runtime.GlobalsSize, runtime.GlobalsSize) for idx, name := range names { symbol := symbolTable.Define(name) diff --git a/script/script_module_test.go b/script/script_module_test.go index 3203a15..b979334 100644 --- a/script/script_module_test.go +++ b/script/script_module_test.go @@ -1,30 +1,56 @@ package script_test import ( + "errors" "testing" "github.com/d5/tengo/assert" "github.com/d5/tengo/script" ) -func TestScript_AddModule(t *testing.T) { - // mod1 module - mod1 := script.New([]byte(`a := 5`)) - +func TestScript_SetUserModuleLoader(t *testing.T) { // script1 imports "mod1" - scr1 := script.New([]byte(`mod1 := import("mod1"); out := mod1.a`)) - scr1.AddModule("mod1", mod1) - c, err := scr1.Run() + scr := script.New([]byte(`out := import("mod")`)) + scr.SetUserModuleLoader(func(name string) ([]byte, error) { + return []byte(`export 5`), nil + }) + c, err := scr.Run() assert.Equal(t, int64(5), c.Get("out").Value()) - // mod2 module imports "mod1" - mod2 := script.New([]byte(`mod1 := import("mod1"); b := mod1.a * 2`)) - mod2.AddModule("mod1", mod1) - - // script2 imports "mod2" (which imports "mod1") - scr2 := script.New([]byte(`mod2 := import("mod2"); out := mod2.b`)) - scr2.AddModule("mod2", mod2) - c, err = scr2.Run() + // executing module function + scr = script.New([]byte(`fn := import("mod"); out := fn()`)) + scr.SetUserModuleLoader(func(name string) ([]byte, error) { + return []byte(`a := 3; export func() { return a + 5 }`), nil + }) + c, err = scr.Run() assert.NoError(t, err) - assert.Equal(t, int64(10), c.Get("out").Value()) + assert.Equal(t, int64(8), c.Get("out").Value()) + + // disabled builtin function + scr = script.New([]byte(`out := import("mod")`)) + scr.SetUserModuleLoader(func(name string) ([]byte, error) { + return []byte(`export len([1, 2, 3])`), nil + }) + c, err = scr.Run() + assert.NoError(t, err) + assert.Equal(t, int64(3), c.Get("out").Value()) + scr.DisableBuiltinFunction("len") + _, err = scr.Run() + assert.Error(t, err) + + // disabled stdlib + scr = script.New([]byte(`out := import("mod")`)) + scr.SetUserModuleLoader(func(name string) ([]byte, error) { + if name == "mod" { + return []byte(`text := import("text"); export text.title("foo")`), nil + } + return nil, errors.New("module not found") + }) + c, err = scr.Run() + assert.NoError(t, err) + assert.Equal(t, "Foo", c.Get("out").Value()) + scr.DisableStdModule("text") + _, err = scr.Run() + assert.Error(t, err) + } diff --git a/script/script_test.go b/script/script_test.go index 142a9d9..ee2b7a9 100644 --- a/script/script_test.go +++ b/script/script_test.go @@ -1,7 +1,6 @@ package script_test import ( - "errors" "testing" "github.com/d5/tengo/assert" @@ -59,22 +58,3 @@ func TestScript_DisableStdModule(t *testing.T) { _, err = s.Run() assert.Error(t, err) } - -func TestScript_SetUserModuleLoader(t *testing.T) { - s := script.New([]byte(`math := import("mod1"); a := math.foo()`)) - _, err := s.Run() - assert.Error(t, err) - s.SetUserModuleLoader(func(moduleName string) (res []byte, err error) { - if moduleName == "mod1" { - res = []byte(`foo := func() { return 5 }`) - return - } - - err = errors.New("module not found") - return - }) - c, err := s.Run() - assert.NoError(t, err) - assert.NotNil(t, c) - compiledGet(t, c, "a", int64(5)) -}