123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- package postgresql
- import (
- "strings"
- "testing"
- "github.com/sqlc-dev/sqlc/internal/sql/ast"
- "github.com/sqlc-dev/sqlc/internal/sql/astutils"
- "github.com/google/go-cmp/cmp"
- )
- func TestApply(t *testing.T) {
- p := NewParser()
- input, err := p.Parse(strings.NewReader("SELECT sqlc.arg(name)"))
- if err != nil {
- t.Fatal(err)
- }
- output, err := p.Parse(strings.NewReader("SELECT $1"))
- if err != nil {
- t.Fatal(err)
- }
- expect := &output[0]
- actual := astutils.Apply(&input[0], func(cr *astutils.Cursor) bool {
- fun, ok := cr.Node().(*ast.FuncCall)
- if !ok {
- return true
- }
- if astutils.Join(fun.Funcname, ".") == "sqlc.arg" {
- cr.Replace(&ast.ParamRef{
- Dollar: true,
- Number: 1,
- Location: fun.Location,
- })
- return false
- }
- return true
- }, nil)
- if diff := cmp.Diff(expect, actual); diff != "" {
- t.Errorf("rewrite mismatch:\n%s", diff)
- }
- }
|