fix bytecode encoding/decoding of builtin modules (#154)

* fix bytecode encoding/decoding of builtin modules

* Bytecode.Decode() to take map[string]objects.Importable

* add objects.ModuleMap

* update docs

* stdlib.GetModuleMap()
This commit is contained in:
Daniel 2019-03-20 01:28:40 -07:00 committed by GitHub
parent e785e38bf8
commit 3c30109cd0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 308 additions and 140 deletions

View file

@ -41,7 +41,7 @@ type Options struct {
Version string
// Import modules
Modules map[string]objects.Importable
Modules *objects.ModuleMap
}
// Run CLI
@ -56,7 +56,7 @@ func Run(options *Options) {
if options.InputFile == "" {
// REPL
runREPL(options.Modules, os.Stdin, os.Stdout)
RunREPL(options.Modules, os.Stdin, os.Stdout)
return
}
@ -67,17 +67,17 @@ func Run(options *Options) {
}
if options.CompileOutput != "" {
if err := compileOnly(options.Modules, inputData, options.InputFile, options.CompileOutput); err != nil {
if err := CompileOnly(options.Modules, inputData, options.InputFile, options.CompileOutput); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
} else if filepath.Ext(options.InputFile) == sourceFileExt {
if err := compileAndRun(options.Modules, inputData, options.InputFile); err != nil {
if err := CompileAndRun(options.Modules, inputData, options.InputFile); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
} else {
if err := runCompiled(inputData); err != nil {
if err := RunCompiled(options.Modules, inputData); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
@ -116,7 +116,8 @@ func doHelp() {
fmt.Println()
}
func compileOnly(modules map[string]objects.Importable, data []byte, inputFile, outputFile string) (err error) {
// CompileOnly compiles the source code and writes the compiled binary into outputFile.
func CompileOnly(modules *objects.ModuleMap, data []byte, inputFile, outputFile string) (err error) {
bytecode, err := compileSrc(modules, data, filepath.Base(inputFile))
if err != nil {
return
@ -148,7 +149,8 @@ func compileOnly(modules map[string]objects.Importable, data []byte, inputFile,
return
}
func compileAndRun(modules map[string]objects.Importable, data []byte, inputFile string) (err error) {
// CompileAndRun compiles the source code and executes it.
func CompileAndRun(modules *objects.ModuleMap, data []byte, inputFile string) (err error) {
bytecode, err := compileSrc(modules, data, filepath.Base(inputFile))
if err != nil {
return
@ -164,9 +166,10 @@ func compileAndRun(modules map[string]objects.Importable, data []byte, inputFile
return
}
func runCompiled(data []byte) (err error) {
// RunCompiled reads the compiled binary from file and executes it.
func RunCompiled(modules *objects.ModuleMap, data []byte) (err error) {
bytecode := &compiler.Bytecode{}
err = bytecode.Decode(bytes.NewReader(data))
err = bytecode.Decode(bytes.NewReader(data), modules)
if err != nil {
return
}
@ -181,7 +184,8 @@ func runCompiled(data []byte) (err error) {
return
}
func runREPL(modules map[string]objects.Importable, in io.Reader, out io.Writer) {
// RunREPL starts REPL.
func RunREPL(modules *objects.ModuleMap, in io.Reader, out io.Writer) {
stdin := bufio.NewScanner(in)
fileSet := source.NewFileSet()
@ -254,7 +258,7 @@ func runREPL(modules map[string]objects.Importable, in io.Reader, out io.Writer)
}
}
func compileSrc(modules map[string]objects.Importable, src []byte, filename string) (*compiler.Bytecode, error) {
func compileSrc(modules *objects.ModuleMap, src []byte, filename string) (*compiler.Bytecode, error) {
fileSet := source.NewFileSet()
srcFile := fileSet.AddFile(filename, -1, len(src))

64
cli/cli_test.go Normal file
View file

@ -0,0 +1,64 @@
package cli_test
import (
"io/ioutil"
"os"
"path/filepath"
"regexp"
"testing"
"github.com/d5/tengo/assert"
"github.com/d5/tengo/cli"
"github.com/d5/tengo/stdlib"
)
func TestCLICompileAndRun(t *testing.T) {
tempDir := filepath.Join(os.TempDir(), "tengo_tests")
_ = os.MkdirAll(tempDir, os.ModePerm)
binFile := filepath.Join(tempDir, "cli_bin")
outFile := filepath.Join(tempDir, "cli_out")
defer func() {
_ = os.RemoveAll(tempDir)
}()
src := []byte(`
os := import("os")
rand := import("rand")
times := import("times")
rand.seed(times.time_nanosecond(times.now()))
rand_num := func() {
return rand.intn(100)
}
file := os.create("` + outFile + `")
file.write_string("random number is " + rand_num())
file.close()
`)
mods := stdlib.GetModuleMap(stdlib.AllModuleNames()...)
err := cli.CompileOnly(mods, src, "src", binFile)
if !assert.NoError(t, err) {
return
}
compiledBin, err := ioutil.ReadFile(binFile)
if !assert.NoError(t, err) {
return
}
err = cli.RunCompiled(mods, compiledBin)
if !assert.NoError(t, err) {
return
}
read, err := ioutil.ReadFile(outFile)
if !assert.NoError(t, err) {
return
}
ok, err := regexp.Match(`^random number is \d+$`, read)
assert.NoError(t, err)
assert.True(t, ok, string(read))
}

View file

@ -27,7 +27,7 @@ func main() {
ShowVersion: showVersion,
Version: version,
CompileOutput: compileOutput,
Modules: stdlib.GetModules(stdlib.AllModuleNames()...),
Modules: stdlib.GetModuleMap(stdlib.AllModuleNames()...),
InputFile: flag.Arg(0),
})
}

View file

@ -2,13 +2,18 @@ package compiler
import (
"encoding/gob"
"fmt"
"io"
"github.com/d5/tengo/objects"
)
// Decode reads Bytecode data from the reader.
func (b *Bytecode) Decode(r io.Reader) error {
func (b *Bytecode) Decode(r io.Reader, modules *objects.ModuleMap) error {
if modules == nil {
modules = objects.NewModuleMap()
}
dec := gob.NewDecoder(r)
if err := dec.Decode(&b.FileSet); err != nil {
@ -25,38 +30,68 @@ func (b *Bytecode) Decode(r io.Reader) error {
return err
}
for i, v := range b.Constants {
b.Constants[i] = fixDecoded(v)
fv, err := fixDecoded(v, modules)
if err != nil {
return err
}
b.Constants[i] = fv
}
return nil
}
func fixDecoded(o objects.Object) objects.Object {
func fixDecoded(o objects.Object, modules *objects.ModuleMap) (objects.Object, error) {
switch o := o.(type) {
case *objects.Bool:
if o.IsFalsy() {
return objects.FalseValue
return objects.FalseValue, nil
}
return objects.TrueValue
return objects.TrueValue, nil
case *objects.Undefined:
return objects.UndefinedValue
return objects.UndefinedValue, nil
case *objects.Array:
for i, v := range o.Value {
o.Value[i] = fixDecoded(v)
fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[i] = fv
}
case *objects.ImmutableArray:
for i, v := range o.Value {
o.Value[i] = fixDecoded(v)
fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[i] = fv
}
case *objects.Map:
for k, v := range o.Value {
o.Value[k] = fixDecoded(v)
fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[k] = fv
}
case *objects.ImmutableMap:
modName := moduleName(o)
if mod := modules.GetBuiltinModule(modName); mod != nil {
return mod.AsImmutableMap(modName), nil
}
for k, v := range o.Value {
o.Value[k] = fixDecoded(v)
// encoding of user function not supported
if _, isUserFunction := v.(*objects.UserFunction); isUserFunction {
return nil, fmt.Errorf("user function not decodable")
}
fv, err := fixDecoded(v, modules)
if err != nil {
return nil, err
}
o.Value[k] = fv
}
}
return o
return o, nil
}

View file

@ -280,7 +280,7 @@ func testBytecodeSerialization(t *testing.T, b *compiler.Bytecode) {
assert.NoError(t, err)
r := &compiler.Bytecode{}
err = r.Decode(bytes.NewReader(buf.Bytes()))
err = r.Decode(bytes.NewReader(buf.Bytes()), nil)
assert.NoError(t, err)
assert.Equal(t, b.FileSet, r.FileSet)

View file

@ -24,7 +24,7 @@ type Compiler struct {
symbolTable *SymbolTable
scopes []CompilationScope
scopeIndex int
importModules map[string]objects.Importable
modules *objects.ModuleMap
compiledModules map[string]*objects.CompiledFunction
allowFileImport bool
loops []*Loop
@ -34,7 +34,7 @@ type Compiler struct {
}
// NewCompiler creates a Compiler.
func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []objects.Object, importModules map[string]objects.Importable, trace io.Writer) *Compiler {
func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []objects.Object, modules *objects.ModuleMap, trace io.Writer) *Compiler {
mainScope := CompilationScope{
symbolInit: make(map[string]bool),
sourceMap: make(map[int]source.Pos),
@ -51,8 +51,8 @@ func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []object
}
// builtin modules
if importModules == nil {
importModules = make(map[string]objects.Importable)
if modules == nil {
modules = objects.NewModuleMap()
}
return &Compiler{
@ -63,7 +63,7 @@ func NewCompiler(file *source.File, symbolTable *SymbolTable, constants []object
scopeIndex: 0,
loopIndex: -1,
trace: trace,
importModules: importModules,
modules: modules,
compiledModules: make(map[string]*objects.CompiledFunction),
}
}
@ -513,7 +513,7 @@ func (c *Compiler) Compile(node ast.Node) error {
return c.errorf(node, "empty module name")
}
if mod, ok := c.importModules[node.ModuleName]; ok {
if mod := c.modules.Get(node.ModuleName); mod != nil {
v, err := mod.Import(node.ModuleName)
if err != nil {
return err
@ -644,7 +644,7 @@ func (c *Compiler) EnableFileImport(enable bool) {
}
func (c *Compiler) fork(file *source.File, modulePath string, symbolTable *SymbolTable) *Compiler {
child := NewCompiler(file, symbolTable, nil, c.importModules, c.trace)
child := NewCompiler(file, symbolTable, nil, c.modules, c.trace)
child.modulePath = modulePath // module file path
child.parent = c // parent to set to current compiler

View file

@ -119,17 +119,13 @@ Users can add and use a custom user type in Tengo code by implementing [Object](
To securely compile and execute _potentially_ unsafe script code, you can use the following Script functions.
#### Script.SetImports(modules map[string]objects.Importable)
#### Script.SetImports(modules *objects.ModuleMap)
SetImports sets the import modules with corresponding names. Script **does not** include any modules by default. You can use this function to include the [Standard Library](https://github.com/d5/tengo/blob/master/docs/stdlib.md).
```golang
s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`))
s.SetImports(map[string]objects.Importable{
"math": stdlib.BuiltinModules["math"],
})
// or
s.SetImports(stdlib.GetModules("math"))
// or, to include all stdlib at once
s.SetImports(stdlib.GetModules(stdlib.AllModuleNames()...))
@ -140,9 +136,9 @@ You can also include Tengo's written module using `objects.SourceModule` (which
```golang
s := script.New([]byte(`double := import("double"); a := double(20)`))
s.SetImports(map[string]objects.Importable{
"double": &objects.SourceModule{Src: []byte(`export func(x) { return x * 2 }`)},
})
mods := objects.NewModuleMap()
mods.AddSourceModule("double", []byte(`export func(x) { return x * 2 }`))
s.SetImports(mods)
```

View file

@ -6,11 +6,18 @@ type BuiltinModule struct {
}
// Import returns an immutable map for the module.
func (m *BuiltinModule) Import(name string) (interface{}, error) {
func (m *BuiltinModule) Import(moduleName string) (interface{}, error) {
return m.AsImmutableMap(moduleName), nil
}
// AsImmutableMap converts builtin module into an immutable map.
func (m *BuiltinModule) AsImmutableMap(moduleName string) *ImmutableMap {
attrs := make(map[string]Object, len(m.Attrs))
for k, v := range m.Attrs {
attrs[k] = v.Copy()
}
attrs["__module_name__"] = &String{Value: name}
return &ImmutableMap{Value: attrs}, nil
attrs["__module_name__"] = &String{Value: moduleName}
return &ImmutableMap{Value: attrs}
}

View file

@ -3,5 +3,5 @@ package objects
// Importable interface represents importable module instance.
type Importable interface {
// Import should return either an Object or module source code ([]byte).
Import(name string) (interface{}, error)
Import(moduleName string) (interface{}, error)
}

65
objects/module_map.go Normal file
View file

@ -0,0 +1,65 @@
package objects
// ModuleMap represents a set of named modules.
// Use NewModuleMap to create a new module map.
type ModuleMap struct {
m map[string]Importable
}
// NewModuleMap creates a new module map.
func NewModuleMap() *ModuleMap {
return &ModuleMap{
m: make(map[string]Importable),
}
}
// Add adds an import module.
func (m *ModuleMap) Add(name string, module Importable) {
m.m[name] = module
}
// AddBuiltinModule adds a builtin module.
func (m *ModuleMap) AddBuiltinModule(name string, attrs map[string]Object) {
m.m[name] = &BuiltinModule{Attrs: attrs}
}
// AddSourceModule adds a source module.
func (m *ModuleMap) AddSourceModule(name string, src []byte) {
m.m[name] = &SourceModule{Src: src}
}
// Get returns an import module identified by name.
// It returns if the name is not found.
func (m *ModuleMap) Get(name string) Importable {
return m.m[name]
}
// GetBuiltinModule returns a builtin module identified by name.
// It returns if the name is not found or the module is not a builtin module.
func (m *ModuleMap) GetBuiltinModule(name string) *BuiltinModule {
mod, _ := m.m[name].(*BuiltinModule)
return mod
}
// GetSourceModule returns a source module identified by name.
// It returns if the name is not found or the module is not a source module.
func (m *ModuleMap) GetSourceModule(name string) *SourceModule {
mod, _ := m.m[name].(*SourceModule)
return mod
}
// Copy creates a copy of the module map.
func (m *ModuleMap) Copy() *ModuleMap {
c := &ModuleMap{
m: make(map[string]Importable),
}
for name, mod := range m.m {
c.m[name] = mod
}
return c
}
// Len returns the number of named modules.
func (m *ModuleMap) Len() int {
return len(m.m)
}

View file

@ -8,6 +8,7 @@ import (
type UserFunction struct {
Name string
Value CallableFunc
EncodingID string
}
// TypeName returns the name of the type.

View file

@ -24,7 +24,7 @@ type MAP = map[string]interface{}
type ARR = []interface{}
type testopts struct {
modules map[string]objects.Importable
modules *objects.ModuleMap
symbols map[string]objects.Object
maxAllocs int64
skip2ndPass bool
@ -32,7 +32,7 @@ type testopts struct {
func Opts() *testopts {
return &testopts{
modules: make(map[string]objects.Importable),
modules: objects.NewModuleMap(),
symbols: make(map[string]objects.Object),
maxAllocs: -1,
skip2ndPass: false,
@ -41,14 +41,11 @@ func Opts() *testopts {
func (o *testopts) copy() *testopts {
c := &testopts{
modules: make(map[string]objects.Importable),
modules: o.modules.Copy(),
symbols: make(map[string]objects.Object),
maxAllocs: o.maxAllocs,
skip2ndPass: o.skip2ndPass,
}
for k, v := range o.modules {
c.modules[k] = v
}
for k, v := range o.symbols {
c.symbols[k] = v
}
@ -59,11 +56,11 @@ func (o *testopts) Module(name string, mod interface{}) *testopts {
c := o.copy()
switch mod := mod.(type) {
case objects.Importable:
c.modules[name] = mod
c.modules.Add(name, mod)
case string:
c.modules[name] = &objects.SourceModule{Src: []byte(mod)}
c.modules.AddSourceModule(name, []byte(mod))
case []byte:
c.modules[name] = &objects.SourceModule{Src: mod}
c.modules.AddSourceModule(name, mod)
default:
panic(fmt.Errorf("invalid module type: %T", mod))
}
@ -135,9 +132,7 @@ func expect(t *testing.T, input string, opts *testopts, expected interface{}) {
expectedObj = &objects.ImmutableMap{Value: eo.Value}
}
modules["__code__"] = &objects.SourceModule{
Src: []byte(fmt.Sprintf("out := undefined; %s; export out", input)),
}
modules.AddSourceModule("__code__", []byte(fmt.Sprintf("out := undefined; %s; export out", input)))
res, trace, err := traceCompileRun(file, symbols, modules, maxAllocs)
if !assert.NoError(t, err) ||
@ -184,7 +179,7 @@ func (o *tracer) Write(p []byte) (n int, err error) {
return len(p), nil
}
func traceCompileRun(file *ast.File, symbols map[string]objects.Object, modules map[string]objects.Importable, maxAllocs int64) (res map[string]objects.Object, trace []string, err error) {
func traceCompileRun(file *ast.File, symbols map[string]objects.Object, modules *objects.ModuleMap, maxAllocs int64) (res map[string]objects.Object, trace []string, err error) {
var v *runtime.VM
defer func() {

View file

@ -14,7 +14,7 @@ import (
// Script can simplify compilation and execution of embedded scripts.
type Script struct {
variables map[string]*Variable
importModules map[string]objects.Importable
modules *objects.ModuleMap
input []byte
maxAllocs int64
maxConstObjects int
@ -59,8 +59,8 @@ func (s *Script) Remove(name string) bool {
}
// SetImports sets import modules.
func (s *Script) SetImports(modules map[string]objects.Importable) {
s.importModules = modules
func (s *Script) SetImports(modules *objects.ModuleMap) {
s.modules = modules
}
// SetMaxAllocs sets the maximum number of objects allocations during the run time.
@ -96,7 +96,7 @@ func (s *Script) Compile() (*Compiled, error) {
return nil, err
}
c := compiler.NewCompiler(srcFile, symbolTable, nil, s.importModules, nil)
c := compiler.NewCompiler(srcFile, symbolTable, nil, s.modules, nil)
c.EnableFileImport(s.enableFileImport)
if err := c.Compile(file); err != nil {
return nil, err

View file

@ -45,8 +45,7 @@ for i:=1; i<=d; i++ {
e := mod1.double(s)
`)
mod1 := &objects.BuiltinModule{
Attrs: map[string]objects.Object{
mod1 := map[string]objects.Object{
"double": &objects.UserFunction{
Value: func(args ...objects.Object) (ret objects.Object, err error) {
arg0, _ := objects.ToInt64(args[0])
@ -54,16 +53,15 @@ e := mod1.double(s)
return
},
},
},
}
scr := script.New(code)
_ = scr.Add("a", 0)
_ = scr.Add("b", 0)
_ = scr.Add("c", 0)
scr.SetImports(map[string]objects.Importable{
"mod1": mod1,
})
mods := objects.NewModuleMap()
mods.AddBuiltinModule("mod1", mod1)
scr.SetImports(mods)
compiled, err := scr.Compile()
assert.NoError(t, err)

View file

@ -12,38 +12,44 @@ import (
func TestScriptSourceModule(t *testing.T) {
// script1 imports "mod1"
scr := script.New([]byte(`out := import("mod")`))
scr.SetImports(map[string]objects.Importable{
"mod": &objects.SourceModule{Src: []byte(`export 5`)},
})
mods := objects.NewModuleMap()
mods.AddSourceModule("mod", []byte(`export 5`))
scr.SetImports(mods)
c, err := scr.Run()
if !assert.NoError(t, err) {
return
}
assert.Equal(t, int64(5), c.Get("out").Value())
// executing module function
scr = script.New([]byte(`fn := import("mod"); out := fn()`))
scr.SetImports(map[string]objects.Importable{
"mod": &objects.SourceModule{Src: []byte(`a := 3; export func() { return a + 5 }`)},
})
mods = objects.NewModuleMap()
mods.AddSourceModule("mod", []byte(`a := 3; export func() { return a + 5 }`))
scr.SetImports(mods)
c, err = scr.Run()
assert.NoError(t, err)
if !assert.NoError(t, err) {
return
}
assert.Equal(t, int64(8), c.Get("out").Value())
scr = script.New([]byte(`out := import("mod")`))
scr.SetImports(map[string]objects.Importable{
"text": &objects.BuiltinModule{
Attrs: map[string]objects.Object{
mods = objects.NewModuleMap()
mods.AddSourceModule("mod", []byte(`text := import("text"); export text.title("foo")`))
mods.AddBuiltinModule("text", map[string]objects.Object{
"title": &objects.UserFunction{Name: "title", Value: func(args ...objects.Object) (ret objects.Object, err error) {
s, _ := objects.ToString(args[0])
return &objects.String{Value: strings.Title(s)}, nil
}},
},
},
"mod": &objects.SourceModule{Src: []byte(`text := import("text"); export text.title("foo")`)},
})
scr.SetImports(mods)
c, err = scr.Run()
assert.NoError(t, err)
if !assert.NoError(t, err) {
return
}
assert.Equal(t, "Foo", c.Get("out").Value())
scr.SetImports(nil)
_, err = scr.Run()
assert.Error(t, err)
if !assert.Error(t, err) {
return
}
}

View file

@ -53,7 +53,7 @@ func TestScript_Run(t *testing.T) {
func TestScript_BuiltinModules(t *testing.T) {
s := script.New([]byte(`math := import("math"); a := math.abs(-19.84)`))
s.SetImports(map[string]objects.Importable{"math": stdlib.BuiltinModules["math"]})
s.SetImports(stdlib.GetModuleMap("math"))
c, err := s.Run()
assert.NoError(t, err)
assert.NotNil(t, c)
@ -64,7 +64,7 @@ func TestScript_BuiltinModules(t *testing.T) {
assert.NotNil(t, c)
compiledGet(t, c, "a", 19.84)
s.SetImports(map[string]objects.Importable{"os": &objects.BuiltinModule{Attrs: map[string]objects.Object{}}})
s.SetImports(stdlib.GetModuleMap("os"))
_, err = s.Run()
assert.Error(t, err)
@ -80,7 +80,7 @@ a := enum.all([1,2,3], func(_, v) {
return v > 0
})
`))
s.SetImports(stdlib.GetModules("enum"))
s.SetImports(stdlib.GetModuleMap("enum"))
c, err := s.Run()
assert.NoError(t, err)
assert.NotNil(t, c)

View file

@ -3,12 +3,12 @@ package stdlib
import "github.com/d5/tengo/objects"
// BuiltinModules are builtin type standard library modules.
var BuiltinModules = map[string]*objects.BuiltinModule{
"math": {Attrs: mathModule},
"os": {Attrs: osModule},
"text": {Attrs: textModule},
"times": {Attrs: timesModule},
"rand": {Attrs: randModule},
"fmt": {Attrs: fmtModule},
"json": {Attrs: jsonModule},
var BuiltinModules = map[string]map[string]objects.Object{
"math": mathModule,
"os": osModule,
"text": textModule,
"times": timesModule,
"rand": randModule,
"fmt": fmtModule,
"json": jsonModule,
}

View file

@ -38,12 +38,10 @@ func main() {
package stdlib
import "github.com/d5/tengo/objects"
// SourceModules are source type standard library modules.
var SourceModules = map[string]*objects.SourceModule{` + "\n")
var SourceModules = map[string]string{` + "\n")
for modName, modSrc := range modules {
out.WriteString("\t\"" + modName + "\": {Src: []byte(`" + modSrc + "`)},\n")
out.WriteString("\t\"" + modName + "\": `" + modSrc + "`,\n")
}
out.WriteString("}\n")

View file

@ -2,11 +2,9 @@
package stdlib
import "github.com/d5/tengo/objects"
// SourceModules are source type standard library modules.
var SourceModules = map[string]*objects.SourceModule{
"enum": {Src: []byte(`export {
var SourceModules = map[string]string{
"enum": `export {
// all returns true if the given function fn evaluates to a truthy value on
// all of the items in the enumerable.
all: func(enumerable, fn) {
@ -46,5 +44,5 @@ var SourceModules = map[string]*objects.SourceModule{
return res
}
}
`)},
`,
}

View file

@ -16,16 +16,17 @@ func AllModuleNames() []string {
return names
}
// GetModules returns the modules for the given names.
// Duplicate names and invalid names are ignore.
func GetModules(names ...string) map[string]objects.Importable {
modules := make(map[string]objects.Importable)
// GetModuleMap returns the module map that includes all modules
// for the given module names.
func GetModuleMap(names ...string) *objects.ModuleMap {
modules := objects.NewModuleMap()
for _, name := range names {
if mod := BuiltinModules[name]; mod != nil {
modules[name] = mod
modules.AddBuiltinModule(name, mod)
}
if mod := SourceModules[name]; mod != nil {
modules[name] = mod
if mod := SourceModules[name]; mod != "" {
modules.AddSourceModule(name, []byte(mod))
}
}

View file

@ -74,25 +74,25 @@ if !is_error(cmd) {
}
func TestGetModules(t *testing.T) {
mods := stdlib.GetModules()
assert.Equal(t, 0, len(mods))
mods := stdlib.GetModuleMap()
assert.Equal(t, 0, mods.Len())
mods = stdlib.GetModules("os")
assert.Equal(t, 1, len(mods))
assert.NotNil(t, mods["os"])
mods = stdlib.GetModuleMap("os")
assert.Equal(t, 1, mods.Len())
assert.NotNil(t, mods.Get("os"))
mods = stdlib.GetModules("os", "rand")
assert.Equal(t, 2, len(mods))
assert.NotNil(t, mods["os"])
assert.NotNil(t, mods["rand"])
mods = stdlib.GetModuleMap("os", "rand")
assert.Equal(t, 2, mods.Len())
assert.NotNil(t, mods.Get("os"))
assert.NotNil(t, mods.Get("rand"))
mods = stdlib.GetModules("text", "text")
assert.Equal(t, 1, len(mods))
assert.NotNil(t, mods["text"])
mods = stdlib.GetModuleMap("text", "text")
assert.Equal(t, 1, mods.Len())
assert.NotNil(t, mods.Get("text"))
mods = stdlib.GetModules("nonexisting", "text")
assert.Equal(t, 1, len(mods))
assert.NotNil(t, mods["text"])
mods = stdlib.GetModuleMap("nonexisting", "text")
assert.Equal(t, 1, mods.Len())
assert.NotNil(t, mods.Get("text"))
}
type callres struct {
@ -156,8 +156,8 @@ func (c callres) expectError() bool {
}
func module(t *testing.T, moduleName string) callres {
mod, ok := stdlib.BuiltinModules[moduleName]
if !ok {
mod := stdlib.GetModuleMap(moduleName).GetBuiltinModule(moduleName)
if mod == nil {
return callres{t: t, e: fmt.Errorf("module not found: %s", moduleName)}
}
@ -231,7 +231,7 @@ func object(v interface{}) objects.Object {
func expect(t *testing.T, input string, expected interface{}) {
s := script.New([]byte(input))
s.SetImports(stdlib.GetModules(stdlib.AllModuleNames()...))
s.SetImports(stdlib.GetModuleMap(stdlib.AllModuleNames()...))
c, err := s.Run()
assert.NoError(t, err)
assert.NotNil(t, c)