add Indexable / IndexAssignable interface

This commit is contained in:
Daniel Kang 2019-01-20 23:32:58 -08:00
parent bf934d0086
commit 1045afd5a4
13 changed files with 428 additions and 230 deletions

View file

@ -135,6 +135,7 @@ func resolveAssignLHS(expr ast.Expr) (name string, selectors []ast.Expr, err err
selectors = append(selectors, term.Sel) selectors = append(selectors, term.Sel)
return return
case *ast.IndexExpr: case *ast.IndexExpr:
name, selectors, err = resolveAssignLHS(term.Expr) name, selectors, err = resolveAssignLHS(term.Expr)
if err != nil { if err != nil {

View file

@ -1,7 +1,6 @@
package objects package objects
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
@ -79,22 +78,40 @@ func (o *Array) Equals(x Object) bool {
return true return true
} }
// Get returns an element at a given index. // IndexGet returns an element at a given index.
func (o *Array) Get(index int) (Object, error) { func (o *Array) IndexGet(index Object) (res Object, err error) {
if index < 0 || index >= len(o.Value) { intIdx, ok := index.(*Int)
return nil, errors.New("array index out of bounds") if !ok {
err = ErrInvalidIndexType
return
} }
return o.Value[index], nil idxVal := int(intIdx.Value)
if idxVal < 0 || idxVal >= len(o.Value) {
err = ErrIndexOutOfBounds
return
}
res = o.Value[idxVal]
return
} }
// Set sets an element at a given index. // IndexSet sets an element at a given index.
func (o *Array) Set(index int, value Object) error { func (o *Array) IndexSet(index, value Object) (err error) {
if index < 0 || index >= len(o.Value) { intIdx, ok := ToInt(index)
return errors.New("array index out of bounds") if !ok {
err = ErrInvalidTypeConversion
return
} }
o.Value[index] = value if intIdx < 0 || intIdx >= len(o.Value) {
err = ErrIndexOutOfBounds
return
}
o.Value[intIdx] = value
return nil return nil
} }

View file

@ -54,3 +54,23 @@ func (o *Bytes) Equals(x Object) bool {
return bytes.Compare(o.Value, t.Value) == 0 return bytes.Compare(o.Value, t.Value) == 0
} }
// IndexGet returns an element (as Int) at a given index.
func (o *Bytes) IndexGet(index Object) (res Object, err error) {
intIdx, ok := index.(*Int)
if !ok {
err = ErrInvalidIndexType
return
}
idxVal := int(intIdx.Value)
if idxVal < 0 || idxVal >= len(o.Value) {
err = ErrIndexOutOfBounds
return
}
res = &Int{Value: int64(o.Value[idxVal])}
return
}

View file

@ -2,11 +2,17 @@ package objects
import "errors" import "errors"
// ErrNotCallable represents an error for calling on non-function objects. // ErrNotIndexable means the type is not indexable.
var ErrNotCallable = errors.New("not a callable object") var ErrNotIndexable = errors.New("type is not indexable")
// ErrNotIndexable represents an error for indexing on non-indexable objects. // ErrNotIndexAssignable means the type is not index-assignable.
var ErrNotIndexable = errors.New("non-indexable object") var ErrNotIndexAssignable = errors.New("type is not index-assignable")
// ErrIndexOutOfBounds is an error where a given index is out of the bounds.
var ErrIndexOutOfBounds = errors.New("index out of bounds")
// ErrInvalidIndexType means the type is not supported as an index.
var ErrInvalidIndexType = errors.New("invalid index type")
// ErrInvalidOperator represents an error for invalid operator usage. // ErrInvalidOperator represents an error for invalid operator usage.
var ErrInvalidOperator = errors.New("invalid operator") var ErrInvalidOperator = errors.New("invalid operator")

View file

@ -47,11 +47,20 @@ func (o *ImmutableMap) IsFalsy() bool {
return len(o.Value) == 0 return len(o.Value) == 0
} }
// Get returns the value for the given key. // IndexGet returns the value for the given key.
func (o *ImmutableMap) Get(key string) (Object, bool) { func (o *ImmutableMap) IndexGet(index Object) (res Object, err error) {
val, ok := o.Value[key] strIdx, ok := ToString(index)
if !ok {
err = ErrInvalidTypeConversion
return
}
return val, ok val, ok := o.Value[strIdx]
if !ok {
val = UndefinedValue
}
return val, nil
} }
// Equals returns true if the value of the type // Equals returns true if the value of the type

