123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- package compiler
- import (
- "fmt"
- "github.com/sqlc-dev/sqlc/internal/sql/ast"
- "github.com/sqlc-dev/sqlc/internal/sql/astutils"
- )
- func findParameters(root ast.Node) ([]paramRef, []error) {
- refs := make([]paramRef, 0)
- errors := make([]error, 0)
- v := paramSearch{seen: make(map[int]struct{}), refs: &refs, errs: &errors}
- astutils.Walk(v, root)
- if len(*v.errs) > 0 {
- return refs, *v.errs
- } else {
- return refs, nil
- }
- }
- type paramRef struct {
- parent ast.Node
- rv *ast.RangeVar
- ref *ast.ParamRef
- name string // Named parameter support
- }
- type paramSearch struct {
- parent ast.Node
- rangeVar *ast.RangeVar
- refs *[]paramRef
- seen map[int]struct{}
- errs *[]error
- // XXX: Gross state hack for limit
- limitCount ast.Node
- limitOffset ast.Node
- }
- type limitCount struct {
- }
- func (l *limitCount) Pos() int {
- return 0
- }
- type limitOffset struct {
- }
- func (l *limitOffset) Pos() int {
- return 0
- }
- func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
- switch n := node.(type) {
- case *ast.A_Expr:
- p.parent = node
- case *ast.BetweenExpr:
- p.parent = node
- case *ast.CallStmt:
- p.parent = n.FuncCall
- case *ast.DeleteStmt:
- if n.LimitCount != nil {
- p.limitCount = n.LimitCount
- }
- case *ast.FuncCall:
- p.parent = node
- case *ast.InsertStmt:
- if s, ok := n.SelectStmt.(*ast.SelectStmt); ok {
- for i, item := range s.TargetList.Items {
- target, ok := item.(*ast.ResTarget)
- if !ok {
- continue
- }
- ref, ok := target.Val.(*ast.ParamRef)
- if !ok {
- continue
- }
- if len(n.Cols.Items) <= i {
- *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns"))
- return p
- }
- *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
- p.seen[ref.Location] = struct{}{}
- }
- for _, item := range s.ValuesLists.Items {
- vl, ok := item.(*ast.List)
- if !ok {
- continue
- }
- for i, v := range vl.Items {
- ref, ok := v.(*ast.ParamRef)
- if !ok {
- continue
- }
- if len(n.Cols.Items) <= i {
- *p.errs = append(*p.errs, fmt.Errorf("INSERT has more expressions than target columns"))
- return p
- }
- *p.refs = append(*p.refs, paramRef{parent: n.Cols.Items[i], ref: ref, rv: n.Relation})
- p.seen[ref.Location] = struct{}{}
- }
- }
- }
- case *ast.UpdateStmt:
- for _, item := range n.TargetList.Items {
- target, ok := item.(*ast.ResTarget)
- if !ok {
- continue
- }
- ref, ok := target.Val.(*ast.ParamRef)
- if !ok {
- continue
- }
- for _, relation := range n.Relations.Items {
- rv, ok := relation.(*ast.RangeVar)
- if !ok {
- continue
- }
- *p.refs = append(*p.refs, paramRef{parent: target, ref: ref, rv: rv})
- }
- p.seen[ref.Location] = struct{}{}
- }
- if n.LimitCount != nil {
- p.limitCount = n.LimitCount
- }
- case *ast.RangeVar:
- p.rangeVar = n
- case *ast.ResTarget:
- p.parent = node
- case *ast.SelectStmt:
- if n.LimitCount != nil {
- p.limitCount = n.LimitCount
- }
- if n.LimitOffset != nil {
- p.limitOffset = n.LimitOffset
- }
- case *ast.TypeCast:
- p.parent = node
- case *ast.ParamRef:
- parent := p.parent
- if count, ok := p.limitCount.(*ast.ParamRef); ok {
- if n.Number == count.Number {
- parent = &limitCount{}
- }
- }
- if offset, ok := p.limitOffset.(*ast.ParamRef); ok {
- if n.Number == offset.Number {
- parent = &limitOffset{}
- }
- }
- if _, found := p.seen[n.Location]; found {
- break
- }
- // Special, terrible case for *ast.MultiAssignRef
- set := true
- if res, ok := parent.(*ast.ResTarget); ok {
- if multi, ok := res.Val.(*ast.MultiAssignRef); ok {
- set = false
- if row, ok := multi.Source.(*ast.RowExpr); ok {
- for i, arg := range row.Args.Items {
- if ref, ok := arg.(*ast.ParamRef); ok {
- if multi.Colno == i+1 && ref.Number == n.Number {
- set = true
- }
- }
- }
- }
- }
- }
- if set {
- *p.refs = append(*p.refs, paramRef{parent: parent, ref: n, rv: p.rangeVar})
- p.seen[n.Location] = struct{}{}
- }
- return nil
- case *ast.In:
- if n.Sel == nil {
- p.parent = node
- } else {
- if sel, ok := n.Sel.(*ast.SelectStmt); ok && sel.FromClause != nil {
- from := sel.FromClause
- if schema, ok := from.Items[0].(*ast.RangeVar); ok && schema != nil {
- p.rangeVar = &ast.RangeVar{
- Catalogname: schema.Catalogname,
- Schemaname: schema.Schemaname,
- Relname: schema.Relname,
- }
- }
- }
- }
- if _, ok := n.Expr.(*ast.ParamRef); ok {
- p.Visit(n.Expr)
- }
- }
- return p
- }
|