output_columns.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. package compiler
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/kyleconroy/sqlc/internal/sql/ast"
  6. "github.com/kyleconroy/sqlc/internal/sql/ast/pg"
  7. "github.com/kyleconroy/sqlc/internal/sql/astutils"
  8. "github.com/kyleconroy/sqlc/internal/sql/lang"
  9. "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
  10. )
  11. func hasStarRef(cf *pg.ColumnRef) bool {
  12. for _, item := range cf.Fields.Items {
  13. if _, ok := item.(*pg.A_Star); ok {
  14. return true
  15. }
  16. }
  17. return false
  18. }
  19. // Compute the output columns for a statement.
  20. //
  21. // Return an error if column references are ambiguous
  22. // Return an error if column references don't exist
  23. func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
  24. tables, err := sourceTables(qc, node)
  25. if err != nil {
  26. return nil, err
  27. }
  28. var targets *ast.List
  29. switch n := node.(type) {
  30. case *pg.DeleteStmt:
  31. targets = n.ReturningList
  32. case *pg.InsertStmt:
  33. targets = n.ReturningList
  34. case *pg.SelectStmt:
  35. targets = n.TargetList
  36. case *pg.TruncateStmt:
  37. targets = &ast.List{}
  38. case *pg.UpdateStmt:
  39. targets = n.ReturningList
  40. default:
  41. return nil, fmt.Errorf("outputColumns: unsupported node type: %T", n)
  42. }
  43. var cols []*Column
  44. for _, target := range targets.Items {
  45. res, ok := target.(*pg.ResTarget)
  46. if !ok {
  47. continue
  48. }
  49. switch n := res.Val.(type) {
  50. case *pg.A_Expr:
  51. name := ""
  52. if res.Name != nil {
  53. name = *res.Name
  54. }
  55. switch {
  56. case lang.IsComparisonOperator(astutils.Join(n.Name, "")):
  57. // TODO: Generate a name for these operations
  58. cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
  59. case lang.IsMathematicalOperator(astutils.Join(n.Name, "")):
  60. // TODO: Generate correct numeric type
  61. cols = append(cols, &Column{Name: name, DataType: "pg_catalog.int4", NotNull: true})
  62. default:
  63. cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
  64. }
  65. case *pg.CaseExpr:
  66. name := ""
  67. if res.Name != nil {
  68. name = *res.Name
  69. }
  70. // TODO: The TypeCase code has been copied from below. Instead, we need a recurse function to get the type of a node.
  71. if tc, ok := n.Defresult.(*pg.TypeCast); ok {
  72. if tc.TypeName == nil {
  73. return nil, errors.New("no type name type cast")
  74. }
  75. name := ""
  76. if ref, ok := tc.Arg.(*pg.ColumnRef); ok {
  77. name = astutils.Join(ref.Fields, "_")
  78. }
  79. if res.Name != nil {
  80. name = *res.Name
  81. }
  82. // TODO Validate column names
  83. col := toColumn(tc.TypeName)
  84. col.Name = name
  85. cols = append(cols, col)
  86. } else {
  87. cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
  88. }
  89. case *pg.CoalesceExpr:
  90. for _, arg := range n.Args.Items {
  91. if ref, ok := arg.(*pg.ColumnRef); ok {
  92. columns, err := outputColumnRefs(res, tables, ref)
  93. if err != nil {
  94. return nil, err
  95. }
  96. for _, c := range columns {
  97. c.NotNull = true
  98. cols = append(cols, c)
  99. }
  100. }
  101. }
  102. case *pg.ColumnRef:
  103. if hasStarRef(n) {
  104. // TODO: This code is copied in func expand()
  105. for _, t := range tables {
  106. scope := astutils.Join(n.Fields, ".")
  107. if scope != "" && scope != t.Rel.Name {
  108. continue
  109. }
  110. for _, c := range t.Columns {
  111. cname := c.Name
  112. if res.Name != nil {
  113. cname = *res.Name
  114. }
  115. cols = append(cols, &Column{
  116. Name: cname,
  117. Type: c.Type,
  118. Scope: scope,
  119. Table: c.Table,
  120. DataType: c.DataType,
  121. NotNull: c.NotNull,
  122. IsArray: c.IsArray,
  123. })
  124. }
  125. }
  126. continue
  127. }
  128. columns, err := outputColumnRefs(res, tables, n)
  129. if err != nil {
  130. return nil, err
  131. }
  132. cols = append(cols, columns...)
  133. case *ast.FuncCall:
  134. rel := n.Func
  135. name := rel.Name
  136. if res.Name != nil {
  137. name = *res.Name
  138. }
  139. fun, err := qc.catalog.GetFuncN(rel, len(n.Args.Items))
  140. if err == nil {
  141. cols = append(cols, &Column{Name: name, DataType: dataType(fun.ReturnType), NotNull: true})
  142. } else {
  143. cols = append(cols, &Column{Name: name, DataType: "any"})
  144. }
  145. case *pg.SubLink:
  146. name := "exists"
  147. if res.Name != nil {
  148. name = *res.Name
  149. }
  150. switch n.SubLinkType {
  151. case pg.EXISTS_SUBLINK:
  152. cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
  153. default:
  154. cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
  155. }
  156. case *pg.TypeCast:
  157. if n.TypeName == nil {
  158. return nil, errors.New("no type name type cast")
  159. }
  160. name := ""
  161. if ref, ok := n.Arg.(*pg.ColumnRef); ok {
  162. name = astutils.Join(ref.Fields, "_")
  163. }
  164. if res.Name != nil {
  165. name = *res.Name
  166. }
  167. // TODO Validate column names
  168. col := toColumn(n.TypeName)
  169. col.Name = name
  170. cols = append(cols, col)
  171. default:
  172. name := ""
  173. if res.Name != nil {
  174. name = *res.Name
  175. }
  176. cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
  177. }
  178. }
  179. return cols, nil
  180. }
  181. // Compute the output columns for a statement.
  182. //
  183. // Return an error if column references are ambiguous
  184. // Return an error if column references don't exist
  185. // Return an error if a table is referenced twice
  186. // Return an error if an unknown column is referenced
  187. func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
  188. var list *ast.List
  189. switch n := node.(type) {
  190. case *pg.DeleteStmt:
  191. list = &ast.List{
  192. Items: []ast.Node{n.Relation},
  193. }
  194. case *pg.InsertStmt:
  195. list = &ast.List{
  196. Items: []ast.Node{n.Relation},
  197. }
  198. case *pg.SelectStmt:
  199. list = astutils.Search(n.FromClause, func(node ast.Node) bool {
  200. switch node.(type) {
  201. case *pg.RangeVar, *pg.RangeSubselect:
  202. return true
  203. default:
  204. return false
  205. }
  206. })
  207. case *pg.TruncateStmt:
  208. list = astutils.Search(n.Relations, func(node ast.Node) bool {
  209. _, ok := node.(*pg.RangeVar)
  210. return ok
  211. })
  212. case *pg.UpdateStmt:
  213. list = &ast.List{
  214. Items: append(n.FromClause.Items, n.Relation),
  215. }
  216. default:
  217. return nil, fmt.Errorf("sourceTables: unsupported node type: %T", n)
  218. }
  219. var tables []*Table
  220. for _, item := range list.Items {
  221. switch n := item.(type) {
  222. case *pg.RangeSubselect:
  223. cols, err := outputColumns(qc, n.Subquery)
  224. if err != nil {
  225. return nil, err
  226. }
  227. tables = append(tables, &Table{
  228. Rel: &ast.TableName{
  229. Name: *n.Alias.Aliasname,
  230. },
  231. Columns: cols,
  232. })
  233. case *pg.RangeVar:
  234. fqn, err := ParseTableName(n)
  235. if err != nil {
  236. return nil, err
  237. }
  238. table, cerr := qc.GetTable(fqn)
  239. if cerr != nil {
  240. // TODO: Update error location
  241. // cerr.Location = n.Location
  242. // return nil, *cerr
  243. return nil, cerr
  244. }
  245. if n.Alias != nil {
  246. table.Rel = &ast.TableName{
  247. Catalog: table.Rel.Catalog,
  248. Schema: table.Rel.Schema,
  249. Name: *n.Alias.Aliasname,
  250. }
  251. }
  252. tables = append(tables, table)
  253. default:
  254. return nil, fmt.Errorf("sourceTable: unsupported list item type: %T", n)
  255. }
  256. }
  257. return tables, nil
  258. }
  259. func outputColumnRefs(res *pg.ResTarget, tables []*Table, node *pg.ColumnRef) ([]*Column, error) {
  260. parts := stringSlice(node.Fields)
  261. var name, alias string
  262. switch {
  263. case len(parts) == 1:
  264. name = parts[0]
  265. case len(parts) == 2:
  266. alias = parts[0]
  267. name = parts[1]
  268. default:
  269. return nil, fmt.Errorf("unknown number of fields: %d", len(parts))
  270. }
  271. var cols []*Column
  272. var found int
  273. for _, t := range tables {
  274. if alias != "" && t.Rel.Name != alias {
  275. continue
  276. }
  277. for _, c := range t.Columns {
  278. if c.Name == name {
  279. found += 1
  280. cname := c.Name
  281. if res.Name != nil {
  282. cname = *res.Name
  283. }
  284. cols = append(cols, &Column{
  285. Name: cname,
  286. Type: c.Type,
  287. Table: c.Table,
  288. DataType: c.DataType,
  289. NotNull: c.NotNull,
  290. IsArray: c.IsArray,
  291. })
  292. }
  293. }
  294. }
  295. if found == 0 {
  296. return nil, &sqlerr.Error{
  297. Code: "42703",
  298. Message: fmt.Sprintf("column \"%s\" does not exist", name),
  299. Location: res.Location,
  300. }
  301. }
  302. if found > 1 {
  303. return nil, &sqlerr.Error{
  304. Code: "42703",
  305. Message: fmt.Sprintf("column reference \"%s\" is ambiguous", name),
  306. Location: res.Location,
  307. }
  308. }
  309. return cols, nil
  310. }