Compile can take custom stdlibs

This commit is contained in:
Daniel Kang 2019-01-18 09:19:45 -08:00
parent bff59006d1
commit a8d838ad3e
7 changed files with 46 additions and 14 deletions

View file

@ -193,7 +193,7 @@ func compileFile(file *ast.File) (time.Duration, *compiler.Bytecode, error) {
start := time.Now() start := time.Now()
c := compiler.NewCompiler(symTable, nil) c := compiler.NewCompiler(symTable, nil, nil)
if err := c.Compile(file); err != nil { if err := c.Compile(file); err != nil {
return time.Since(start), nil, err return time.Since(start), nil, err
} }

View file

@ -171,7 +171,11 @@ func doRepl(in io.Reader, out io.Writer) {
fileSet := source.NewFileSet() fileSet := source.NewFileSet()
globals := make([]*objects.Object, runtime.GlobalsSize) globals := make([]*objects.Object, runtime.GlobalsSize)
symbolTable := compiler.NewSymbolTable() symbolTable := compiler.NewSymbolTable()
for idx, fn := range objects.Builtins {
symbolTable.DefineBuiltin(idx, fn.Name)
}
for { for {
_, _ = fmt.Fprintf(out, replPrompt) _, _ = fmt.Fprintf(out, replPrompt)
@ -191,7 +195,7 @@ func doRepl(in io.Reader, out io.Writer) {
file = addPrints(file) file = addPrints(file)
c := compiler.NewCompiler(symbolTable, nil) c := compiler.NewCompiler(symbolTable, nil, nil)
if err := c.Compile(file); err != nil { if err := c.Compile(file); err != nil {
_, _ = fmt.Fprintf(out, "Compilation error:\n %s\n", err.Error()) _, _ = fmt.Fprintf(out, "Compilation error:\n %s\n", err.Error())
continue continue
@ -218,7 +222,7 @@ func compileSrc(src []byte, filename string) (*compiler.Bytecode, error) {
return nil, err return nil, err
} }
c := compiler.NewCompiler(nil, nil) c := compiler.NewCompiler(nil, nil, nil)
if err := c.Compile(file); err != nil { if err := c.Compile(file); err != nil {
return nil, err return nil, err
} }

View file

@ -20,6 +20,7 @@ type Compiler struct {
scopes []CompilationScope scopes []CompilationScope
scopeIndex int scopeIndex int
moduleLoader ModuleLoader moduleLoader ModuleLoader
stdModules map[string]*objects.ImmutableMap
compiledModules map[string]*objects.CompiledModule compiledModules map[string]*objects.CompiledModule
loops []*Loop loops []*Loop
loopIndex int loopIndex int
@ -28,17 +29,28 @@ type Compiler struct {
} }
// NewCompiler creates a Compiler. // NewCompiler creates a Compiler.
func NewCompiler(symbolTable *SymbolTable, trace io.Writer) *Compiler { // User can optionally provide the symbol table if one wants to add or remove
// some global- or builtin- scope symbols. If not (nil), Compile will create
// a new symbol table and use the default builtin functions. Likewise, standard
// modules can be explicitly provided if user wants to add or remove some modules.
// By default, Compile will use all the standard modules otherwise.
func NewCompiler(symbolTable *SymbolTable, stdModules map[string]*objects.ImmutableMap, trace io.Writer) *Compiler {
mainScope := CompilationScope{ mainScope := CompilationScope{
instructions: make([]byte, 0), instructions: make([]byte, 0),
} }
// symbol table
if symbolTable == nil { if symbolTable == nil {
symbolTable = NewSymbolTable() symbolTable = NewSymbolTable()
for idx, fn := range objects.Builtins {
symbolTable.DefineBuiltin(idx, fn.Name)
}
} }
for idx, fn := range objects.Builtins { // standard modules
symbolTable.DefineBuiltin(idx, fn.Name) if stdModules == nil {
stdModules = stdlib.Modules
} }
return &Compiler{ return &Compiler{
@ -47,6 +59,7 @@ func NewCompiler(symbolTable *SymbolTable, trace io.Writer) *Compiler {
scopeIndex: 0, scopeIndex: 0,
loopIndex: -1, loopIndex: -1,
trace: trace, trace: trace,
stdModules: stdModules,
compiledModules: make(map[string]*objects.CompiledModule), compiledModules: make(map[string]*objects.CompiledModule),
} }
} }
@ -440,7 +453,7 @@ func (c *Compiler) Compile(node ast.Node) error {
c.emit(OpCall, len(node.Args)) c.emit(OpCall, len(node.Args))
case *ast.ImportExpr: case *ast.ImportExpr:
stdMod, ok := stdlib.Modules[node.ModuleName] stdMod, ok := c.stdModules[node.ModuleName]
if ok { if ok {
// standard modules contain only globals with no code. // standard modules contain only globals with no code.
// so no need to compile anything // so no need to compile anything
@ -474,12 +487,14 @@ func (c *Compiler) Bytecode() *Bytecode {
} }
// SetModuleLoader sets or replaces the current module loader. // SetModuleLoader sets or replaces the current module loader.
// Note that the module loader is used for user modules,
// not for the standard modules.
func (c *Compiler) SetModuleLoader(moduleLoader ModuleLoader) { func (c *Compiler) SetModuleLoader(moduleLoader ModuleLoader) {
c.moduleLoader = moduleLoader c.moduleLoader = moduleLoader
} }
func (c *Compiler) fork(moduleName string, symbolTable *SymbolTable) *Compiler { func (c *Compiler) fork(moduleName string, symbolTable *SymbolTable) *Compiler {
child := NewCompiler(symbolTable, c.trace) child := NewCompiler(symbolTable, c.stdModules, c.trace)
child.moduleName = moduleName // name of the module to compile child.moduleName = moduleName // name of the module to compile
child.parent = c // parent to set to current compiler child.parent = c // parent to set to current compiler
child.moduleLoader = c.moduleLoader // share module loader child.moduleLoader = c.moduleLoader // share module loader

View file

@ -972,9 +972,12 @@ func traceCompile(input string, symbols map[string]objects.Object) (res *compile
for name := range symbols { for name := range symbols {
symTable.Define(name) symTable.Define(name)
} }
for idx, fn := range objects.Builtins {
symTable.DefineBuiltin(idx, fn.Name)
}
tr := &tracer{} tr := &tracer{}
c := compiler.NewCompiler(symTable, tr) c := compiler.NewCompiler(symTable, nil, tr)
parsed, err := p.ParseFile() parsed, err := p.ParseFile()
if err != nil { if err != nil {
return return

View file

@ -1,10 +1,13 @@
package objects package objects
// Builtins contains all known builtin functions. // NamedBuiltinFunc is a named builtin function.
var Builtins = []struct { type NamedBuiltinFunc struct {
Name string Name string
Func CallableFunc Func CallableFunc
}{ }
// Builtins contains all default builtin functions.
var Builtins = []NamedBuiltinFunc{
{ {
Name: "print", Name: "print",
Func: builtinPrint, Func: builtinPrint,

View file

@ -176,9 +176,12 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object, userModu
sym := symTable.Define(name) sym := symTable.Define(name)
globals[sym.Index] = &value globals[sym.Index] = &value
} }
for idx, fn := range objects.Builtins {
symTable.DefineBuiltin(idx, fn.Name)
}
tr := &tracer{} tr := &tracer{}
c := compiler.NewCompiler(symTable, tr) c := compiler.NewCompiler(symTable, nil, tr)
c.SetModuleLoader(func(moduleName string) ([]byte, error) { c.SetModuleLoader(func(moduleName string) ([]byte, error) {
if src, ok := userModules[moduleName]; ok { if src, ok := userModules[moduleName]; ok {
return []byte(src), nil return []byte(src), nil

View file

@ -63,7 +63,7 @@ func (s *Script) Compile() (*Compiled, error) {
return nil, fmt.Errorf("parse error: %s", err.Error()) return nil, fmt.Errorf("parse error: %s", err.Error())
} }
c := compiler.NewCompiler(symbolTable, nil) c := compiler.NewCompiler(symbolTable, nil, nil)
if err := c.Compile(file); err != nil { if err := c.Compile(file); err != nil {
return nil, err return nil, err
} }
@ -94,6 +94,10 @@ func (s *Script) prepCompile() (symbolTable *compiler.SymbolTable, globals []*ob
} }
symbolTable = compiler.NewSymbolTable() symbolTable = compiler.NewSymbolTable()
for idx, fn := range objects.Builtins {
symbolTable.DefineBuiltin(idx, fn.Name)
}
globals = make([]*objects.Object, len(names), len(names)) globals = make([]*objects.Object, len(names), len(names))
for idx, name := range names { for idx, name := range names {