diff --git a/compiler/compiler_assign.go b/compiler/compiler_assign.go index 3aceb0f..3efc376 100644 --- a/compiler/compiler_assign.go +++ b/compiler/compiler_assign.go @@ -135,6 +135,7 @@ func resolveAssignLHS(expr ast.Expr) (name string, selectors []ast.Expr, err err selectors = append(selectors, term.Sel) return + case *ast.IndexExpr: name, selectors, err = resolveAssignLHS(term.Expr) if err != nil { diff --git a/objects/array.go b/objects/array.go index 5df9d81..d7109e6 100644 --- a/objects/array.go +++ b/objects/array.go @@ -1,7 +1,6 @@ package objects import ( - "errors" "fmt" "strings" @@ -79,22 +78,40 @@ func (o *Array) Equals(x Object) bool { return true } -// Get returns an element at a given index. -func (o *Array) Get(index int) (Object, error) { - if index < 0 || index >= len(o.Value) { - return nil, errors.New("array index out of bounds") +// IndexGet returns an element at a given index. +func (o *Array) IndexGet(index Object) (res Object, err error) { + intIdx, ok := index.(*Int) + if !ok { + err = ErrInvalidIndexType + return } - return o.Value[index], nil + idxVal := int(intIdx.Value) + + if idxVal < 0 || idxVal >= len(o.Value) { + err = ErrIndexOutOfBounds + return + } + + res = o.Value[idxVal] + + return } -// Set sets an element at a given index. -func (o *Array) Set(index int, value Object) error { - if index < 0 || index >= len(o.Value) { - return errors.New("array index out of bounds") +// IndexSet sets an element at a given index. +func (o *Array) IndexSet(index, value Object) (err error) { + intIdx, ok := ToInt(index) + if !ok { + err = ErrInvalidTypeConversion + return } - o.Value[index] = value + if intIdx < 0 || intIdx >= len(o.Value) { + err = ErrIndexOutOfBounds + return + } + + o.Value[intIdx] = value return nil } diff --git a/objects/bytes.go b/objects/bytes.go index 58e8059..f77d969 100644 --- a/objects/bytes.go +++ b/objects/bytes.go @@ -54,3 +54,23 @@ func (o *Bytes) Equals(x Object) bool { return bytes.Compare(o.Value, t.Value) == 0 } + +// IndexGet returns an element (as Int) at a given index. +func (o *Bytes) IndexGet(index Object) (res Object, err error) { + intIdx, ok := index.(*Int) + if !ok { + err = ErrInvalidIndexType + return + } + + idxVal := int(intIdx.Value) + + if idxVal < 0 || idxVal >= len(o.Value) { + err = ErrIndexOutOfBounds + return + } + + res = &Int{Value: int64(o.Value[idxVal])} + + return +} diff --git a/objects/errors.go b/objects/errors.go index 19fbaa9..53e7d1c 100644 --- a/objects/errors.go +++ b/objects/errors.go @@ -2,11 +2,17 @@ package objects import "errors" -// ErrNotCallable represents an error for calling on non-function objects. -var ErrNotCallable = errors.New("not a callable object") +// ErrNotIndexable means the type is not indexable. +var ErrNotIndexable = errors.New("type is not indexable") -// ErrNotIndexable represents an error for indexing on non-indexable objects. -var ErrNotIndexable = errors.New("non-indexable object") +// ErrNotIndexAssignable means the type is not index-assignable. +var ErrNotIndexAssignable = errors.New("type is not index-assignable") + +// ErrIndexOutOfBounds is an error where a given index is out of the bounds. +var ErrIndexOutOfBounds = errors.New("index out of bounds") + +// ErrInvalidIndexType means the type is not supported as an index. +var ErrInvalidIndexType = errors.New("invalid index type") // ErrInvalidOperator represents an error for invalid operator usage. var ErrInvalidOperator = errors.New("invalid operator") diff --git a/objects/immutable_map.go b/objects/immutable_map.go index 4e29281..70c46e7 100644 --- a/objects/immutable_map.go +++ b/objects/immutable_map.go @@ -47,11 +47,20 @@ func (o *ImmutableMap) IsFalsy() bool { return len(o.Value) == 0 } -// Get returns the value for the given key. -func (o *ImmutableMap) Get(key string) (Object, bool) { - val, ok := o.Value[key] +// IndexGet returns the value for the given key. +func (o *ImmutableMap) IndexGet(index Object) (res Object, err error) { + strIdx, ok := ToString(index) + if !ok { + err = ErrInvalidTypeConversion + return + } - return val, ok + val, ok := o.Value[strIdx] + if !ok { + val = UndefinedValue + } + + return val, nil } // Equals returns true if the value of the type diff --git a/objects/index_assignable.go b/objects/index_assignable.go new file mode 100644 index 0000000..a1c6cbf --- /dev/null +++ b/objects/index_assignable.go @@ -0,0 +1,9 @@ +package objects + +// IndexAssignable is an object that can take an index and a value +// on the left-hand side of the assignment statement. +type IndexAssignable interface { + // IndexSet should take an index Object and a value Object. + // If an error is returned, it will be treated as a run-time error. + IndexSet(index, value Object) error +} diff --git a/objects/indexable.go b/objects/indexable.go new file mode 100644 index 0000000..bbc8163 --- /dev/null +++ b/objects/indexable.go @@ -0,0 +1,9 @@ +package objects + +// Indexable is an object that can take an index and return an object. +type Indexable interface { + // IndexGet should take an index Object and return a result Object or an error. + // If error is returned, the runtime will treat it as a run-time error and ignore returned value. + // If nil is returned as value, it will be converted to Undefined value by the runtime. + IndexGet(index Object) (value Object, err error) +} diff --git a/objects/map.go b/objects/map.go index f7438f7..0bd800c 100644 --- a/objects/map.go +++ b/objects/map.go @@ -47,18 +47,6 @@ func (o *Map) IsFalsy() bool { return len(o.Value) == 0 } -// Get returns the value for the given key. -func (o *Map) Get(key string) (Object, bool) { - val, ok := o.Value[key] - - return val, ok -} - -// Set sets the value for the given key. -func (o *Map) Set(key string, value Object) { - o.Value[key] = value -} - // Equals returns true if the value of the type // is equal to the value of another object. func (o *Map) Equals(x Object) bool { @@ -80,3 +68,32 @@ func (o *Map) Equals(x Object) bool { return true } + +// IndexGet returns the value for the given key. +func (o *Map) IndexGet(index Object) (res Object, err error) { + strIdx, ok := index.(*String) + if !ok { + err = ErrInvalidIndexType + return + } + + val, ok := o.Value[strIdx.Value] + if !ok { + val = UndefinedValue + } + + return val, nil +} + +// IndexSet sets the value for the given key. +func (o *Map) IndexSet(index, value Object) (err error) { + strIdx, ok := ToString(index) + if !ok { + err = ErrInvalidTypeConversion + return + } + + o.Value[strIdx] = value + + return nil +} diff --git a/objects/string.go b/objects/string.go index 67d5ddd..1b6227c 100644 --- a/objects/string.go +++ b/objects/string.go @@ -8,7 +8,8 @@ import ( // String represents a string value. type String struct { - Value string + Value string + runeStr []rune } // TypeName returns the name of the type. @@ -56,3 +57,27 @@ func (o *String) Equals(x Object) bool { return o.Value == t.Value } + +// IndexGet returns a character at a given index. +func (o *String) IndexGet(index Object) (res Object, err error) { + intIdx, ok := index.(*Int) + if !ok { + err = ErrInvalidIndexType + return + } + + idxVal := int(intIdx.Value) + + if o.runeStr == nil { + o.runeStr = []rune(o.Value) + } + + if idxVal < 0 || idxVal >= len(o.runeStr) { + err = ErrIndexOutOfBounds + return + } + + res = &Char{Value: o.runeStr[idxVal]} + + return +} diff --git a/runtime/vm.go b/runtime/vm.go index b8a0e78..a15e8e5 100644 --- a/runtime/vm.go +++ b/runtime/vm.go @@ -485,27 +485,12 @@ func (v *VM) Run() error { numSelectors := int(v.curInsts[v.ip+3]) v.ip += 3 - // pop selector outcomes (left to right) - selectors := make([]interface{}, numSelectors, numSelectors) - for i := 0; i < numSelectors; i++ { - sel := v.stack[v.sp-1] - v.sp-- + // selectors and RHS value + selectors := v.stack[v.sp-numSelectors : v.sp] + val := v.stack[v.sp-numSelectors-1] + v.sp -= numSelectors + 1 - switch sel := (*sel).(type) { - case *objects.String: // map key - selectors[i] = sel.Value - case *objects.Int: // array index - selectors[i] = int(sel.Value) - default: - return fmt.Errorf("invalid selector type: %s", sel.TypeName()) - } - } - - // RHS value - val := v.stack[v.sp-1] - v.sp-- - - if err := selectorAssign(v.globals[globalIndex], val, selectors); err != nil { + if err := indexAssign(v.globals[globalIndex], val, selectors); err != nil { return err } @@ -583,101 +568,22 @@ func (v *VM) Run() error { v.sp -= 2 switch left := (*left).(type) { - case *objects.Array: - idx, ok := (*index).(*objects.Int) - if !ok { - return fmt.Errorf("non-integer array index: %s", left.TypeName()) + case objects.Indexable: + val, err := left.IndexGet(*index) + if err != nil { + return err } - - if idx.Value < 0 || idx.Value >= int64(len(left.Value)) { - return fmt.Errorf("index out of bounds: %d", index) + if val == nil { + val = objects.UndefinedValue } if v.sp >= StackSize { return ErrStackOverflow } - v.stack[v.sp] = &left.Value[idx.Value] - v.sp++ - - case *objects.String: - idx, ok := (*index).(*objects.Int) - if !ok { - return fmt.Errorf("non-integer array index: %s", left.TypeName()) - } - - str := []rune(left.Value) - - if idx.Value < 0 || idx.Value >= int64(len(str)) { - return fmt.Errorf("index out of bounds: %d", index) - } - - if v.sp >= StackSize { - return ErrStackOverflow - } - - var val objects.Object = &objects.Char{Value: str[idx.Value]} - v.stack[v.sp] = &val v.sp++ - case *objects.Bytes: - idx, ok := (*index).(*objects.Int) - if !ok { - return fmt.Errorf("non-integer array index: %s", left.TypeName()) - } - - if idx.Value < 0 || idx.Value >= int64(len(left.Value)) { - return fmt.Errorf("index out of bounds: %d", index) - } - - if v.sp >= StackSize { - return ErrStackOverflow - } - - var val objects.Object = &objects.Int{Value: int64(left.Value[idx.Value])} - - v.stack[v.sp] = &val - v.sp++ - - case *objects.Map: - key, ok := (*index).(*objects.String) - if !ok { - return fmt.Errorf("non-string key: %s", left.TypeName()) - } - - var res = objects.UndefinedValue - val, ok := left.Value[key.Value] - if ok { - res = val - } - - if v.sp >= StackSize { - return ErrStackOverflow - } - - v.stack[v.sp] = &res - v.sp++ - - case *objects.ImmutableMap: - key, ok := (*index).(*objects.String) - if !ok { - return fmt.Errorf("non-string key: %s", left.TypeName()) - } - - var res = objects.UndefinedValue - val, ok := left.Value[key.Value] - if ok { - res = val - } - - if v.sp >= StackSize { - return ErrStackOverflow - } - - v.stack[v.sp] = &res - v.sp++ - case *objects.Error: // err.value key, ok := (*index).(*objects.String) if !ok || key.Value != "value" { @@ -692,7 +598,7 @@ func (v *VM) Run() error { v.sp++ default: - return fmt.Errorf("type %s does not support indexing", left.TypeName()) + return objects.ErrNotIndexable } case compiler.OpSliceIndex: @@ -927,31 +833,16 @@ func (v *VM) Run() error { case compiler.OpSetSelLocal: localIndex := int(v.curInsts[v.ip+1]) numSelectors := int(v.curInsts[v.ip+2]) - v.ip += 2 - // pop selector outcomes (left to right) - selectors := make([]interface{}, numSelectors, numSelectors) - for i := 0; i < numSelectors; i++ { - sel := v.stack[v.sp-1] - v.sp-- - switch sel := (*sel).(type) { - case *objects.String: // map key - selectors[i] = sel.Value - case *objects.Int: // array index - selectors[i] = int(sel.Value) - default: - return fmt.Errorf("invalid selector type: %s", sel.TypeName()) - } - } - - // RHS value - val := v.stack[v.sp-1] // no need to copy value here; selectorAssign uses copy of value - v.sp-- + // selectors and RHS value + selectors := v.stack[v.sp-numSelectors : v.sp] + val := v.stack[v.sp-numSelectors-1] + v.sp -= numSelectors + 1 sp := v.curFrame.basePointer + localIndex - if err := selectorAssign(v.stack[sp], val, selectors); err != nil { + if err := indexAssign(v.stack[sp], val, selectors); err != nil { return err } @@ -1006,27 +897,12 @@ func (v *VM) Run() error { numSelectors := int(v.curInsts[v.ip+2]) v.ip += 2 - // pop selector outcomes (left to right) - selectors := make([]interface{}, numSelectors, numSelectors) - for i := 0; i < numSelectors; i++ { - sel := v.stack[v.sp-1] - v.sp-- + // selectors and RHS value + selectors := v.stack[v.sp-numSelectors : v.sp] + val := v.stack[v.sp-numSelectors-1] + v.sp -= numSelectors + 1 - switch sel := (*sel).(type) { - case *objects.String: // map key - selectors[i] = sel.Value - case *objects.Int: // array index - selectors[i] = int(sel.Value) - default: - return fmt.Errorf("invalid selector type: %s", sel.TypeName()) - } - } - - // RHS value - val := v.stack[v.sp-1] - v.sp-- - - if err := selectorAssign(v.curFrame.freeVars[freeIndex], val, selectors); err != nil { + if err := indexAssign(v.curFrame.freeVars[freeIndex], val, selectors); err != nil { return err } @@ -1204,8 +1080,7 @@ func (v *VM) callFunction(fn *objects.CompiledFunction, freeVars []*objects.Obje v.stack[v.curFrame.basePointer+p] = v.stack[v.sp-numArgs+p] } v.sp -= numArgs + 1 - v.ip = -1 - //v.curFrame.ip = -1 // reset IP to beginning of the frame + v.ip = -1 // reset IP to beginning of the frame // stack after tail-call // @@ -1238,7 +1113,6 @@ func (v *VM) callFunction(fn *objects.CompiledFunction, freeVars []*objects.Obje v.curFrame = &(v.frames[v.framesIndex]) v.curFrame.fn = fn v.curFrame.freeVars = freeVars - //v.curFrame.ip = -1 v.curFrame.basePointer = v.sp - numArgs v.curInsts = fn.Instructions v.ip = -1 @@ -1293,50 +1167,29 @@ func (v *VM) importModule(compiledModule *objects.CompiledModule) error { return nil } -func selectorAssign(dst, src *objects.Object, selectors []interface{}) error { +func indexAssign(dst, src *objects.Object, selectors []*objects.Object) error { numSel := len(selectors) - for idx := 0; idx < numSel; idx++ { - switch sel := selectors[idx].(type) { - case string: - m, isMap := (*dst).(*objects.Map) - if !isMap { - return fmt.Errorf("invalid map object for selector '%s'", sel) - } - - if idx == numSel-1 { - m.Set(sel, *src) - return nil - } - - nxt, found := m.Get(sel) - if !found { - return fmt.Errorf("key not found '%s'", sel) - } - - dst = &nxt - case int: - arr, isArray := (*dst).(*objects.Array) - if !isArray { - return fmt.Errorf("invalid array object for select '[%d]'", sel) - } - - if idx == numSel-1 { - return arr.Set(sel, *src) - } - - nxt, err := arr.Get(sel) - if err != nil { - return err - } - - dst = &nxt - default: - panic(fmt.Errorf("invalid selector term: %T", sel)) + for sidx := numSel - 1; sidx > 0; sidx-- { + indexable, ok := (*dst).(objects.Indexable) + if !ok { + return objects.ErrNotIndexable } + + next, err := indexable.IndexGet(*selectors[sidx]) + if err != nil { + return err + } + + dst = &next } - return nil + indexAssignable, ok := (*dst).(objects.IndexAssignable) + if !ok { + return objects.ErrNotIndexAssignable + } + + return indexAssignable.IndexSet(*selectors[0], *src) } func init() { diff --git a/runtime/vm_indexable_test.go b/runtime/vm_indexable_test.go new file mode 100644 index 0000000..19959ca --- /dev/null +++ b/runtime/vm_indexable_test.go @@ -0,0 +1,200 @@ +package runtime_test + +import ( + "strings" + "testing" + + "github.com/d5/tengo/compiler/token" + "github.com/d5/tengo/objects" +) + +type objectImpl struct{} + +func (objectImpl) TypeName() string { return "" } +func (objectImpl) String() string { return "" } +func (objectImpl) IsFalsy() bool { return false } +func (objectImpl) Equals(another objects.Object) bool { return false } +func (objectImpl) Copy() objects.Object { return nil } +func (objectImpl) BinaryOp(token.Token, objects.Object) (objects.Object, error) { + return nil, objects.ErrInvalidOperator +} + +type StringDict struct { + objectImpl + Value map[string]string +} + +func (o *StringDict) TypeName() string { + return "string-dict" +} + +func (o *StringDict) IndexGet(index objects.Object) (objects.Object, error) { + strIdx, ok := index.(*objects.String) + if !ok { + return nil, objects.ErrInvalidIndexType + } + + for k, v := range o.Value { + if strings.ToLower(strIdx.Value) == strings.ToLower(k) { + return &objects.String{Value: v}, nil + } + } + + return objects.UndefinedValue, nil +} + +func (o *StringDict) IndexSet(index, value objects.Object) error { + strIdx, ok := index.(*objects.String) + if !ok { + return objects.ErrInvalidIndexType + } + + strVal, ok := objects.ToString(value) + if !ok { + return objects.ErrInvalidTypeConversion + } + + o.Value[strings.ToLower(strIdx.Value)] = strVal + + return nil +} + +type StringCircle struct { + objectImpl + Value []string +} + +func (o *StringCircle) TypeName() string { + return "string-circle" +} + +func (o *StringCircle) IndexGet(index objects.Object) (objects.Object, error) { + intIdx, ok := index.(*objects.Int) + if !ok { + return nil, objects.ErrInvalidIndexType + } + + r := int(intIdx.Value) % len(o.Value) + if r < 0 { + r = len(o.Value) + r + } + + return &objects.String{Value: o.Value[r]}, nil +} + +func (o *StringCircle) IndexSet(index, value objects.Object) error { + intIdx, ok := index.(*objects.Int) + if !ok { + return objects.ErrInvalidIndexType + } + + r := int(intIdx.Value) % len(o.Value) + if r < 0 { + r = len(o.Value) + r + } + + strVal, ok := objects.ToString(value) + if !ok { + return objects.ErrInvalidTypeConversion + } + + o.Value[r] = strVal + + return nil +} + +type StringArray struct { + objectImpl + Value []string +} + +func (o *StringArray) TypeName() string { + return "string-array" +} + +func (o *StringArray) IndexGet(index objects.Object) (objects.Object, error) { + intIdx, ok := index.(*objects.Int) + if ok { + if intIdx.Value >= 0 && intIdx.Value < int64(len(o.Value)) { + return &objects.String{Value: o.Value[intIdx.Value]}, nil + } + + return nil, objects.ErrIndexOutOfBounds + } + + strIdx, ok := index.(*objects.String) + if ok { + for vidx, str := range o.Value { + if strIdx.Value == str { + return &objects.Int{Value: int64(vidx)}, nil + } + } + + return objects.UndefinedValue, nil + } + + return nil, objects.ErrInvalidIndexType +} + +func (o *StringArray) IndexSet(index, value objects.Object) error { + strVal, ok := objects.ToString(value) + if !ok { + return objects.ErrInvalidTypeConversion + } + + intIdx, ok := index.(*objects.Int) + if ok { + if intIdx.Value >= 0 && intIdx.Value < int64(len(o.Value)) { + o.Value[intIdx.Value] = strVal + return nil + } + + return objects.ErrIndexOutOfBounds + } + + return objects.ErrInvalidIndexType +} + +func TestIndexable(t *testing.T) { + dict := func() *StringDict { return &StringDict{Value: map[string]string{"a": "foo", "b": "bar"}} } + expectWithSymbols(t, `out = dict["a"]`, "foo", SYM{"dict": dict()}) + expectWithSymbols(t, `out = dict["B"]`, "bar", SYM{"dict": dict()}) + expectWithSymbols(t, `out = dict["x"]`, undefined(), SYM{"dict": dict()}) + expectErrorWithSymbols(t, `out = dict[0]`, SYM{"dict": dict()}) + + strCir := func() *StringCircle { return &StringCircle{Value: []string{"one", "two", "three"}} } + expectWithSymbols(t, `out = cir[0]`, "one", SYM{"cir": strCir()}) + expectWithSymbols(t, `out = cir[1]`, "two", SYM{"cir": strCir()}) + expectWithSymbols(t, `out = cir[-1]`, "three", SYM{"cir": strCir()}) + expectWithSymbols(t, `out = cir[-2]`, "two", SYM{"cir": strCir()}) + expectWithSymbols(t, `out = cir[3]`, "one", SYM{"cir": strCir()}) + expectErrorWithSymbols(t, `out = cir["a"]`, SYM{"cir": strCir()}) + + strArr := func() *StringArray { return &StringArray{Value: []string{"one", "two", "three"}} } + expectWithSymbols(t, `out = arr["one"]`, 0, SYM{"arr": strArr()}) + expectWithSymbols(t, `out = arr["three"]`, 2, SYM{"arr": strArr()}) + expectWithSymbols(t, `out = arr["four"]`, undefined(), SYM{"arr": strArr()}) + expectWithSymbols(t, `out = arr[0]`, "one", SYM{"arr": strArr()}) + expectWithSymbols(t, `out = arr[1]`, "two", SYM{"arr": strArr()}) + expectErrorWithSymbols(t, `out = arr[-1]`, SYM{"arr": strArr()}) +} + +func TestIndexAssignable(t *testing.T) { + dict := func() *StringDict { return &StringDict{Value: map[string]string{"a": "foo", "b": "bar"}} } + expectWithSymbols(t, `dict["a"] = "1984"; out = dict["a"]`, "1984", SYM{"dict": dict()}) + expectWithSymbols(t, `dict["c"] = "1984"; out = dict["c"]`, "1984", SYM{"dict": dict()}) + expectWithSymbols(t, `dict["c"] = 1984; out = dict["C"]`, "1984", SYM{"dict": dict()}) + expectErrorWithSymbols(t, `dict[0] = "1984"`, SYM{"dict": dict()}) + + strCir := func() *StringCircle { return &StringCircle{Value: []string{"one", "two", "three"}} } + expectWithSymbols(t, `cir[0] = "ONE"; out = cir[0]`, "ONE", SYM{"cir": strCir()}) + expectWithSymbols(t, `cir[1] = "TWO"; out = cir[1]`, "TWO", SYM{"cir": strCir()}) + expectWithSymbols(t, `cir[-1] = "THREE"; out = cir[2]`, "THREE", SYM{"cir": strCir()}) + expectWithSymbols(t, `cir[0] = "ONE"; out = cir[3]`, "ONE", SYM{"cir": strCir()}) + expectErrorWithSymbols(t, `cir["a"] = "ONE"`, SYM{"cir": strCir()}) + + strArr := func() *StringArray { return &StringArray{Value: []string{"one", "two", "three"}} } + expectWithSymbols(t, `arr[0] = "ONE"; out = arr[0]`, "ONE", SYM{"arr": strArr()}) + expectWithSymbols(t, `arr[1] = "TWO"; out = arr[1]`, "TWO", SYM{"arr": strArr()}) + expectErrorWithSymbols(t, `arr["one"] = "ONE"`, SYM{"arr": strArr()}) +} diff --git a/runtime/vm_module_test.go b/runtime/vm_module_test.go index b63b8b2..19fcf67 100644 --- a/runtime/vm_module_test.go +++ b/runtime/vm_module_test.go @@ -127,7 +127,9 @@ out = echo(["foo", "bar"]) expectErrorWithUserModules(t, `m1 := import("mod1"); m1.a = 5`, map[string]string{ "mod1": `a := 3`, }) - expectErrorWithUserModules(t, `m1 := import("mod1"); m1.a.b = 5`, map[string]string{ + + // module is immutable but its variables is not necessarily immutable. + expectWithUserModules(t, `m1 := import("mod1"); m1.a.b = 5; out = m1.a.b`, 5, map[string]string{ "mod1": `a := {b: 3}`, }) } diff --git a/runtime/vm_test.go b/runtime/vm_test.go index f3f0722..f4e4bc5 100644 --- a/runtime/vm_test.go +++ b/runtime/vm_test.go @@ -22,11 +22,23 @@ const ( type MAP = map[string]interface{} type ARR = []interface{} +type SYM = map[string]objects.Object func expect(t *testing.T, input string, expected interface{}) { expectWithUserModules(t, input, expected, nil) } +func expectWithSymbols(t *testing.T, input string, expected interface{}, symbols map[string]objects.Object) { + // parse + file := parse(t, input) + if file == nil { + return + } + + // compiler/VM + runVM(t, file, expected, symbols, nil) +} + func expectWithUserModules(t *testing.T, input string, expected interface{}, userModules map[string]string) { // parse file := parse(t, input) @@ -35,7 +47,7 @@ func expectWithUserModules(t *testing.T, input string, expected interface{}, use } // compiler/VM - runVM(t, file, expected, userModules) + runVM(t, file, expected, nil, userModules) } func expectError(t *testing.T, input string) { @@ -50,15 +62,29 @@ func expectErrorWithUserModules(t *testing.T, input string, userModules map[stri } // compiler/VM - runVMError(t, program, userModules) + runVMError(t, program, nil, userModules) } -func runVM(t *testing.T, file *ast.File, expected interface{}, userModules map[string]string) (ok bool) { +func expectErrorWithSymbols(t *testing.T, input string, symbols map[string]objects.Object) { + // parse + program := parse(t, input) + if program == nil { + return + } + + // compiler/VM + runVMError(t, program, symbols, nil) +} + +func runVM(t *testing.T, file *ast.File, expected interface{}, symbols map[string]objects.Object, userModules map[string]string) (ok bool) { expectedObj := toObject(expected) - res, trace, err := traceCompileRun(file, map[string]objects.Object{ - testOut: objectZeroCopy(expectedObj), - }, userModules) + if symbols == nil { + symbols = make(map[string]objects.Object) + } + symbols[testOut] = objectZeroCopy(expectedObj) + + res, trace, err := traceCompileRun(file, symbols, userModules) defer func() { if !ok { @@ -76,8 +102,8 @@ func runVM(t *testing.T, file *ast.File, expected interface{}, userModules map[s } // TODO: should differentiate compile-time error, runtime error, and, error object returned -func runVMError(t *testing.T, file *ast.File, userModules map[string]string) (ok bool) { - _, trace, err := traceCompileRun(file, nil, userModules) +func runVMError(t *testing.T, file *ast.File, symbols map[string]objects.Object, userModules map[string]string) (ok bool) { + _, trace, err := traceCompileRun(file, symbols, userModules) defer func() { if !ok { @@ -174,7 +200,11 @@ func traceCompileRun(file *ast.File, symbols map[string]objects.Object, userModu symTable := compiler.NewSymbolTable() for name, value := range symbols { sym := symTable.Define(name) - globals[sym.Index] = &value + + // should not store pointer to 'value' variable + // which is re-used in each iteration. + valueCopy := value + globals[sym.Index] = &valueCopy } for idx, fn := range objects.Builtins { symTable.DefineBuiltin(idx, fn.Name)