fix bytecode encoding/decoding (#152)

* register ImmutableMap type to gob

* fix bytecode decoding

* disallow empty module name
This commit is contained in:
Daniel 2019-03-19 09:43:03 -07:00 committed by GitHub
parent 2520caf581
commit e785e38bf8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 152 additions and 31 deletions

View file

@ -17,27 +17,6 @@ type Bytecode struct {
Constants []objects.Object Constants []objects.Object
} }
// Decode reads Bytecode data from the reader.
func (b *Bytecode) Decode(r io.Reader) error {
dec := gob.NewDecoder(r)
if err := dec.Decode(&b.FileSet); err != nil {
return err
}
// TODO: files in b.FileSet.File does not have their 'set' field properly set to b.FileSet
// as it's private field and not serialized by gob encoder/decoder.
if err := dec.Decode(&b.MainFunction); err != nil {
return err
}
if err := dec.Decode(&b.Constants); err != nil {
return err
}
return nil
}
// Encode writes Bytecode data to the writer. // Encode writes Bytecode data to the writer.
func (b *Bytecode) Encode(w io.Writer) error { func (b *Bytecode) Encode(w io.Writer) error {
enc := gob.NewEncoder(w) enc := gob.NewEncoder(w)
@ -92,9 +71,20 @@ func (b *Bytecode) FormatConstants() (output []string) {
func init() { func init() {
gob.Register(&source.FileSet{}) gob.Register(&source.FileSet{})
gob.Register(&source.File{}) gob.Register(&source.File{})
gob.Register(&objects.Array{})
gob.Register(&objects.Bool{})
gob.Register(&objects.Bytes{})
gob.Register(&objects.Char{}) gob.Register(&objects.Char{})
gob.Register(&objects.Closure{})
gob.Register(&objects.CompiledFunction{}) gob.Register(&objects.CompiledFunction{})
gob.Register(&objects.Error{})
gob.Register(&objects.Float{}) gob.Register(&objects.Float{})
gob.Register(&objects.ImmutableArray{})
gob.Register(&objects.ImmutableMap{})
gob.Register(&objects.Int{}) gob.Register(&objects.Int{})
gob.Register(&objects.Map{})
gob.Register(&objects.String{}) gob.Register(&objects.String{})
gob.Register(&objects.Time{})
gob.Register(&objects.Undefined{})
gob.Register(&objects.UserFunction{})
} }

View file

@ -0,0 +1,62 @@
package compiler
import (
"encoding/gob"
"io"
"github.com/d5/tengo/objects"
)
// Decode reads Bytecode data from the reader.
func (b *Bytecode) Decode(r io.Reader) error {
dec := gob.NewDecoder(r)
if err := dec.Decode(&b.FileSet); err != nil {
return err
}
// TODO: files in b.FileSet.File does not have their 'set' field properly set to b.FileSet
// as it's private field and not serialized by gob encoder/decoder.
if err := dec.Decode(&b.MainFunction); err != nil {
return err
}
if err := dec.Decode(&b.Constants); err != nil {
return err
}
for i, v := range b.Constants {
b.Constants[i] = fixDecoded(v)
}
return nil
}
func fixDecoded(o objects.Object) objects.Object {
switch o := o.(type) {
case *objects.Bool:
if o.IsFalsy() {
return objects.FalseValue
}
return objects.TrueValue
case *objects.Undefined:
return objects.UndefinedValue
case *objects.Array:
for i, v := range o.Value {
o.Value[i] = fixDecoded(v)
}
case *objects.ImmutableArray:
for i, v := range o.Value {
o.Value[i] = fixDecoded(v)
}
case *objects.Map:
for k, v := range o.Value {
o.Value[k] = fixDecoded(v)
}
case *objects.ImmutableMap:
for k, v := range o.Value {
o.Value[k] = fixDecoded(v)
}
}
return o
}

View file

@ -25,7 +25,7 @@ func (b *Bytecode) RemoveDuplicates() {
indexMap[curIdx] = len(deduped) indexMap[curIdx] = len(deduped)
deduped = append(deduped, c) deduped = append(deduped, c)
case *objects.ImmutableMap: case *objects.ImmutableMap:
modName := c.Value["__module_name__"].(*objects.String).Value modName := moduleName(c)
newIdx, ok := immutableMaps[modName] newIdx, ok := immutableMaps[modName]
if modName != "" && ok { if modName != "" && ok {
indexMap[curIdx] = newIdx indexMap[curIdx] = newIdx
@ -72,7 +72,7 @@ func (b *Bytecode) RemoveDuplicates() {
deduped = append(deduped, c) deduped = append(deduped, c)
} }
default: default:
panic(fmt.Errorf("invalid constant type: %s", c.TypeName())) panic(fmt.Errorf("unsupported top-level constant type: %s", c.TypeName()))
} }
} }
@ -119,3 +119,11 @@ func updateConstIndexes(insts []byte, indexMap map[int]int) {
i += 1 + read i += 1 + read
} }
} }
func moduleName(mod *objects.ImmutableMap) string {
if modName, ok := mod.Value["__module_name__"].(*objects.String); ok {
return modName.Value
}
return ""
}

