find_params.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. package compiler
  2. import (
  3. "fmt"
  4. "github.com/sqlc-dev/sqlc/internal/sql/ast"
  5. "github.com/sqlc-dev/sqlc/internal/sql/astutils"
  6. )
  7. func findParameters(root ast.Node) ([]paramRef, []error) {
  8. refs := make([]paramRef, 0)
  9. errors := make([]error, 0)
  10. v := paramSearch{seen: make(map[int]struct{}), refs: &refs, errs: &errors}
  11. astutils.Walk(v, root)
  12. if len(*v.errs) > 0 {
  13. return refs, *v.errs
  14. } else {
  15. return refs, nil
  16. }
  17. }
  18. type paramRef struct {
  19. parent ast.Node
  20. rv *ast.RangeVar
  21. ref *ast.ParamRef
  22. name string // Named parameter support
  23. }
  24. type paramSearch struct {
  25. parent ast.Node
  26. rangeVar *ast.RangeVar
  27. refs *[]paramRef
  28. seen map[int]struct{}
  29. errs *[]error
  30. // XXX: Gross state hack for limit
  31. limitCount ast.Node
  32. limitOffset ast.Node
  33. }
  34. type limitCount struct {
  35. }
  36. func (l *limitCount) Pos() int {
  37. return 0
  38. }
  39. type limitOffset struct {
  40. }
  41. func (l *limitOffset) Pos() int {
  42. return 0
  43. }
  44. func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
  45. switch n := node.(type) {
  46. case *ast.A_Expr:
  47. p.parent = node
  48. case *ast.BetweenExpr:
  49. p.parent = node
  50. case *ast.CallStmt:
  51. p.parent = n.FuncCall
  52. case *ast.DeleteStmt:
  53. if n.LimitCount != nil {
  54. p.limitCount = n.LimitCount
  55. }
  56. case *ast.FuncCall:
  57. p.parent = node
  58. case *ast.InsertStmt:
  59. if s, ok := n.SelectStmt.(*ast.SelectStmt); ok {
  60. for i, item := range s.TargetList.Items {
  61. target, ok := item.(*ast.ResTarget)
  62. if !ok {
  63. continue
  64. }
  65. ref, ok := target.Val.(*ast.ParamRef)
  66. if !ok {
  67. continue
  68. }
  69. if len(n.Cols.Items) <= i {
  70. *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns"))
  71. return p
  72. }
  73. *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
  74. p.seen[ref.Location] = struct{}{}
  75. }
  76. for _, item := range s.ValuesLists.Items {
  77. vl, ok := item.(*ast.List)
  78. if !ok {
  79. continue
  80. }
  81. for i, v := range vl.Items {
  82. ref, ok := v.(*ast.ParamRef)
  83. if !ok {
  84. continue
  85. }
  86. if len(n.Cols.Items) <= i {
  87. *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns"))
  88. return p
  89. }
  90. *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
  91. p.seen[ref.Location] = struct{}{}
  92. }
  93. }
  94. }
  95. case *ast.UpdateStmt:
  96. for _, item := range n.TargetList.Items {
  97. target, ok := item.(*ast.ResTarget)
  98. if !ok {
  99. continue
  100. }
  101. ref, ok := target.Val.(*ast.ParamRef)
  102. if !ok {
  103. continue
  104. }
  105. for _, relation := range n.Relations.Items {
  106. rv, ok := relation.(*ast.RangeVar)
  107. if !ok {
  108. continue
  109. }
  110. *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv})
  111. }
  112. p.seen[ref.Location] = struct{}{}
  113. }
  114. if n.LimitCount != nil {
  115. p.limitCount = n.LimitCount
  116. }
  117. case *ast.RangeVar:
  118. p.rangeVar = n
  119. case *ast.ResTarget:
  120. p.parent = node
  121. case *ast.SelectStmt:
  122. if n.LimitCount != nil {
  123. p.limitCount = n.LimitCount
  124. }
  125. if n.LimitOffset != nil {
  126. p.limitOffset = n.LimitOffset
  127. }
  128. case *ast.TypeCast:
  129. p.parent = node
  130. case *ast.ParamRef:
  131. parent := p.parent
  132. if count, ok := p.limitCount.(*ast.ParamRef); ok {
  133. if n.Number == count.Number {
  134. parent = &limitCount{}
  135. }
  136. }
  137. if offset, ok := p.limitOffset.(*ast.ParamRef); ok {
  138. if n.Number == offset.Number {
  139. parent = &limitOffset{}
  140. }
  141. }
  142. if _, found := p.seen[n.Location]; found {
  143. break
  144. }
  145. // Special, terrible case for *ast.MultiAssignRef
  146. set := true
  147. if res, ok := parent.(*ast.ResTarget); ok {
  148. if multi, ok := res.Val.(*ast.MultiAssignRef); ok {
  149. set = false
  150. if row, ok := multi.Source.(*ast.RowExpr); ok {
  151. for i, arg := range row.Args.Items {
  152. if ref, ok := arg.(*ast.ParamRef); ok {
  153. if multi.Colno == i+1 && ref.Number == n.Number {
  154. set = true
  155. }
  156. }
  157. }
  158. }
  159. }
  160. }
  161. if set {
  162. *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar})
  163. p.seen[n.Location] = struct{}{}
  164. }
  165. return nil
  166. case *ast.In:
  167. if n.Sel == nil {
  168. p.parent = node
  169. } else {
  170. if sel, ok := n.Sel.(*ast.SelectStmt); ok && sel.FromClause != nil {
  171. from := sel.FromClause
  172. if schema, ok := from.Items[0].(*ast.RangeVar); ok && schema != nil {
  173. p.rangeVar = &ast.RangeVar{
  174. Catalogname: schema.Catalogname,
  175. Schemaname: schema.Schemaname,
  176. Relname: schema.Relname,
  177. }
  178. }
  179. }
  180. }
  181. if _, ok := n.Expr.(*ast.ParamRef); ok {
  182. p.Visit(n.Expr)
  183. }
  184. }
  185. return p
  186. }