parameters.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package rewrite
  2. import (
  3. "fmt"
  4. "github.com/kyleconroy/sqlc/internal/source"
  5. "github.com/kyleconroy/sqlc/internal/sql/ast"
  6. "github.com/kyleconroy/sqlc/internal/sql/ast/pg"
  7. "github.com/kyleconroy/sqlc/internal/sql/astutils"
  8. "github.com/kyleconroy/sqlc/internal/sql/named"
  9. )
  10. // Given an AST node, return the string representation of names
  11. func flatten(root ast.Node) (string, bool) {
  12. sw := &stringWalker{}
  13. astutils.Walk(sw, root)
  14. return sw.String, sw.IsConst
  15. }
  16. type stringWalker struct {
  17. String string
  18. IsConst bool
  19. }
  20. func (s *stringWalker) Visit(node ast.Node) astutils.Visitor {
  21. if _, ok := node.(*pg.A_Const); ok {
  22. s.IsConst = true
  23. }
  24. if n, ok := node.(*pg.String); ok {
  25. s.String += n.Str
  26. }
  27. return s
  28. }
  29. func isNamedParamSignCast(node ast.Node) bool {
  30. expr, ok := node.(*pg.A_Expr)
  31. if !ok {
  32. return false
  33. }
  34. _, cast := expr.Rexpr.(*pg.TypeCast)
  35. return astutils.Join(expr.Name, ".") == "@" && cast
  36. }
  37. func NamedParameters(raw *ast.RawStmt) (*ast.RawStmt, map[int]string, []source.Edit) {
  38. foundFunc := astutils.Search(raw, named.IsParamFunc)
  39. foundSign := astutils.Search(raw, named.IsParamSign)
  40. if len(foundFunc.Items)+len(foundSign.Items) == 0 {
  41. return raw, map[int]string{}, nil
  42. }
  43. args := map[string]int{}
  44. argn := 0
  45. var edits []source.Edit
  46. node := astutils.Apply(raw, func(cr *astutils.Cursor) bool {
  47. node := cr.Node()
  48. switch {
  49. case named.IsParamFunc(node):
  50. fun := node.(*ast.FuncCall)
  51. param, isConst := flatten(fun.Args)
  52. if num, ok := args[param]; ok {
  53. cr.Replace(&pg.ParamRef{
  54. Number: num,
  55. Location: fun.Location,
  56. })
  57. } else {
  58. argn += 1
  59. args[param] = argn
  60. cr.Replace(&pg.ParamRef{
  61. Number: argn,
  62. Location: fun.Location,
  63. })
  64. }
  65. // TODO: This code assumes that sqlc.arg(name) is on a single line
  66. var old string
  67. if isConst {
  68. old = fmt.Sprintf("sqlc.arg('%s')", param)
  69. } else {
  70. old = fmt.Sprintf("sqlc.arg(%s)", param)
  71. }
  72. edits = append(edits, source.Edit{
  73. Location: fun.Location - raw.StmtLocation,
  74. Old: old,
  75. New: fmt.Sprintf("$%d", args[param]),
  76. })
  77. return false
  78. case isNamedParamSignCast(node):
  79. expr := node.(*pg.A_Expr)
  80. cast := expr.Rexpr.(*pg.TypeCast)
  81. param, _ := flatten(cast.Arg)
  82. if num, ok := args[param]; ok {
  83. cast.Arg = &pg.ParamRef{
  84. Number: num,
  85. Location: expr.Location,
  86. }
  87. cr.Replace(cast)
  88. } else {
  89. argn += 1
  90. args[param] = argn
  91. cast.Arg = &pg.ParamRef{
  92. Number: argn,
  93. Location: expr.Location,
  94. }
  95. cr.Replace(cast)
  96. }
  97. // TODO: This code assumes that @foo::bool is on a single line
  98. edits = append(edits, source.Edit{
  99. Location: expr.Location - raw.StmtLocation,
  100. Old: fmt.Sprintf("@%s", param),
  101. New: fmt.Sprintf("$%d", args[param]),
  102. })
  103. return false
  104. case named.IsParamSign(node):
  105. expr := node.(*pg.A_Expr)
  106. param, _ := flatten(expr.Rexpr)
  107. if num, ok := args[param]; ok {
  108. cr.Replace(&pg.ParamRef{
  109. Number: num,
  110. Location: expr.Location,
  111. })
  112. } else {
  113. argn += 1
  114. args[param] = argn
  115. cr.Replace(&pg.ParamRef{
  116. Number: argn,
  117. Location: expr.Location,
  118. })
  119. }
  120. // TODO: This code assumes that @foo is on a single line
  121. edits = append(edits, source.Edit{
  122. Location: expr.Location - raw.StmtLocation,
  123. Old: fmt.Sprintf("@%s", param),
  124. New: fmt.Sprintf("$%d", args[param]),
  125. })
  126. return false
  127. default:
  128. return true
  129. }
  130. }, nil)
  131. named := map[int]string{}
  132. for k, v := range args {
  133. named[v] = k
  134. }
  135. return node.(*ast.RawStmt), named, edits
  136. }