xgo/compiler/compiler.go
2019-02-10 00:32:37 -08:00

707 lines
16 KiB
Go

package compiler
import (
"fmt"
"io"
"reflect"
"github.com/d5/tengo/compiler/ast"
"github.com/d5/tengo/compiler/stdlib"
"github.com/d5/tengo/compiler/token"
"github.com/d5/tengo/objects"
)
// Compiler compiles the AST into a bytecode.
type Compiler struct {
parent *Compiler
moduleName string
constants []objects.Object
symbolTable *SymbolTable
scopes []CompilationScope
scopeIndex int
moduleLoader ModuleLoader
stdModules map[string]*objects.ImmutableMap
compiledModules map[string]*objects.CompiledFunction
loops []*Loop
loopIndex int
trace io.Writer
indent int
}
// NewCompiler creates a Compiler.
// User can optionally provide the symbol table if one wants to add or remove
// some global- or builtin- scope symbols. If not (nil), Compile will create
// a new symbol table and use the default builtin functions. Likewise, standard
// modules can be explicitly provided if user wants to add or remove some modules.
// By default, Compile will use all the standard modules otherwise.
func NewCompiler(symbolTable *SymbolTable, constants []objects.Object, stdModules map[string]*objects.ImmutableMap, trace io.Writer) *Compiler {
mainScope := CompilationScope{
instructions: make([]byte, 0),
}
// symbol table
if symbolTable == nil {
symbolTable = NewSymbolTable()
for idx, fn := range objects.Builtins {
symbolTable.DefineBuiltin(idx, fn.Name)
}
}
// standard modules
if stdModules == nil {
stdModules = stdlib.Modules
}
return &Compiler{
symbolTable: symbolTable,
constants: constants,
scopes: []CompilationScope{mainScope},
scopeIndex: 0,
loopIndex: -1,
trace: trace,
stdModules: stdModules,
compiledModules: make(map[string]*objects.CompiledFunction),
}
}
// Compile compiles the AST node.
func (c *Compiler) Compile(node ast.Node) error {
if c.trace != nil {
if node != nil {
defer un(trace(c, fmt.Sprintf("%s (%s)", node.String(), reflect.TypeOf(node).Elem().Name())))
} else {
defer un(trace(c, "<nil>"))
}
}
switch node := node.(type) {
case *ast.File:
for _, stmt := range node.Stmts {
if err := c.Compile(stmt); err != nil {
return err
}
}
case *ast.ExprStmt:
if err := c.Compile(node.Expr); err != nil {
return err
}
c.emit(OpPop)
case *ast.IncDecStmt:
op := token.AddAssign
if node.Token == token.Dec {
op = token.SubAssign
}
return c.compileAssign([]ast.Expr{node.Expr}, []ast.Expr{&ast.IntLit{Value: 1}}, op)
case *ast.ParenExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
case *ast.BinaryExpr:
if node.Token == token.LAnd || node.Token == token.LOr {
return c.compileLogical(node)
}
if node.Token == token.Less {
if err := c.Compile(node.RHS); err != nil {
return err
}
if err := c.Compile(node.LHS); err != nil {
return err
}
c.emit(OpGreaterThan)
return nil
} else if node.Token == token.LessEq {
if err := c.Compile(node.RHS); err != nil {
return err
}
if err := c.Compile(node.LHS); err != nil {
return err
}
c.emit(OpGreaterThanEqual)
return nil
}
if err := c.Compile(node.LHS); err != nil {
return err
}
if err := c.Compile(node.RHS); err != nil {
return err
}
switch node.Token {
case token.Add:
c.emit(OpAdd)
case token.Sub:
c.emit(OpSub)
case token.Mul:
c.emit(OpMul)
case token.Quo:
c.emit(OpDiv)
case token.Rem:
c.emit(OpRem)
case token.Greater:
c.emit(OpGreaterThan)
case token.GreaterEq:
c.emit(OpGreaterThanEqual)
case token.Equal:
c.emit(OpEqual)
case token.NotEqual:
c.emit(OpNotEqual)
case token.And:
c.emit(OpBAnd)
case token.Or:
c.emit(OpBOr)
case token.Xor:
c.emit(OpBXor)
case token.AndNot:
c.emit(OpBAndNot)
case token.Shl:
c.emit(OpBShiftLeft)
case token.Shr:
c.emit(OpBShiftRight)
default:
return fmt.Errorf("unknown operator: %s", node.Token.String())
}
case *ast.IntLit:
c.emit(OpConstant, c.addConstant(&objects.Int{Value: node.Value}))
case *ast.FloatLit:
c.emit(OpConstant, c.addConstant(&objects.Float{Value: node.Value}))
case *ast.BoolLit:
if node.Value {
c.emit(OpTrue)
} else {
c.emit(OpFalse)
}
case *ast.StringLit:
c.emit(OpConstant, c.addConstant(&objects.String{Value: node.Value}))
case *ast.CharLit:
c.emit(OpConstant, c.addConstant(&objects.Char{Value: node.Value}))
case *ast.UndefinedLit:
c.emit(OpNull)
case *ast.UnaryExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
switch node.Token {
case token.Not:
c.emit(OpLNot)
case token.Sub:
c.emit(OpMinus)
case token.Xor:
c.emit(OpBComplement)
case token.Add:
// do nothing?
default:
return fmt.Errorf("unknown operator: %s", node.Token.String())
}
case *ast.IfStmt:
// open new symbol table for the statement
c.symbolTable = c.symbolTable.Fork(true)
defer func() {
c.symbolTable = c.symbolTable.Parent(false)
}()
if node.Init != nil {
if err := c.Compile(node.Init); err != nil {
return err
}
}
if err := c.Compile(node.Cond); err != nil {
return err
}
// first jump placeholder
jumpPos1 := c.emit(OpJumpFalsy, 0)
if err := c.Compile(node.Body); err != nil {
return err
}
if node.Else != nil {
// second jump placeholder
jumpPos2 := c.emit(OpJump, 0)
// update first jump offset
curPos := len(c.currentInstructions())
c.changeOperand(jumpPos1, curPos)
if err := c.Compile(node.Else); err != nil {
return err
}
// update second jump offset
curPos = len(c.currentInstructions())
c.changeOperand(jumpPos2, curPos)
} else {
// update first jump offset
curPos := len(c.currentInstructions())
c.changeOperand(jumpPos1, curPos)
}
case *ast.ForStmt:
return c.compileForStmt(node)
case *ast.ForInStmt:
return c.compileForInStmt(node)
case *ast.BranchStmt:
if node.Token == token.Break {
curLoop := c.currentLoop()
if curLoop == nil {
return fmt.Errorf("break statement outside loop")
}
pos := c.emit(OpJump, 0)
curLoop.Breaks = append(curLoop.Breaks, pos)
} else if node.Token == token.Continue {
curLoop := c.currentLoop()
if curLoop == nil {
return fmt.Errorf("continue statement outside loop")
}
pos := c.emit(OpJump, 0)
curLoop.Continues = append(curLoop.Continues, pos)
} else {
return fmt.Errorf("unknown branch statement: %s", node.Token.String())
}
case *ast.BlockStmt:
for _, stmt := range node.Stmts {
if err := c.Compile(stmt); err != nil {
return err
}
}
case *ast.AssignStmt:
if err := c.compileAssign(node.LHS, node.RHS, node.Token); err != nil {
return err
}
case *ast.Ident:
symbol, _, ok := c.symbolTable.Resolve(node.Name)
if !ok {
return fmt.Errorf("undefined variable: %s", node.Name)
}
switch symbol.Scope {
case ScopeGlobal:
c.emit(OpGetGlobal, symbol.Index)
case ScopeLocal:
c.emit(OpGetLocal, symbol.Index)
case ScopeBuiltin:
c.emit(OpGetBuiltin, symbol.Index)
case ScopeFree:
c.emit(OpGetFree, symbol.Index)
}
case *ast.ArrayLit:
for _, elem := range node.Elements {
if err := c.Compile(elem); err != nil {
return err
}
}
c.emit(OpArray, len(node.Elements))
case *ast.MapLit:
for _, elt := range node.Elements {
// key
c.emit(OpConstant, c.addConstant(&objects.String{Value: elt.Key}))
// value
if err := c.Compile(elt.Value); err != nil {
return err
}
}
c.emit(OpMap, len(node.Elements)*2)
case *ast.SelectorExpr: // selector on RHS side
if err := c.Compile(node.Expr); err != nil {
return err
}
if err := c.Compile(node.Sel); err != nil {
return err
}
c.emit(OpIndex)
case *ast.IndexExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
if err := c.Compile(node.Index); err != nil {
return err
}
c.emit(OpIndex)
case *ast.SliceExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
if node.Low != nil {
if err := c.Compile(node.Low); err != nil {
return err
}
} else {
c.emit(OpNull)
}
if node.High != nil {
if err := c.Compile(node.High); err != nil {
return err
}
} else {
c.emit(OpNull)
}
c.emit(OpSliceIndex)
case *ast.FuncLit:
c.enterScope()
for _, p := range node.Type.Params.List {
s := c.symbolTable.Define(p.Name)
// function arguments is not assigned directly.
s.LocalAssigned = true
}
if err := c.Compile(node.Body); err != nil {
return err
}
// add OpReturn if function returns nothing
if !c.lastInstructionIs(OpReturnValue) && !c.lastInstructionIs(OpReturn) {
c.emit(OpReturn)
}
freeSymbols := c.symbolTable.FreeSymbols()
numLocals := c.symbolTable.MaxSymbols()
instructions := c.leaveScope()
for _, s := range freeSymbols {
switch s.Scope {
case ScopeLocal:
if !s.LocalAssigned {
// Here, the closure is capturing a local variable that's not yet assigned its value.
// One example is a local recursive function:
//
// func() {
// foo := func(x) {
// // ..
// return foo(x-1)
// }
// }
//
// which translate into
//
// 0000 GETL 0
// 0002 CLOSURE ? 1
// 0006 DEFL 0
//
// . So the local variable (0) is being captured before it's assigned the value.
//
// Solution is to transform the code into something like this:
//
// func() {
// foo := undefined
// foo = func(x) {
// // ..
// return foo(x-1)
// }
// }
//
// that is equivalent to
//
// 0000 NULL
// 0001 DEFL 0
// 0003 GETL 0
// 0005 CLOSURE ? 1
// 0009 SETL 0
//
c.emit(OpNull)
c.emit(OpDefineLocal, s.Index)
s.LocalAssigned = true
}
c.emit(OpGetLocal, s.Index)
case ScopeFree:
c.emit(OpGetFree, s.Index)
}
}
compiledFunction := &objects.CompiledFunction{
Instructions: instructions,
NumLocals: numLocals,
NumParameters: len(node.Type.Params.List),
}
if len(freeSymbols) > 0 {
c.emit(OpClosure, c.addConstant(compiledFunction), len(freeSymbols))
} else {
c.emit(OpConstant, c.addConstant(compiledFunction))
}
case *ast.ReturnStmt:
if c.symbolTable.Parent(true) == nil {
// outside the function
return fmt.Errorf("return statement outside function")
}
if node.Result == nil {
c.emit(OpReturn)
} else {
if err := c.Compile(node.Result); err != nil {
return err
}
c.emit(OpReturnValue)
}
case *ast.CallExpr:
if err := c.Compile(node.Func); err != nil {
return err
}
for _, arg := range node.Args {
if err := c.Compile(arg); err != nil {
return err
}
}
c.emit(OpCall, len(node.Args))
case *ast.ImportExpr:
stdMod, ok := c.stdModules[node.ModuleName]
if ok {
// standard modules contain only globals with no code.
// so no need to compile anything
c.emit(OpConstant, c.addConstant(stdMod))
} else {
userMod, err := c.compileModule(node.ModuleName)
if err != nil {
return err
}
c.emit(OpConstant, c.addConstant(userMod))
c.emit(OpCall, 0)
}
case *ast.ExportStmt:
// export statement must be in top-level scope
if c.scopeIndex != 0 {
return fmt.Errorf("cannot use 'export' inside function")
}
// export statement is simply ignore when compiling non-module code
if c.parent == nil {
break
}
if err := c.Compile(node.Result); err != nil {
return err
}
c.emit(OpImmutable)
c.emit(OpReturnValue)
case *ast.ErrorExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
c.emit(OpError)
case *ast.ImmutableExpr:
if err := c.Compile(node.Expr); err != nil {
return err
}
c.emit(OpImmutable)
case *ast.CondExpr:
if err := c.Compile(node.Cond); err != nil {
return err
}
// first jump placeholder
jumpPos1 := c.emit(OpJumpFalsy, 0)
if err := c.Compile(node.True); err != nil {
return err
}
// second jump placeholder
jumpPos2 := c.emit(OpJump, 0)
// update first jump offset
curPos := len(c.currentInstructions())
c.changeOperand(jumpPos1, curPos)
if err := c.Compile(node.False); err != nil {
return err
}
// update second jump offset
curPos = len(c.currentInstructions())
c.changeOperand(jumpPos2, curPos)
}
return nil
}
// Bytecode returns a compiled bytecode.
func (c *Compiler) Bytecode() *Bytecode {
return &Bytecode{
Instructions: c.currentInstructions(),
Constants: c.constants,
}
}
// SetModuleLoader sets or replaces the current module loader.
// Note that the module loader is used for user modules,
// not for the standard modules.
func (c *Compiler) SetModuleLoader(moduleLoader ModuleLoader) {
c.moduleLoader = moduleLoader
}
func (c *Compiler) fork(moduleName string, symbolTable *SymbolTable) *Compiler {
child := NewCompiler(symbolTable, nil, c.stdModules, c.trace)
child.moduleName = moduleName // name of the module to compile
child.parent = c // parent to set to current compiler
child.moduleLoader = c.moduleLoader // share module loader
return child
}
func (c *Compiler) addConstant(o objects.Object) int {
if c.parent != nil {
// module compilers will use their parent's constants array
return c.parent.addConstant(o)
}
c.constants = append(c.constants, o)
if c.trace != nil {
c.printTrace(fmt.Sprintf("CONST %04d %s", len(c.constants)-1, o))
}
return len(c.constants) - 1
}
func (c *Compiler) addInstruction(b []byte) int {
posNewIns := len(c.currentInstructions())
c.scopes[c.scopeIndex].instructions = append(c.currentInstructions(), b...)
return posNewIns
}
func (c *Compiler) setLastInstruction(op Opcode, pos int) {
c.scopes[c.scopeIndex].lastInstructions[1] = c.scopes[c.scopeIndex].lastInstructions[0]
c.scopes[c.scopeIndex].lastInstructions[0].Opcode = op
c.scopes[c.scopeIndex].lastInstructions[0].Position = pos
}
func (c *Compiler) lastInstructionIs(op Opcode) bool {
if len(c.currentInstructions()) == 0 {
return false
}
return c.scopes[c.scopeIndex].lastInstructions[0].Opcode == op
}
func (c *Compiler) removeLastInstruction() {
lastPos := c.scopes[c.scopeIndex].lastInstructions[0].Position
if c.trace != nil {
c.printTrace(fmt.Sprintf("DELET %s",
FormatInstructions(c.scopes[c.scopeIndex].instructions[lastPos:], lastPos)[0]))
}
c.scopes[c.scopeIndex].instructions = c.currentInstructions()[:lastPos]
c.scopes[c.scopeIndex].lastInstructions[0] = c.scopes[c.scopeIndex].lastInstructions[1]
}
func (c *Compiler) replaceInstruction(pos int, inst []byte) {
copy(c.currentInstructions()[pos:], inst)
if c.trace != nil {
c.printTrace(fmt.Sprintf("REPLC %s",
FormatInstructions(c.scopes[c.scopeIndex].instructions[pos:], pos)[0]))
}
}
func (c *Compiler) changeOperand(opPos int, operand ...int) {
op := Opcode(c.currentInstructions()[opPos])
inst := MakeInstruction(op, operand...)
c.replaceInstruction(opPos, inst)
}
func (c *Compiler) emit(opcode Opcode, operands ...int) int {
inst := MakeInstruction(opcode, operands...)
pos := c.addInstruction(inst)
c.setLastInstruction(opcode, pos)
if c.trace != nil {
c.printTrace(fmt.Sprintf("EMIT %s",
FormatInstructions(c.scopes[c.scopeIndex].instructions[pos:], pos)[0]))
}
return pos
}
func (c *Compiler) printTrace(a ...interface{}) {
const (
dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . "
n = len(dots)
)
i := 2 * c.indent
for i > n {
_, _ = fmt.Fprint(c.trace, dots)
i -= n
}
_, _ = fmt.Fprint(c.trace, dots[0:i])
_, _ = fmt.Fprintln(c.trace, a...)
}
func trace(c *Compiler, msg string) *Compiler {
c.printTrace(msg, "{")
c.indent++
return c
}
func un(c *Compiler) {
c.indent--
c.printTrace("}")
}