1
0

rewrite_test.go 888 B

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. package postgresql
  2. import (
  3. "strings"
  4. "testing"
  5. "github.com/kyleconroy/sqlc/internal/sql/ast"
  6. "github.com/kyleconroy/sqlc/internal/sql/astutils"
  7. "github.com/google/go-cmp/cmp"
  8. )
  9. func TestApply(t *testing.T) {
  10. p := NewParser()
  11. input, err := p.Parse(strings.NewReader("SELECT sqlc.arg(name)"))
  12. if err != nil {
  13. t.Fatal(err)
  14. }
  15. output, err := p.Parse(strings.NewReader("SELECT $1"))
  16. if err != nil {
  17. t.Fatal(err)
  18. }
  19. expect := &output[0]
  20. actual := astutils.Apply(&input[0], func(cr *astutils.Cursor) bool {
  21. fun, ok := cr.Node().(*ast.FuncCall)
  22. if !ok {
  23. return true
  24. }
  25. if astutils.Join(fun.Funcname, ".") == "sqlc.arg" {
  26. cr.Replace(&ast.ParamRef{
  27. Dollar: true,
  28. Number: 1,
  29. Location: fun.Location,
  30. })
  31. return false
  32. }
  33. return true
  34. }, nil)
  35. if diff := cmp.Diff(expect, actual); diff != "" {
  36. t.Errorf("rewrite mismatch:\n%s", diff)
  37. }
  38. }