find_params.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package compiler
  2. import (
  3. "github.com/kyleconroy/sqlc/internal/sql/ast"
  4. "github.com/kyleconroy/sqlc/internal/sql/ast/pg"
  5. "github.com/kyleconroy/sqlc/internal/sql/astutils"
  6. )
  7. func findParameters(root ast.Node) []paramRef {
  8. refs := make([]paramRef, 0)
  9. v := paramSearch{seen: make(map[int]struct{}), refs: &refs}
  10. astutils.Walk(v, root)
  11. return refs
  12. }
  13. type paramRef struct {
  14. parent ast.Node
  15. rv *pg.RangeVar
  16. ref *pg.ParamRef
  17. name string // Named parameter support
  18. }
  19. type paramSearch struct {
  20. parent ast.Node
  21. rangeVar *pg.RangeVar
  22. refs *[]paramRef
  23. seen map[int]struct{}
  24. // XXX: Gross state hack for limit
  25. limitCount ast.Node
  26. limitOffset ast.Node
  27. }
  28. type limitCount struct {
  29. }
  30. func (l *limitCount) Pos() int {
  31. return 0
  32. }
  33. type limitOffset struct {
  34. }
  35. func (l *limitOffset) Pos() int {
  36. return 0
  37. }
  38. func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
  39. switch n := node.(type) {
  40. case *pg.A_Expr:
  41. p.parent = node
  42. case *ast.FuncCall:
  43. p.parent = node
  44. case *pg.InsertStmt:
  45. if s, ok := n.SelectStmt.(*pg.SelectStmt); ok {
  46. for i, item := range s.TargetList.Items {
  47. target, ok := item.(*pg.ResTarget)
  48. if !ok {
  49. continue
  50. }
  51. ref, ok := target.Val.(*pg.ParamRef)
  52. if !ok {
  53. continue
  54. }
  55. // TODO: Out-of-bounds panic
  56. *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
  57. p.seen[ref.Location] = struct{}{}
  58. }
  59. for _, item := range s.ValuesLists.Items {
  60. vl, ok := item.(*ast.List)
  61. if !ok {
  62. continue
  63. }
  64. for i, v := range vl.Items {
  65. ref, ok := v.(*pg.ParamRef)
  66. if !ok {
  67. continue
  68. }
  69. // TODO: Out-of-bounds panic
  70. *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
  71. p.seen[ref.Location] = struct{}{}
  72. }
  73. }
  74. }
  75. case *pg.RangeVar:
  76. p.rangeVar = n
  77. case *pg.ResTarget:
  78. p.parent = node
  79. case *pg.SelectStmt:
  80. if n.LimitCount != nil {
  81. p.limitCount = n.LimitCount
  82. }
  83. if n.LimitOffset != nil {
  84. p.limitOffset = n.LimitOffset
  85. }
  86. case *pg.TypeCast:
  87. p.parent = node
  88. case *pg.ParamRef:
  89. parent := p.parent
  90. if count, ok := p.limitCount.(*pg.ParamRef); ok {
  91. if n.Number == count.Number {
  92. parent = &limitCount{}
  93. }
  94. }
  95. if offset, ok := p.limitOffset.(*pg.ParamRef); ok {
  96. if n.Number == offset.Number {
  97. parent = &limitOffset{}
  98. }
  99. }
  100. if _, found := p.seen[n.Location]; found {
  101. break
  102. }
  103. // Special, terrible case for *pg.MultiAssignRef
  104. set := true
  105. if res, ok := parent.(*pg.ResTarget); ok {
  106. if multi, ok := res.Val.(*pg.MultiAssignRef); ok {
  107. set = false
  108. if row, ok := multi.Source.(*pg.RowExpr); ok {
  109. for i, arg := range row.Args.Items {
  110. if ref, ok := arg.(*pg.ParamRef); ok {
  111. if multi.Colno == i+1 && ref.Number == n.Number {
  112. set = true
  113. }
  114. }
  115. }
  116. }
  117. }
  118. }
  119. if set {
  120. *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar})
  121. p.seen[n.Location] = struct{}{}
  122. }
  123. return nil
  124. }
  125. return p
  126. }