diff --git a/cmd/bench/main.go b/cmd/bench/main.go index e275608..17a8093 100644 --- a/cmd/bench/main.go +++ b/cmd/bench/main.go @@ -193,7 +193,7 @@ func compileFile(file *ast.File) (time.Duration, *compiler.Bytecode, error) { start := time.Now() - c := compiler.NewCompiler(symTable, nil) + c := compiler.NewCompiler(symTable, nil, nil) if err := c.Compile(file); err != nil { return time.Since(start), nil, err } diff --git a/cmd/tengo/main.go b/cmd/tengo/main.go index 777321c..5ebb054 100644 --- a/cmd/tengo/main.go +++ b/cmd/tengo/main.go @@ -171,7 +171,11 @@ func doRepl(in io.Reader, out io.Writer) { fileSet := source.NewFileSet() globals := make([]*objects.Object, runtime.GlobalsSize) + symbolTable := compiler.NewSymbolTable() + for idx, fn := range objects.Builtins { + symbolTable.DefineBuiltin(idx, fn.Name) + } for { _, _ = fmt.Fprintf(out, replPrompt) @@ -191,7 +195,7 @@ func doRepl(in io.Reader, out io.Writer) { file = addPrints(file) - c := compiler.NewCompiler(symbolTable, nil) + c := compiler.NewCompiler(symbolTable, nil, nil) if err := c.Compile(file); err != nil { _, _ = fmt.Fprintf(out, "Compilation error:\n %s\n", err.Error()) continue @@ -218,7 +222,7 @@ func compileSrc(src []byte, filename string) (*compiler.Bytecode, error) { return nil, err } - c := compiler.NewCompiler(nil, nil) + c := compiler.NewCompiler(nil, nil, nil) if err := c.Compile(file); err != nil { return nil, err } diff --git a/compiler/compiler.go b/compiler/compiler.go index ba5a26b..1b3b331 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -20,6 +20,7 @@ type Compiler struct { scopes []CompilationScope scopeIndex int moduleLoader ModuleLoader + stdModules map[string]*objects.ImmutableMap compiledModules map[string]*objects.CompiledModule loops []*Loop loopIndex int @@ -28,17 +29,28 @@ type Compiler struct { } // NewCompiler creates a Compiler. -func NewCompiler(symbolTable *SymbolTable, trace io.Writer) *Compiler { +// User can optionally provide the symbol table if one wants to add or remove +// some global- or builtin- scope symbols. If not (nil), Compile will create +// a new symbol table and use the default builtin functions. Likewise, standard +// modules can be explicitly provided if user wants to add or remove some modules. +// By default, Compile will use all the standard modules otherwise. +func NewCompiler(symbolTable *SymbolTable, stdModules map[string]*objects.ImmutableMap, trace io.Writer) *Compiler { mainScope := CompilationScope{ instructions: make([]byte, 0), } + // symbol table if symbolTable == nil { symbolTable = NewSymbolTable() + + for idx, fn := range objects.Builtins { + symbolTable.DefineBuiltin(idx, fn.Name) + } } - for idx, fn := range objects.Builtins { - symbolTable.DefineBuiltin(idx, fn.Name) + // standard modules + if stdModules == nil { + stdModules = stdlib.Modules } return &Compiler{ @@ -47,6 +59,7 @@ func NewCompiler(symbolTable *SymbolTable, trace io.Writer) *Compiler { scopeIndex: 0, loopIndex: -1, trace: trace, + stdModules: stdModules, compiledModules: make(map[string]*objects.CompiledModule), } } @@ -440,7 +453,7 @@ func (c *Compiler) Compile(node ast.Node) error { c.emit(OpCall, len(node.Args)) case *ast.ImportExpr: - stdMod, ok := stdlib.Modules[node.ModuleName] + stdMod, ok := c.stdModules[node.ModuleName] if ok { // standard modules contain only globals with no code. // so no need to compile anything @@ -474,12 +487,14 @@ func (c *Compiler) Bytecode() *Bytecode { } // SetModuleLoader sets or replaces the current module loader. +// Note that the module loader is used for user modules, +// not for the standard modules. func (c *Compiler) SetModuleLoader(moduleLoader ModuleLoader) { c.moduleLoader = moduleLoader } func (c *Compiler) fork(moduleName string, symbolTable *SymbolTable) *Compiler { - child := NewCompiler(symbolTable, c.trace) + child := NewCompiler(symbolTable, c.stdModules, c.trace) child.moduleName = moduleName // name of the module to compile child.parent = c // parent to set to current compiler child.moduleLoader = c.moduleLoader // share module loader diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index c63b053..5dfefe5 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -972,9 +972,12 @@ func traceCompile(input string, symbols map[string]objects.Object) (res *compile for name := range symbols { symTable.Define(name) } + for idx, fn := range objects.Builtins { + symTable.DefineBuiltin(idx, fn.Name) + } tr := &tracer{} - c := compiler.NewCompiler(symTable, tr) + c := compiler.NewCompiler(symTable, nil, tr) parsed, err := p.ParseFile() if err != nil { return diff --git a/objects/builtins.go b/objects/builtins.go index acf2419..fe7085e 100644 --- a/objects/builtins.go +++ b/objects/builtins.go @@ -1,10 +1,13 @@ package objects -// Builtins contains all known builtin functions. -var Builtins = []struct { +// NamedBuiltinFunc is a named builtin function. +type NamedBuiltinFunc struct { Name string Func CallableFunc -}{ +} + +// Builtins contains all default builtin functions. +var Builtins = []NamedBuiltinFunc{ { Name: "print", Func: builtinPrint, diff --git a/runtime/vm_test.go b/runtime/vm_test.go index d14eaef..f3f0722 100644 --- a/runtime/vm_test.go +++ b/runtime/vm_test.go @@ -176,9 +176,12 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object, userModu sym := symTable.Define(name) globals[sym.Index] = &value } + for idx, fn := range objects.Builtins { + symTable.DefineBuiltin(idx, fn.Name) + } tr := &tracer{} - c := compiler.NewCompiler(symTable, tr) + c := compiler.NewCompiler(symTable, nil, tr) c.SetModuleLoader(func(moduleName string) ([]byte, error) { if src, ok := userModules[moduleName]; ok { return []byte(src), nil diff --git a/script/script.go b/script/script.go index 32933eb..418b3f3 100644 --- a/script/script.go +++ b/script/script.go @@ -63,7 +63,7 @@ func (s *Script) Compile() (*Compiled, error) { return nil, fmt.Errorf("parse error: %s", err.Error()) } - c := compiler.NewCompiler(symbolTable, nil) + c := compiler.NewCompiler(symbolTable, nil, nil) if err := c.Compile(file); err != nil { return nil, err } @@ -94,6 +94,10 @@ func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, globals []*ob } symbolTable = compiler.NewSymbolTable() + for idx, fn := range objects.Builtins { + symbolTable.DefineBuiltin(idx, fn.Name) + } + globals = make([]*objects.Object, len(names), len(names)) for idx, name := range names {