From 306055ad65115f63cdfc81cd47e210f9a8a1204c Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 28 Feb 2019 18:41:29 -0800 Subject: [PATCH] add helper functions for builtin functions and builtin modules (#122) * add helper functions for builtin functions and builtin modules * fix a builtin function bug for modules --- compiler/compiler_module.go | 7 ++-- compiler/symbol_table.go | 28 ++++++++++++---- objects/builtins.go | 35 ++++++++++++++++++++ objects/builtins_test.go | 65 +++++++++++++++++++++++++++++++++++++ script/script_test.go | 16 +++++++++ stdlib/stdlib.go | 21 ++++++++++++ stdlib/stdlib_test.go | 32 ++++++++++++++++++ 7 files changed, 193 insertions(+), 11 deletions(-) create mode 100644 objects/builtins_test.go diff --git a/compiler/compiler_module.go b/compiler/compiler_module.go index 8f63abb..d069bfa 100644 --- a/compiler/compiler_module.go +++ b/compiler/compiler_module.go @@ -77,11 +77,8 @@ func (c *Compiler) doCompileModule(moduleName string, src []byte) (*objects.Comp symbolTable := NewSymbolTable() // inherit builtin functions - for idx, fn := range objects.Builtins { - s, _, ok := c.symbolTable.Resolve(fn.Name) - if ok && s.Scope == ScopeBuiltin { - symbolTable.DefineBuiltin(idx, fn.Name) - } + for _, sym := range c.symbolTable.BuiltinSymbols() { + symbolTable.DefineBuiltin(sym.Index, sym.Name) } // no global scope for the module diff --git a/compiler/symbol_table.go b/compiler/symbol_table.go index da55a82..fb029b3 100644 --- a/compiler/symbol_table.go +++ b/compiler/symbol_table.go @@ -2,12 +2,13 @@ package compiler // SymbolTable represents a symbol table. type SymbolTable struct { - parent *SymbolTable - block bool - store map[string]*Symbol - numDefinition int - maxDefinition int - freeSymbols []*Symbol + parent *SymbolTable + block bool + store map[string]*Symbol + numDefinition int + maxDefinition int + freeSymbols []*Symbol + builtinSymbols []*Symbol } // NewSymbolTable creates a SymbolTable. @@ -37,6 +38,10 @@ func (t *SymbolTable) Define(name string) *Symbol { // DefineBuiltin adds a symbol for builtin function. func (t *SymbolTable) DefineBuiltin(index int, name string) *Symbol { + if t.parent != nil { + return t.parent.DefineBuiltin(index, name) + } + symbol := &Symbol{ Name: name, Index: index, @@ -45,6 +50,8 @@ func (t *SymbolTable) DefineBuiltin(index int, name string) *Symbol { t.store[name] = symbol + t.builtinSymbols = append(t.builtinSymbols, symbol) + return symbol } @@ -101,6 +108,15 @@ func (t *SymbolTable) FreeSymbols() []*Symbol { return t.freeSymbols } +// BuiltinSymbols returns builtin symbols for the scope. +func (t *SymbolTable) BuiltinSymbols() []*Symbol { + if t.parent != nil { + return t.parent.BuiltinSymbols() + } + + return t.builtinSymbols +} + // Names returns the name of all the symbols. func (t *SymbolTable) Names() []string { var names []string diff --git a/objects/builtins.go b/objects/builtins.go index c8b2a37..42c1a75 100644 --- a/objects/builtins.go +++ b/objects/builtins.go @@ -1,6 +1,7 @@ package objects // Builtins contains all default builtin functions. +// Use GetBuiltinFunctions instead of accessing Builtins directly. var Builtins = []BuiltinFunction{ { Name: "print", @@ -127,3 +128,37 @@ var Builtins = []BuiltinFunction{ Value: builtinTypeName, }, } + +// AllBuiltinFunctionNames returns a list of all default builtin function names. +func AllBuiltinFunctionNames() []string { + var names []string + for _, bf := range Builtins { + names = append(names, bf.Name) + } + return names +} + +// GetBuiltinFunctions returns a slice of builtin function objects. +// GetBuiltinFunctions removes the duplicate names, and, the returned builtin functions +// are not guaranteed to be in the same order as names. +func GetBuiltinFunctions(names ...string) []*BuiltinFunction { + include := make(map[string]bool) + for _, name := range names { + include[name] = true + } + + var builtinFuncs []*BuiltinFunction + for _, bf := range Builtins { + if include[bf.Name] { + bf := bf + builtinFuncs = append(builtinFuncs, &bf) + } + } + + return builtinFuncs +} + +// GetAllBuiltinFunctions returns all builtin functions. +func GetAllBuiltinFunctions() []*BuiltinFunction { + return GetBuiltinFunctions(AllBuiltinFunctionNames()...) +} diff --git a/objects/builtins_test.go b/objects/builtins_test.go new file mode 100644 index 0000000..4dfe036 --- /dev/null +++ b/objects/builtins_test.go @@ -0,0 +1,65 @@ +package objects_test + +import ( + "testing" + + "github.com/d5/tengo/assert" + "github.com/d5/tengo/objects" +) + +func TestGetBuiltinFunctions(t *testing.T) { + testGetBuiltinFunctions(t) + testGetBuiltinFunctions(t, "print") + testGetBuiltinFunctions(t, "int", "float") + testGetBuiltinFunctions(t, "int", "float", "printf") + testGetBuiltinFunctions(t, "int", "int") // duplicate names ignored +} + +func TestGetAllBuiltinFunctions(t *testing.T) { + funcs := objects.GetAllBuiltinFunctions() + if !assert.Equal(t, len(objects.Builtins), len(funcs)) { + return + } + + namesM := make(map[string]bool) + for _, bf := range objects.Builtins { + namesM[bf.Name] = true + } + + for _, bf := range funcs { + assert.True(t, namesM[bf.Name], "name: %s", bf.Name) + } +} + +func TestAllBuiltinFunctionNames(t *testing.T) { + names := objects.AllBuiltinFunctionNames() + if !assert.Equal(t, len(objects.Builtins), len(names)) { + return + } + + namesM := make(map[string]bool) + for _, name := range names { + namesM[name] = true + } + + for _, bf := range objects.Builtins { + assert.True(t, namesM[bf.Name], "name: %s", bf.Name) + } +} + +func testGetBuiltinFunctions(t *testing.T, names ...string) { + // remove duplicates + namesM := make(map[string]bool) + for _, name := range names { + namesM[name] = true + } + + funcs := objects.GetBuiltinFunctions(names...) + if !assert.Equal(t, len(namesM), len(funcs)) { + return + } + + for _, bf := range funcs { + assert.True(t, namesM[bf.Name], "name: %s", bf.Name) + } +} diff --git a/script/script_test.go b/script/script_test.go index b9fa09c..5f254f1 100644 --- a/script/script_test.go +++ b/script/script_test.go @@ -1,6 +1,7 @@ package script_test import ( + "errors" "testing" "github.com/d5/tengo/assert" @@ -60,6 +61,21 @@ func TestScript_SetBuiltinFunctions(t *testing.T) { s.SetBuiltinFunctions(nil) _, err = s.Run() assert.Error(t, err) + + s = script.New([]byte(`a := import("b")`)) + s.SetUserModuleLoader(func(name string) ([]byte, error) { + if name == "b" { + return []byte(`export import("c")`), nil + } else if name == "c" { + return []byte("export len([1, 2, 3])"), nil + } + return nil, errors.New("module not found") + }) + s.SetBuiltinFunctions([]*objects.BuiltinFunction{&objects.Builtins[3]}) + c, err = s.Run() + assert.NoError(t, err) + assert.NotNil(t, c) + compiledGet(t, c, "a", int64(3)) } func TestScript_SetBuiltinModules(t *testing.T) { diff --git a/stdlib/stdlib.go b/stdlib/stdlib.go index d34fbc8..f5181a2 100644 --- a/stdlib/stdlib.go +++ b/stdlib/stdlib.go @@ -11,6 +11,27 @@ var Modules = map[string]*objects.Object{ "rand": objectPtr(&objects.ImmutableMap{Value: randModule}), } +// AllModuleNames returns a list of all default module names. +func AllModuleNames() []string { + var names []string + for name := range Modules { + names = append(names, name) + } + return names +} + +// GetModules returns the modules for the given names. +// Duplicate names and invalid names are ignore. +func GetModules(names ...string) map[string]*objects.ImmutableMap { + modules := make(map[string]*objects.ImmutableMap) + for _, name := range names { + if mod := Modules[name]; mod != nil { + modules[name] = (*mod).(*objects.ImmutableMap) + } + } + return modules +} + func objectPtr(o objects.Object) *objects.Object { return &o } diff --git a/stdlib/stdlib_test.go b/stdlib/stdlib_test.go index f0ba6a6..bd3ac4c 100644 --- a/stdlib/stdlib_test.go +++ b/stdlib/stdlib_test.go @@ -15,6 +15,38 @@ type MAP = map[string]interface{} type IARR []interface{} type IMAP map[string]interface{} +func TestAllModuleNames(t *testing.T) { + names := stdlib.AllModuleNames() + if !assert.Equal(t, len(stdlib.Modules), len(names)) { + return + } + for _, name := range names { + assert.NotNil(t, stdlib.Modules[name], "name: %s", name) + } +} + +func TestGetModules(t *testing.T) { + mods := stdlib.GetModules() + assert.Equal(t, 0, len(mods)) + + mods = stdlib.GetModules("os") + assert.Equal(t, 1, len(mods)) + assert.NotNil(t, mods["os"]) + + mods = stdlib.GetModules("os", "rand") + assert.Equal(t, 2, len(mods)) + assert.NotNil(t, mods["os"]) + assert.NotNil(t, mods["rand"]) + + mods = stdlib.GetModules("text", "text") + assert.Equal(t, 1, len(mods)) + assert.NotNil(t, mods["text"]) + + mods = stdlib.GetModules("nonexisting", "text") + assert.Equal(t, 1, len(mods)) + assert.NotNil(t, mods["text"]) +} + type callres struct { t *testing.T o objects.Object