View file

@ -0,0 +1,9 @@
package objects
// IndexAssignable is an object that can take an index and a value
// on the left-hand side of the assignment statement.
type IndexAssignable interface {
// IndexSet should take an index Object and a value Object.
// If an error is returned, it will be treated as a run-time error.
IndexSet(index, value Object) error
}

9
objects/indexable.go Normal file
View file

@ -0,0 +1,9 @@
package objects
// Indexable is an object that can take an index and return an object.
type Indexable interface {
// IndexGet should take an index Object and return a result Object or an error.
// If error is returned, the runtime will treat it as a run-time error and ignore returned value.
// If nil is returned as value, it will be converted to Undefined value by the runtime.
IndexGet(index Object) (value Object, err error)
}

View file

@ -47,18 +47,6 @@ func (o *Map) IsFalsy() bool {
return len(o.Value) == 0 return len(o.Value) == 0
} }
// Get returns the value for the given key.
func (o *Map) Get(key string) (Object, bool) {
val, ok := o.Value[key]
return val, ok
}
// Set sets the value for the given key.
func (o *Map) Set(key string, value Object) {
o.Value[key] = value
}
// Equals returns true if the value of the type // Equals returns true if the value of the type
// is equal to the value of another object. // is equal to the value of another object.
func (o *Map) Equals(x Object) bool { func (o *Map) Equals(x Object) bool {
@ -80,3 +68,32 @@ func (o *Map) Equals(x Object) bool {
return true return true
} }
// IndexGet returns the value for the given key.
func (o *Map) IndexGet(index Object) (res Object, err error) {
strIdx, ok := index.(*String)
if !ok {
err = ErrInvalidIndexType
return
}
val, ok := o.Value[strIdx.Value]
if !ok {
val = UndefinedValue
}
return val, nil
}
// IndexSet sets the value for the given key.
func (o *Map) IndexSet(index, value Object) (err error) {
strIdx, ok := ToString(index)
if !ok {
err = ErrInvalidTypeConversion
return
}
o.Value[strIdx] = value
return nil
}

View file

@ -9,6 +9,7 @@ import (
// String represents a string value. // String represents a string value.
type String struct { type String struct {
Value string Value string
runeStr []rune
} }
// TypeName returns the name of the type. // TypeName returns the name of the type.
@ -56,3 +57,27 @@ func (o *String) Equals(x Object) bool {
return o.Value == t.Value return o.Value == t.Value
} }
// IndexGet returns a character at a given index.
func (o *String) IndexGet(index Object) (res Object, err error) {
intIdx, ok := index.(*Int)
if !ok {
err = ErrInvalidIndexType
return
}
idxVal := int(intIdx.Value)
if o.runeStr == nil {
o.runeStr = []rune(o.Value)
}
if idxVal < 0 || idxVal >= len(o.runeStr) {
err = ErrIndexOutOfBounds
return
}
res = &Char{Value: o.runeStr[idxVal]}
return
}

View file

@ -485,27 +485,12 @@ func (v *VM) Run() error {
numSelectors := int(v.curInsts[v.ip+3]) numSelectors := int(v.curInsts[v.ip+3])
v.ip += 3 v.ip += 3
// pop selector outcomes (left to right) // selectors and RHS value
selectors := make([]interface{}, numSelectors, numSelectors) selectors := v.stack[v.sp-numSelectors : v.sp]
for i := 0; i < numSelectors; i++ { val := v.stack[v.sp-numSelectors-1]
sel := v.stack[v.sp-1] v.sp -= numSelectors + 1
v.sp--
switch sel := (*sel).(type) { if err := indexAssign(v.globals[globalIndex], val, selectors); err != nil {
case *objects.String: // map key
selectors[i] = sel.Value
case *objects.Int: // array index
selectors[i] = int(sel.Value)
default:
return fmt.Errorf("invalid selector type: %s", sel.TypeName())
}
}
// RHS value
val := v.stack[v.sp-1]
v.sp--
if err := selectorAssign(v.globals[globalIndex], val, selectors); err != nil {
return err return err
} }
@ -583,101 +568,22 @@ func (v *VM) Run() error {
v.sp -= 2 v.sp -= 2
switch left := (*left).(type) { switch left := (*left).(type) {
case *objects.Array: case objects.Indexable:
idx, ok := (*index).(*objects.Int) val, err := left.IndexGet(*index)
if !ok { if err != nil {
return fmt.Errorf("non-integer array index: %s", left.TypeName()) return err
} }
if val == nil {
if idx.Value < 0 || idx.Value >= int64(len(left.Value)) { val = objects.UndefinedValue
return fmt.Errorf("index out of bounds: %d", index)
} }
if v.sp >= StackSize { if v.sp >= StackSize {
return ErrStackOverflow return ErrStackOverflow
} }
v.stack[v.sp] = &left.Value[idx.Value]
v.sp++
case *objects.String:
idx, ok := (*index).(*objects.Int)
if !ok {
return fmt.Errorf("non-integer array index: %s", left.TypeName())
}
str := []rune(left.Value)
if idx.Value < 0 || idx.Value >= int64(len(str)) {
return fmt.Errorf("index out of bounds: %d", index)
}
if v.sp >= StackSize {
return ErrStackOverflow
}
var val objects.Object = &objects.Char{Value: str[idx.Value]}
v.stack[v.sp] = &val v.stack[v.sp] = &val
v.sp++ v.sp++
case *objects.Bytes:
idx, ok := (*index).(*objects.Int)
if !ok {
return fmt.Errorf("non-integer array index: %s", left.TypeName())
}
if idx.Value < 0 || idx.Value >= int64(len(left.Value)) {
return fmt.Errorf("index out of bounds: %d", index)
}
if v.sp >= StackSize {
return ErrStackOverflow
}
var val objects.Object = &objects.Int{Value: int64(left.Value[idx.Value])}
v.stack[v.sp] = &val
v.sp++
case *objects.Map:
key, ok := (*index).(*objects.String)
if !ok {
return fmt.Errorf("non-string key: %s", left.TypeName())
}
var res = objects.UndefinedValue
val, ok := left.Value[key.Value]
if ok {
res = val
}
if v.sp >= StackSize {
return ErrStackOverflow
}
v.stack[v.sp] = &res
v.sp++
case *objects.ImmutableMap:
key, ok := (*index).(*objects.String)
if !ok {
return fmt.Errorf("non-string key: %s", left.TypeName())
}
var res = objects.UndefinedValue
val, ok := left.Value[key.Value]
if ok {
res = val
}
if v.sp >= StackSize {
return ErrStackOverflow
}
v.stack[v.sp] = &res
v.sp++
case *objects.Error: // err.value case *objects.Error: // err.value
key, ok := (*index).(*objects.String) key, ok := (*index).(*objects.String)
if !ok || key.Value != "value" { if !ok || key.Value != "value" {
@ -692,7 +598,7 @@ func (v *VM) Run() error {
v.sp++ v.sp++
default: default:
return fmt.Errorf("type %s does not support indexing", left.TypeName()) return objects.ErrNotIndexable
} }
case compiler.OpSliceIndex: case compiler.OpSliceIndex:
@ -927,31 +833,16 @@ func (v *VM) Run() error {
case compiler.OpSetSelLocal: case compiler.OpSetSelLocal:
localIndex := int(v.curInsts[v.ip+1]) localIndex := int(v.curInsts[v.ip+1])
numSelectors := int(v.curInsts[v.ip+2]) numSelectors := int(v.curInsts[v.ip+2])
v.ip += 2 v.ip += 2
// pop selector outcomes (left to right)
selectors := make([]interface{}, numSelectors, numSelectors)
for i := 0; i < numSelectors; i++ {
sel := v.stack[v.sp-1]
v.sp--
switch sel := (*sel).(type) { // selectors and RHS value
case *objects.String: // map key selectors := v.stack[v.sp-numSelectors : v.sp]
selectors[i] = sel.Value val := v.stack[v.sp-numSelectors-1]
case *objects.Int: // array index v.sp -= numSelectors + 1
selectors[i] = int(sel.Value)
default:
return fmt.Errorf("invalid selector type: %s", sel.TypeName())
}
}
// RHS value
val := v.stack[v.sp-1] // no need to copy value here; selectorAssign uses copy of value
v.sp--
sp := v.curFrame.basePointer + localIndex sp := v.curFrame.basePointer + localIndex
if err := selectorAssign(v.stack[sp], val, selectors); err != nil { if err := indexAssign(v.stack[sp], val, selectors); err != nil {
return err return err
} }
@ -1006,27 +897,12 @@ func (v *VM) Run() error {
numSelectors := int(v.curInsts[v.ip+2]) numSelectors := int(v.curInsts[v.ip+2])
v.ip += 2 v.ip += 2
// pop selector outcomes (left to right) // selectors and RHS value
selectors := make([]interface{}, numSelectors, numSelectors) selectors := v.stack[v.sp-numSelectors : v.sp]
for i := 0; i < numSelectors; i++ { val := v.stack[v.sp-numSelectors-1]
sel := v.stack[v.sp-1] v.sp -= numSelectors + 1
v.sp--
switch sel := (*sel).(type) { if err := indexAssign(v.curFrame.freeVars[freeIndex], val, selectors); err != nil {
case *objects.String: // map key
selectors[i] = sel.Value
case *objects.Int: // array index
selectors[i] = int(sel.Value)
default:
return fmt.Errorf("invalid selector type: %s", sel.TypeName())
}
}
// RHS value
val := v.stack[v.sp-1]
v.sp--
if err := selectorAssign(v.curFrame.freeVars[freeIndex], val, selectors); err != nil {
return err return err
} }
@ -1204,8 +1080,7 @@ func (v *VM) callFunction(fn *objects.CompiledFunction, freeVars []*objects.Obje
v.stack[v.curFrame.basePointer+p] = v.stack[v.sp-numArgs+p] v.stack[v.curFrame.basePointer+p] = v.stack[v.sp-numArgs+p]
} }
v.sp -= numArgs + 1 v.sp -= numArgs + 1
v.ip = -1 v.ip = -1 // reset IP to beginning of the frame
//v.curFrame.ip = -1 // reset IP to beginning of the frame
// stack after tail-call // stack after tail-call
// //
@ -1238,7 +1113,6 @@ func (v *VM) callFunction(fn *objects.CompiledFunction, freeVars []*objects.Obje
v.curFrame = &(v.frames[v.framesIndex]) v.curFrame = &(v.frames[v.framesIndex])
v.curFrame.fn = fn v.curFrame.fn = fn
v.curFrame.freeVars = freeVars v.curFrame.freeVars = freeVars
//v.curFrame.ip = -1
v.curFrame.basePointer = v.sp - numArgs v.curFrame.basePointer = v.sp - numArgs
v.curInsts = fn.Instructions v.curInsts = fn.Instructions
v.ip = -1 v.ip = -1
@ -1293,50 +1167,29 @@ func (v *VM) importModule(compiledModule *objects.CompiledModule) error {
return nil return nil
} }
func selectorAssign(dst, src *objects.Object, selectors []interface{}) error { func indexAssign(dst, src *objects.Object, selectors []*objects.Object) error {
numSel := len(selectors) numSel := len(selectors)
for idx := 0; idx < numSel; idx++ { for sidx := numSel - 1; sidx > 0; sidx-- {
switch sel := selectors[idx].(type) { indexable, ok := (*dst).(objects.Indexable)
case string: if !ok {
m, isMap := (*dst).(*objects.Map) return objects.ErrNotIndexable
if !isMap {
return fmt.Errorf("invalid map object for selector '%s'", sel)
} }
if idx == numSel-1 { next, err := indexable.IndexGet(*selectors[sidx])
m.Set(sel, *src)
return nil
}
nxt, found := m.Get(sel)
if !found {
return fmt.Errorf("key not found '%s'", sel)
}
dst = &nxt
case int:
arr, isArray := (*dst).(*objects.Array)
if !isArray {
return fmt.Errorf("invalid array object for select '[%d]'", sel)
}
if idx == numSel-1 {
return arr.Set(sel, *src)
}
nxt, err := arr.Get(sel)
if err != nil { if err != nil {
return err return err
} }
dst = &nxt dst = &next
default:
panic(fmt.Errorf("invalid selector term: %T", sel))
}
} }
return nil indexAssignable, ok := (*dst).(objects.IndexAssignable)
if !ok {
return objects.ErrNotIndexAssignable
}
return indexAssignable.IndexSet(*selectors[0], *src)
} }
func init() { func init() {

View file

@ -0,0 +1,200 @@
package runtime_test
import (
"strings"
"testing"
"github.com/d5/tengo/compiler/token"
"github.com/d5/tengo/objects"
)
type objectImpl struct{}
func (objectImpl) TypeName() string { return "" }
func (objectImpl) String() string { return "" }
func (objectImpl) IsFalsy() bool { return false }
func (objectImpl) Equals(another objects.Object) bool { return false }
func (objectImpl) Copy() objects.Object { return nil }
func (objectImpl) BinaryOp(token.Token, objects.Object) (objects.Object, error) {
return nil, objects.ErrInvalidOperator
}
type StringDict struct {
objectImpl
Value map[string]string
}
func (o *StringDict) TypeName() string {
return "string-dict"
}
func (o *StringDict) IndexGet(index objects.Object) (objects.Object, error) {
strIdx, ok := index.(*objects.String)
if !ok {
return nil, objects.ErrInvalidIndexType
}
for k, v := range o.Value {
if strings.ToLower(strIdx.Value) == strings.ToLower(k) {
return &objects.String{Value: v}, nil
}
}
return objects.UndefinedValue, nil
}
func (o *StringDict) IndexSet(index, value objects.Object) error {
strIdx, ok := index.(*objects.String)
if !ok {
return objects.ErrInvalidIndexType
}
strVal, ok := objects.ToString(value)
if !ok {
return objects.ErrInvalidTypeConversion
}
o.Value[strings.ToLower(strIdx.Value)] = strVal
return nil
}
type StringCircle struct {
objectImpl
Value []string
}
func (o *StringCircle) TypeName() string {
return "string-circle"
}
func (o *StringCircle) IndexGet(index objects.Object) (objects.Object, error) {
intIdx, ok := index.(*objects.Int)
if !ok {
return nil, objects.ErrInvalidIndexType
}
r := int(intIdx.Value) % len(o.Value)
if r < 0 {
r = len(o.Value) + r
}
return &objects.String{Value: o.Value[r]}, nil
}
func (o *StringCircle) IndexSet(index, value objects.Object) error {
intIdx, ok := index.(*objects.Int)
if !ok {
return objects.ErrInvalidIndexType
}
r := int(intIdx.Value) % len(o.Value)
if r < 0 {
r = len(o.Value) + r
}
strVal, ok := objects.ToString(value)
if !ok {
return objects.ErrInvalidTypeConversion
}
o.Value[r] = strVal
return nil
}
type StringArray struct {
objectImpl
Value []string
}
func (o *StringArray) TypeName() string {
return "string-array"
}
func (o *StringArray) IndexGet(index objects.Object) (objects.Object, error) {
intIdx, ok := index.(*objects.Int)
if ok {
if intIdx.Value >= 0 && intIdx.Value < int64(len(o.Value)) {
return &objects.String{Value: o.Value[intIdx.Value]}, nil
}
return nil, objects.ErrIndexOutOfBounds
}
strIdx, ok := index.(*objects.String)
if ok {
for vidx, str := range o.Value {
if strIdx.Value == str {
return &objects.Int{Value: int64(vidx)}, nil
}
}
return objects.UndefinedValue, nil
}
return nil, objects.ErrInvalidIndexType
}
func (o *StringArray) IndexSet(index, value objects.Object) error {
strVal, ok := objects.ToString(value)
if !ok {
return objects.ErrInvalidTypeConversion
}
intIdx, ok := index.(*objects.Int)
if ok {
if intIdx.Value >= 0 && intIdx.Value < int64(len(o.Value)) {
o.Value[intIdx.Value] = strVal
return nil
}
return objects.ErrIndexOutOfBounds
}
return objects.ErrInvalidIndexType
}
func TestIndexable(t *testing.T) {
dict := func() *StringDict { return &StringDict{Value: map[string]string{"a": "foo", "b": "bar"}} }
expectWithSymbols(t, `out = dict["a"]`, "foo", SYM{"dict": dict()})
expectWithSymbols(t, `out = dict["B"]`, "bar", SYM{"dict": dict()})
expectWithSymbols(t, `out = dict["x"]`, undefined(), SYM{"dict": dict()})
expectErrorWithSymbols(t, `out = dict[0]`, SYM{"dict": dict()})
strCir := func() *StringCircle { return &StringCircle{Value: []string{"one", "two", "three"}} }
expectWithSymbols(t, `out = cir[0]`, "one", SYM{"cir": strCir()})
expectWithSymbols(t, `out = cir[1]`, "two", SYM{"cir": strCir()})
expectWithSymbols(t, `out = cir[-1]`, "three", SYM{"cir": strCir()})
expectWithSymbols(t, `out = cir[-2]`, "two", SYM{"cir": strCir()})
expectWithSymbols(t, `out = cir[3]`, "one", SYM{"cir": strCir()})
expectErrorWithSymbols(t, `out = cir["a"]`, SYM{"cir": strCir()})
strArr := func() *StringArray { return &StringArray{Value: []string{"one", "two", "three"}} }
expectWithSymbols(t, `out = arr["one"]`, 0, SYM{"arr": strArr()})
expectWithSymbols(t, `out = arr["three"]`, 2, SYM{"arr": strArr()})
expectWithSymbols(t, `out = arr["four"]`, undefined(), SYM{"arr": strArr()})
expectWithSymbols(t, `out = arr[0]`, "one", SYM{"arr": strArr()})
expectWithSymbols(t, `out = arr[1]`, "two", SYM{"arr": strArr()})
expectErrorWithSymbols(t, `out = arr[-1]`, SYM{"arr": strArr()})
}
func TestIndexAssignable(t *testing.T) {
dict := func() *StringDict { return &StringDict{Value: map[string]string{"a": "foo", "b": "bar"}} }
expectWithSymbols(t, `dict["a"] = "1984"; out = dict["a"]`, "1984", SYM{"dict": dict()})
expectWithSymbols(t, `dict["c"] = "1984"; out = dict["c"]`, "1984", SYM{"dict": dict()})
expectWithSymbols(t, `dict["c"] = 1984; out = dict["C"]`, "1984", SYM{"dict": dict()})
expectErrorWithSymbols(t, `dict[0] = "1984"`, SYM{"dict": dict()})
strCir := func() *StringCircle { return &StringCircle{Value: []string{"one", "two", "three"}} }
expectWithSymbols(t, `cir[0] = "ONE"; out = cir[0]`, "ONE", SYM{"cir": strCir()})
expectWithSymbols(t, `cir[1] = "TWO"; out = cir[1]`, "TWO", SYM{"cir": strCir()})
expectWithSymbols(t, `cir[-1] = "THREE"; out = cir[2]`, "THREE", SYM{"cir": strCir()})
expectWithSymbols(t, `cir[0] = "ONE"; out = cir[3]`, "ONE", SYM{"cir": strCir()})
expectErrorWithSymbols(t, `cir["a"] = "ONE"`, SYM{"cir": strCir()})
strArr := func() *StringArray { return &StringArray{Value: []string{"one", "two", "three"}} }
expectWithSymbols(t, `arr[0] = "ONE"; out = arr[0]`, "ONE", SYM{"arr": strArr()})
expectWithSymbols(t, `arr[1] = "TWO"; out = arr[1]`, "TWO", SYM{"arr": strArr()})
expectErrorWithSymbols(t, `arr["one"] = "ONE"`, SYM{"arr": strArr()})
}

View file

@ -127,7 +127,9 @@ out = echo(["foo", "bar"])
expectErrorWithUserModules(t, `m1 := import("mod1"); m1.a = 5`, map[string]string{ expectErrorWithUserModules(t, `m1 := import("mod1"); m1.a = 5`, map[string]string{
"mod1": `a := 3`, "mod1": `a := 3`,
}) })
expectErrorWithUserModules(t, `m1 := import("mod1"); m1.a.b = 5`, map[string]string{
// module is immutable but its variables is not necessarily immutable.
expectWithUserModules(t, `m1 := import("mod1"); m1.a.b = 5; out = m1.a.b`, 5, map[string]string{
"mod1": `a := {b: 3}`, "mod1": `a := {b: 3}`,
}) })
} }

View file

@ -22,11 +22,23 @@ const (
type MAP = map[string]interface{} type MAP = map[string]interface{}
type ARR = []interface{} type ARR = []interface{}
type SYM = map[string]objects.Object
func expect(t *testing.T, input string, expected interface{}) { func expect(t *testing.T, input string, expected interface{}) {
expectWithUserModules(t, input, expected, nil) expectWithUserModules(t, input, expected, nil)
} }
func expectWithSymbols(t *testing.T, input string, expected interface{}, symbols map[string]objects.Object) {
// parse
file := parse(t, input)
if file == nil {
return
}
// compiler/VM
runVM(t, file, expected, symbols, nil)
}
func expectWithUserModules(t *testing.T, input string, expected interface{}, userModules map[string]string) { func expectWithUserModules(t *testing.T, input string, expected interface{}, userModules map[string]string) {
// parse // parse
file := parse(t, input) file := parse(t, input)
@ -35,7 +47,7 @@ func expectWithUserModules(t *testing.T, input string, expected interface{}, use
} }
// compiler/VM // compiler/VM
runVM(t, file, expected, userModules) runVM(t, file, expected, nil, userModules)
} }
func expectError(t *testing.T, input string) { func expectError(t *testing.T, input string) {
@ -50,15 +62,29 @@ func expectErrorWithUserModules(t *testing.T, input string, userModules map[stri
} }
// compiler/VM // compiler/VM
runVMError(t, program, userModules) runVMError(t, program, nil, userModules)
} }
func runVM(t *testing.T, file *ast.File, expected interface{}, userModules map[string]string) (ok bool) { func expectErrorWithSymbols(t *testing.T, input string, symbols map[string]objects.Object) {
// parse
program := parse(t, input)
if program == nil {
return
}
// compiler/VM
runVMError(t, program, symbols, nil)
}
func runVM(t *testing.T, file *ast.File, expected interface{}, symbols map[string]objects.Object, userModules map[string]string) (ok bool) {
expectedObj := toObject(expected) expectedObj := toObject(expected)
res, trace, err := traceCompileRun(file, map[string]objects.Object{ if symbols == nil {
testOut: objectZeroCopy(expectedObj), symbols = make(map[string]objects.Object)
}, userModules) }
symbols[testOut] = objectZeroCopy(expectedObj)
res, trace, err := traceCompileRun(file, symbols, userModules)
defer func() { defer func() {
if !ok { if !ok {
@ -76,8 +102,8 @@ func runVM(t *testing.T, file *ast.File, expected interface{}, userModules map[s
} }
// TODO: should differentiate compile-time error, runtime error, and, error object returned // TODO: should differentiate compile-time error, runtime error, and, error object returned
func runVMError(t *testing.T, file *ast.File, userModules map[string]string) (ok bool) { func runVMError(t *testing.T, file *ast.File, symbols map[string]objects.Object, userModules map[string]string) (ok bool) {
_, trace, err := traceCompileRun(file, nil, userModules) _, trace, err := traceCompileRun(file, symbols, userModules)
defer func() { defer func() {
if !ok { if !ok {
@ -174,7 +200,11 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object, userModu
symTable := compiler.NewSymbolTable() symTable := compiler.NewSymbolTable()
for name, value := range symbols { for name, value := range symbols {
sym := symTable.Define(name) sym := symTable.Define(name)
globals[sym.Index] = &value
// should not store pointer to 'value' variable
// which is re-used in each iteration.
valueCopy := value
globals[sym.Index] = &valueCopy
} }
for idx, fn := range objects.Builtins { for idx, fn := range objects.Builtins {
symTable.DefineBuiltin(idx, fn.Name) symTable.DefineBuiltin(idx, fn.Name)