View file

@ -3,6 +3,7 @@ package compiler_test
import ( import (
"bytes" "bytes"
"testing" "testing"
"time"
"github.com/d5/tengo/assert" "github.com/d5/tengo/assert"
"github.com/d5/tengo/compiler" "github.com/d5/tengo/compiler"
@ -38,10 +39,64 @@ func TestBytecode(t *testing.T) {
compiler.MakeInstruction(compiler.OpConstant, 6), compiler.MakeInstruction(compiler.OpConstant, 6),
compiler.MakeInstruction(compiler.OpPop)), compiler.MakeInstruction(compiler.OpPop)),
objectsArray( objectsArray(
intObject(55), &objects.Int{Value: 55},
intObject(66), &objects.Int{Value: 66},
intObject(77), &objects.Int{Value: 77},
intObject(88), &objects.Int{Value: 88},
&objects.ImmutableMap{
Value: map[string]objects.Object{
"array": &objects.ImmutableArray{
Value: []objects.Object{
&objects.Int{Value: 1},
&objects.Int{Value: 2},
&objects.Int{Value: 3},
objects.TrueValue,
objects.FalseValue,
objects.UndefinedValue,
},
},
"true": objects.TrueValue,
"false": objects.FalseValue,
"bytes": &objects.Bytes{Value: make([]byte, 16)},
"char": &objects.Char{Value: 'Y'},
"error": &objects.Error{Value: &objects.String{Value: "some error"}},
"float": &objects.Float{Value: -19.84},
"immutable_array": &objects.ImmutableArray{
Value: []objects.Object{
&objects.Int{Value: 1},
&objects.Int{Value: 2},
&objects.Int{Value: 3},
objects.TrueValue,
objects.FalseValue,
objects.UndefinedValue,
},
},
"immutable_map": &objects.ImmutableMap{
Value: map[string]objects.Object{
"a": &objects.Int{Value: 1},
"b": &objects.Int{Value: 2},
"c": &objects.Int{Value: 3},
"d": objects.TrueValue,
"e": objects.FalseValue,
"f": objects.UndefinedValue,
},
},
"int": &objects.Int{Value: 91},
"map": &objects.Map{
Value: map[string]objects.Object{
"a": &objects.Int{Value: 1},
"b": &objects.Int{Value: 2},
"c": &objects.Int{Value: 3},
"d": objects.TrueValue,
"e": objects.FalseValue,
"f": objects.UndefinedValue,
},
},
"string": &objects.String{Value: "foo bar"},
"time": &objects.Time{Value: time.Now()},
"undefined": objects.UndefinedValue,
},
},
compiledFunction(1, 0, compiledFunction(1, 0,
compiler.MakeInstruction(compiler.OpConstant, 3), compiler.MakeInstruction(compiler.OpConstant, 3),
compiler.MakeInstruction(compiler.OpSetLocal, 0), compiler.MakeInstruction(compiler.OpSetLocal, 0),
@ -179,10 +234,10 @@ func TestBytecode_CountObjects(t *testing.T) {
b := bytecode( b := bytecode(
concat(), concat(),
objectsArray( objectsArray(
intObject(55), &objects.Int{Value: 55},
intObject(66), &objects.Int{Value: 66},
intObject(77), &objects.Int{Value: 77},
intObject(88), &objects.Int{Value: 88},
compiledFunction(1, 0, compiledFunction(1, 0,
compiler.MakeInstruction(compiler.OpConstant, 3), compiler.MakeInstruction(compiler.OpConstant, 3),
compiler.MakeInstruction(compiler.OpReturnValue)), compiler.MakeInstruction(compiler.OpReturnValue)),

View file

@ -509,6 +509,10 @@ func (c *Compiler) Compile(node ast.Node) error {
c.emit(node, OpCall, len(node.Args)) c.emit(node, OpCall, len(node.Args))
case *ast.ImportExpr: case *ast.ImportExpr:
if node.ModuleName == "" {
return c.errorf(node, "empty module name")
}
if mod, ok := c.importModules[node.ModuleName]; ok { if mod, ok := c.importModules[node.ModuleName]; ok {
v, err := mod.Import(node.ModuleName) v, err := mod.Import(node.ModuleName)
if err != nil { if err != nil {

View file

@ -891,6 +891,8 @@ r["x"] = {
@k:1 @k:1
} }
`, "Parse Error: illegal character U+0040 '@'\n\tat test:3:5 (and 10 more errors)") // too many errors `, "Parse Error: illegal character U+0040 '@'\n\tat test:3:5 (and 10 more errors)") // too many errors
expectError(t, `import("")`, "empty module name")
} }
func concat(instructions ...[]byte) []byte { func concat(instructions ...[]byte) []byte {