123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- package rewrite
- import (
- "fmt"
- "github.com/kyleconroy/sqlc/internal/source"
- "github.com/kyleconroy/sqlc/internal/sql/ast"
- "github.com/kyleconroy/sqlc/internal/sql/ast/pg"
- "github.com/kyleconroy/sqlc/internal/sql/astutils"
- "github.com/kyleconroy/sqlc/internal/sql/named"
- )
- // Given an AST node, return the string representation of names
- func flatten(root ast.Node) (string, bool) {
- sw := &stringWalker{}
- astutils.Walk(sw, root)
- return sw.String, sw.IsConst
- }
- type stringWalker struct {
- String string
- IsConst bool
- }
- func (s *stringWalker) Visit(node ast.Node) astutils.Visitor {
- if _, ok := node.(*pg.A_Const); ok {
- s.IsConst = true
- }
- if n, ok := node.(*pg.String); ok {
- s.String += n.Str
- }
- return s
- }
- func isNamedParamSignCast(node ast.Node) bool {
- expr, ok := node.(*pg.A_Expr)
- if !ok {
- return false
- }
- _, cast := expr.Rexpr.(*pg.TypeCast)
- return astutils.Join(expr.Name, ".") == "@" && cast
- }
- func NamedParameters(raw *ast.RawStmt) (*ast.RawStmt, map[int]string, []source.Edit) {
- foundFunc := astutils.Search(raw, named.IsParamFunc)
- foundSign := astutils.Search(raw, named.IsParamSign)
- if len(foundFunc.Items)+len(foundSign.Items) == 0 {
- return raw, map[int]string{}, nil
- }
- args := map[string]int{}
- argn := 0
- var edits []source.Edit
- node := astutils.Apply(raw, func(cr *astutils.Cursor) bool {
- node := cr.Node()
- switch {
- case named.IsParamFunc(node):
- fun := node.(*ast.FuncCall)
- param, isConst := flatten(fun.Args)
- if num, ok := args[param]; ok {
- cr.Replace(&pg.ParamRef{
- Number: num,
- Location: fun.Location,
- })
- } else {
- argn += 1
- args[param] = argn
- cr.Replace(&pg.ParamRef{
- Number: argn,
- Location: fun.Location,
- })
- }
- // TODO: This code assumes that sqlc.arg(name) is on a single line
- var old string
- if isConst {
- old = fmt.Sprintf("sqlc.arg('%s')", param)
- } else {
- old = fmt.Sprintf("sqlc.arg(%s)", param)
- }
- edits = append(edits, source.Edit{
- Location: fun.Location - raw.StmtLocation,
- Old: old,
- New: fmt.Sprintf("$%d", args[param]),
- })
- return false
- case isNamedParamSignCast(node):
- expr := node.(*pg.A_Expr)
- cast := expr.Rexpr.(*pg.TypeCast)
- param, _ := flatten(cast.Arg)
- if num, ok := args[param]; ok {
- cast.Arg = &pg.ParamRef{
- Number: num,
- Location: expr.Location,
- }
- cr.Replace(cast)
- } else {
- argn += 1
- args[param] = argn
- cast.Arg = &pg.ParamRef{
- Number: argn,
- Location: expr.Location,
- }
- cr.Replace(cast)
- }
- // TODO: This code assumes that @foo::bool is on a single line
- edits = append(edits, source.Edit{
- Location: expr.Location - raw.StmtLocation,
- Old: fmt.Sprintf("@%s", param),
- New: fmt.Sprintf("$%d", args[param]),
- })
- return false
- case named.IsParamSign(node):
- expr := node.(*pg.A_Expr)
- param, _ := flatten(expr.Rexpr)
- if num, ok := args[param]; ok {
- cr.Replace(&pg.ParamRef{
- Number: num,
- Location: expr.Location,
- })
- } else {
- argn += 1
- args[param] = argn
- cr.Replace(&pg.ParamRef{
- Number: argn,
- Location: expr.Location,
- })
- }
- // TODO: This code assumes that @foo is on a single line
- edits = append(edits, source.Edit{
- Location: expr.Location - raw.StmtLocation,
- Old: fmt.Sprintf("@%s", param),
- New: fmt.Sprintf("$%d", args[param]),
- })
- return false
- default:
- return true
- }
- }, nil)
- named := map[int]string{}
- for k, v := range args {
- named[v] = k
- }
- return node.(*ast.RawStmt), named, edits
- }
|