expand.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package compiler
  2. import (
  3. "fmt"
  4. "strings"
  5. "github.com/kyleconroy/sqlc/internal/config"
  6. "github.com/kyleconroy/sqlc/internal/source"
  7. "github.com/kyleconroy/sqlc/internal/sql/ast"
  8. "github.com/kyleconroy/sqlc/internal/sql/astutils"
  9. )
  10. func (c *Compiler) expand(qc *QueryCatalog, raw *ast.RawStmt) ([]source.Edit, error) {
  11. list := astutils.Search(raw, func(node ast.Node) bool {
  12. switch node.(type) {
  13. case *ast.DeleteStmt:
  14. case *ast.InsertStmt:
  15. case *ast.SelectStmt:
  16. case *ast.UpdateStmt:
  17. default:
  18. return false
  19. }
  20. return true
  21. })
  22. if len(list.Items) == 0 {
  23. return nil, nil
  24. }
  25. var edits []source.Edit
  26. for _, item := range list.Items {
  27. edit, err := c.expandStmt(qc, raw, item)
  28. if err != nil {
  29. return nil, err
  30. }
  31. edits = append(edits, edit...)
  32. }
  33. return edits, nil
  34. }
  35. func (c *Compiler) quoteIdent(ident string) string {
  36. if c.parser.IsReservedKeyword(ident) {
  37. switch c.conf.Engine {
  38. case config.EngineMySQL:
  39. return "`" + ident + "`"
  40. default:
  41. return "\"" + ident + "\""
  42. }
  43. }
  44. return ident
  45. }
  46. func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) ([]source.Edit, error) {
  47. tables, err := sourceTables(qc, node)
  48. if err != nil {
  49. return nil, err
  50. }
  51. var targets *ast.List
  52. switch n := node.(type) {
  53. case *ast.DeleteStmt:
  54. targets = n.ReturningList
  55. case *ast.InsertStmt:
  56. targets = n.ReturningList
  57. case *ast.SelectStmt:
  58. targets = n.TargetList
  59. case *ast.UpdateStmt:
  60. targets = n.ReturningList
  61. default:
  62. return nil, fmt.Errorf("outputColumns: unsupported node type: %T", n)
  63. }
  64. var edits []source.Edit
  65. for _, target := range targets.Items {
  66. res, ok := target.(*ast.ResTarget)
  67. if !ok {
  68. continue
  69. }
  70. ref, ok := res.Val.(*ast.ColumnRef)
  71. if !ok {
  72. continue
  73. }
  74. if !hasStarRef(ref) {
  75. continue
  76. }
  77. var parts, cols []string
  78. for _, f := range ref.Fields.Items {
  79. switch field := f.(type) {
  80. case *ast.String:
  81. parts = append(parts, field.Str)
  82. case *ast.A_Star:
  83. parts = append(parts, "*")
  84. default:
  85. return nil, fmt.Errorf("unknown field in ColumnRef: %T", f)
  86. }
  87. }
  88. scope := astutils.Join(ref.Fields, ".")
  89. counts := map[string]int{}
  90. if scope == "" {
  91. for _, t := range tables {
  92. for _, c := range t.Columns {
  93. counts[c.Name] += 1
  94. }
  95. }
  96. }
  97. for _, t := range tables {
  98. if scope != "" && scope != t.Rel.Name {
  99. continue
  100. }
  101. tableName := c.quoteIdent(t.Rel.Name)
  102. scopeName := c.quoteIdent(scope)
  103. for _, column := range t.Columns {
  104. cname := column.Name
  105. if res.Name != nil {
  106. cname = *res.Name
  107. }
  108. cname = c.quoteIdent(cname)
  109. if scope != "" {
  110. cname = scopeName + "." + cname
  111. }
  112. if counts[cname] > 1 {
  113. cname = tableName + "." + cname
  114. }
  115. cols = append(cols, cname)
  116. }
  117. }
  118. var old []string
  119. for _, p := range parts {
  120. old = append(old, c.quoteIdent(p))
  121. }
  122. edits = append(edits, source.Edit{
  123. Location: res.Location - raw.StmtLocation,
  124. Old: strings.Join(old, "."),
  125. New: strings.Join(cols, ", "),
  126. })
  127. }
  128. return edits, nil
  129. }