123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331 |
- package compiler
- import (
- "errors"
- "fmt"
- "github.com/kyleconroy/sqlc/internal/sql/ast"
- "github.com/kyleconroy/sqlc/internal/sql/ast/pg"
- "github.com/kyleconroy/sqlc/internal/sql/astutils"
- "github.com/kyleconroy/sqlc/internal/sql/lang"
- "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
- )
- func hasStarRef(cf *pg.ColumnRef) bool {
- for _, item := range cf.Fields.Items {
- if _, ok := item.(*pg.A_Star); ok {
- return true
- }
- }
- return false
- }
- // Compute the output columns for a statement.
- //
- // Return an error if column references are ambiguous
- // Return an error if column references don't exist
- func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
- tables, err := sourceTables(qc, node)
- if err != nil {
- return nil, err
- }
- var targets *ast.List
- switch n := node.(type) {
- case *pg.DeleteStmt:
- targets = n.ReturningList
- case *pg.InsertStmt:
- targets = n.ReturningList
- case *pg.SelectStmt:
- targets = n.TargetList
- case *pg.TruncateStmt:
- targets = &ast.List{}
- case *pg.UpdateStmt:
- targets = n.ReturningList
- default:
- return nil, fmt.Errorf("outputColumns: unsupported node type: %T", n)
- }
- var cols []*Column
- for _, target := range targets.Items {
- res, ok := target.(*pg.ResTarget)
- if !ok {
- continue
- }
- switch n := res.Val.(type) {
- case *pg.A_Expr:
- name := ""
- if res.Name != nil {
- name = *res.Name
- }
- switch {
- case lang.IsComparisonOperator(astutils.Join(n.Name, "")):
- // TODO: Generate a name for these operations
- cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
- case lang.IsMathematicalOperator(astutils.Join(n.Name, "")):
- // TODO: Generate correct numeric type
- cols = append(cols, &Column{Name: name, DataType: "pg_catalog.int4", NotNull: true})
- default:
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- case *pg.CaseExpr:
- name := ""
- if res.Name != nil {
- name = *res.Name
- }
- // TODO: The TypeCase code has been copied from below. Instead, we need a recurse function to get the type of a node.
- if tc, ok := n.Defresult.(*pg.TypeCast); ok {
- if tc.TypeName == nil {
- return nil, errors.New("no type name type cast")
- }
- name := ""
- if ref, ok := tc.Arg.(*pg.ColumnRef); ok {
- name = astutils.Join(ref.Fields, "_")
- }
- if res.Name != nil {
- name = *res.Name
- }
- // TODO Validate column names
- col := toColumn(tc.TypeName)
- col.Name = name
- cols = append(cols, col)
- } else {
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- case *pg.CoalesceExpr:
- for _, arg := range n.Args.Items {
- if ref, ok := arg.(*pg.ColumnRef); ok {
- columns, err := outputColumnRefs(res, tables, ref)
- if err != nil {
- return nil, err
- }
- for _, c := range columns {
- c.NotNull = true
- cols = append(cols, c)
- }
- }
- }
- case *pg.ColumnRef:
- if hasStarRef(n) {
- // TODO: This code is copied in func expand()
- for _, t := range tables {
- scope := astutils.Join(n.Fields, ".")
- if scope != "" && scope != t.Rel.Name {
- continue
- }
- for _, c := range t.Columns {
- cname := c.Name
- if res.Name != nil {
- cname = *res.Name
- }
- cols = append(cols, &Column{
- Name: cname,
- Type: c.Type,
- Scope: scope,
- Table: c.Table,
- DataType: c.DataType,
- NotNull: c.NotNull,
- IsArray: c.IsArray,
- })
- }
- }
- continue
- }
- columns, err := outputColumnRefs(res, tables, n)
- if err != nil {
- return nil, err
- }
- cols = append(cols, columns...)
- case *ast.FuncCall:
- rel := n.Func
- name := rel.Name
- if res.Name != nil {
- name = *res.Name
- }
- fun, err := qc.catalog.GetFuncN(rel, len(n.Args.Items))
- if err == nil {
- cols = append(cols, &Column{Name: name, DataType: dataType(fun.ReturnType), NotNull: true})
- } else {
- cols = append(cols, &Column{Name: name, DataType: "any"})
- }
- case *pg.SubLink:
- name := "exists"
- if res.Name != nil {
- name = *res.Name
- }
- switch n.SubLinkType {
- case pg.EXISTS_SUBLINK:
- cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
- default:
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- case *pg.TypeCast:
- if n.TypeName == nil {
- return nil, errors.New("no type name type cast")
- }
- name := ""
- if ref, ok := n.Arg.(*pg.ColumnRef); ok {
- name = astutils.Join(ref.Fields, "_")
- }
- if res.Name != nil {
- name = *res.Name
- }
- // TODO Validate column names
- col := toColumn(n.TypeName)
- col.Name = name
- cols = append(cols, col)
- default:
- name := ""
- if res.Name != nil {
- name = *res.Name
- }
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- }
- return cols, nil
- }
- // Compute the output columns for a statement.
- //
- // Return an error if column references are ambiguous
- // Return an error if column references don't exist
- // Return an error if a table is referenced twice
- // Return an error if an unknown column is referenced
- func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
- var list *ast.List
- switch n := node.(type) {
- case *pg.DeleteStmt:
- list = &ast.List{
- Items: []ast.Node{n.Relation},
- }
- case *pg.InsertStmt:
- list = &ast.List{
- Items: []ast.Node{n.Relation},
- }
- case *pg.SelectStmt:
- list = astutils.Search(n.FromClause, func(node ast.Node) bool {
- switch node.(type) {
- case *pg.RangeVar, *pg.RangeSubselect:
- return true
- default:
- return false
- }
- })
- case *pg.TruncateStmt:
- list = astutils.Search(n.Relations, func(node ast.Node) bool {
- _, ok := node.(*pg.RangeVar)
- return ok
- })
- case *pg.UpdateStmt:
- list = &ast.List{
- Items: append(n.FromClause.Items, n.Relation),
- }
- default:
- return nil, fmt.Errorf("sourceTables: unsupported node type: %T", n)
- }
- var tables []*Table
- for _, item := range list.Items {
- switch n := item.(type) {
- case *pg.RangeSubselect:
- cols, err := outputColumns(qc, n.Subquery)
- if err != nil {
- return nil, err
- }
- tables = append(tables, &Table{
- Rel: &ast.TableName{
- Name: *n.Alias.Aliasname,
- },
- Columns: cols,
- })
- case *pg.RangeVar:
- fqn, err := ParseTableName(n)
- if err != nil {
- return nil, err
- }
- table, cerr := qc.GetTable(fqn)
- if cerr != nil {
- // TODO: Update error location
- // cerr.Location = n.Location
- // return nil, *cerr
- return nil, cerr
- }
- if n.Alias != nil {
- table.Rel = &ast.TableName{
- Catalog: table.Rel.Catalog,
- Schema: table.Rel.Schema,
- Name: *n.Alias.Aliasname,
- }
- }
- tables = append(tables, table)
- default:
- return nil, fmt.Errorf("sourceTable: unsupported list item type: %T", n)
- }
- }
- return tables, nil
- }
- func outputColumnRefs(res *pg.ResTarget, tables []*Table, node *pg.ColumnRef) ([]*Column, error) {
- parts := stringSlice(node.Fields)
- var name, alias string
- switch {
- case len(parts) == 1:
- name = parts[0]
- case len(parts) == 2:
- alias = parts[0]
- name = parts[1]
- default:
- return nil, fmt.Errorf("unknown number of fields: %d", len(parts))
- }
- var cols []*Column
- var found int
- for _, t := range tables {
- if alias != "" && t.Rel.Name != alias {
- continue
- }
- for _, c := range t.Columns {
- if c.Name == name {
- found += 1
- cname := c.Name
- if res.Name != nil {
- cname = *res.Name
- }
- cols = append(cols, &Column{
- Name: cname,
- Type: c.Type,
- Table: c.Table,
- DataType: c.DataType,
- NotNull: c.NotNull,
- IsArray: c.IsArray,
- })
- }
- }
- }
- if found == 0 {
- return nil, &sqlerr.Error{
- Code: "42703",
- Message: fmt.Sprintf("column \"%s\" does not exist", name),
- Location: res.Location,
- }
- }
- if found > 1 {
- return nil, &sqlerr.Error{
- Code: "42703",
- Message: fmt.Sprintf("column reference \"%s\" is ambiguous", name),
- Location: res.Location,
- }
- }
- return cols, nil
- }
|