|
- package compiler
- import (
- "errors"
- "fmt"
- "github.com/sqlc-dev/sqlc/internal/sql/ast"
- "github.com/sqlc-dev/sqlc/internal/sql/astutils"
- "github.com/sqlc-dev/sqlc/internal/sql/catalog"
- "github.com/sqlc-dev/sqlc/internal/sql/lang"
- "github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
- )
- // OutputColumns determines which columns a statement will output
- func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
- qc, err := c.buildQueryCatalog(c.catalog, stmt, nil)
- if err != nil {
- return nil, err
- }
- cols, err := c.outputColumns(qc, stmt)
- if err != nil {
- return nil, err
- }
- catCols := make([]*catalog.Column, 0, len(cols))
- for _, col := range cols {
- catCols = append(catCols, &catalog.Column{
- Name: col.Name,
- Type: ast.TypeName{Name: col.DataType},
- IsNotNull: col.NotNull,
- IsUnsigned: col.Unsigned,
- IsArray: col.IsArray,
- ArrayDims: col.ArrayDims,
- Comment: col.Comment,
- Length: col.Length,
- })
- }
- return catCols, nil
- }
- func hasStarRef(cf *ast.ColumnRef) bool {
- for _, item := range cf.Fields.Items {
- if _, ok := item.(*ast.A_Star); ok {
- return true
- }
- }
- return false
- }
- // Compute the output columns for a statement.
- //
- // Return an error if column references are ambiguous
- // Return an error if column references don't exist
- func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
- tables, err := c.sourceTables(qc, node)
- if err != nil {
- return nil, err
- }
- targets := &ast.List{}
- switch n := node.(type) {
- case *ast.DeleteStmt:
- targets = n.ReturningList
- case *ast.InsertStmt:
- targets = n.ReturningList
- case *ast.SelectStmt:
- targets = n.TargetList
- isUnion := len(targets.Items) == 0 && n.Larg != nil
- if n.GroupClause != nil {
- for _, item := range n.GroupClause.Items {
- if err := findColumnForNode(item, tables, targets); err != nil {
- return nil, err
- }
- }
- }
- validateOrderBy := true
- if c.conf.StrictOrderBy != nil {
- validateOrderBy = *c.conf.StrictOrderBy
- }
- if !isUnion && validateOrderBy {
- if n.SortClause != nil {
- for _, item := range n.SortClause.Items {
- sb, ok := item.(*ast.SortBy)
- if !ok {
- continue
- }
- if err := findColumnForNode(sb.Node, tables, targets); err != nil {
- return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err)
- }
- }
- }
- if n.WindowClause != nil {
- for _, item := range n.WindowClause.Items {
- sb, ok := item.(*ast.List)
- if !ok {
- continue
- }
- for _, single := range sb.Items {
- caseExpr, ok := single.(*ast.CaseExpr)
- if !ok {
- continue
- }
- if err := findColumnForNode(caseExpr.Xpr, tables, targets); err != nil {
- return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err)
- }
- }
- }
- }
- }
- // For UNION queries, targets is empty and we need to look for the
- // columns in Largs.
- if isUnion {
- return c.outputColumns(qc, n.Larg)
- }
- case *ast.UpdateStmt:
- targets = n.ReturningList
- }
- var cols []*Column
- for _, target := range targets.Items {
- res, ok := target.(*ast.ResTarget)
- if !ok {
- continue
- }
- switch n := res.Val.(type) {
- case *ast.A_Const:
- name := ""
- if res.Name != nil {
- name = *res.Name
- }
- switch n.Val.(type) {
- case *ast.String:
- cols = append(cols, &Column{Name: name, DataType: "text", NotNull: true})
- case *ast.Integer:
- cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true})
- case *ast.Float:
- cols = append(cols, &Column{Name: name, DataType: "float", NotNull: true})
- case *ast.Boolean:
- cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
- default:
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- case *ast.A_Expr:
- name := ""
- if res.Name != nil {
- name = *res.Name
- }
- switch op := astutils.Join(n.Name, ""); {
- case lang.IsComparisonOperator(op):
- // TODO: Generate a name for these operations
- cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
- case lang.IsMathematicalOperator(op):
- cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true})
- default:
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- case *ast.BoolExpr:
- name := ""
- if res.Name != nil {
- name = *res.Name
- }
- notNull := false
- if len(n.Args.Items) == 1 {
- switch n.Boolop {
- case ast.BoolExprTypeIsNull, ast.BoolExprTypeIsNotNull:
- notNull = true
- case ast.BoolExprTypeNot:
- sublink, ok := n.Args.Items[0].(*ast.SubLink)
- if ok && sublink.SubLinkType == ast.EXISTS_SUBLINK {
- notNull = true
- if name == "" {
- name = "not_exists"
- }
- }
- }
- }
- cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: notNull})
- case *ast.CaseExpr:
- name := ""
- if res.Name != nil {
- name = *res.Name
- }
- // TODO: The TypeCase and A_Const code has been copied from below. Instead, we
- // need a recurse function to get the type of a node.
- if tc, ok := n.Defresult.(*ast.TypeCast); ok {
- if tc.TypeName == nil {
- return nil, errors.New("no type name type cast")
- }
- name := ""
- if ref, ok := tc.Arg.(*ast.ColumnRef); ok {
- name = astutils.Join(ref.Fields, "_")
- }
- if res.Name != nil {
- name = *res.Name
- }
- // TODO Validate column names
- col := toColumn(tc.TypeName)
- col.Name = name
- cols = append(cols, col)
- } else if aconst, ok := n.Defresult.(*ast.A_Const); ok {
- switch aconst.Val.(type) {
- case *ast.String:
- cols = append(cols, &Column{Name: name, DataType: "text", NotNull: true})
- case *ast.Integer:
- cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true})
- case *ast.Float:
- cols = append(cols, &Column{Name: name, DataType: "float", NotNull: true})
- case *ast.Boolean:
- cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
- default:
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- } else {
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- case *ast.CoalesceExpr:
- name := "coalesce"
- if res.Name != nil {
- name = *res.Name
- }
- var firstColumn *Column
- var shouldNotBeNull bool
- for _, arg := range n.Args.Items {
- if _, ok := arg.(*ast.A_Const); ok {
- shouldNotBeNull = true
- continue
- }
- if ref, ok := arg.(*ast.ColumnRef); ok {
- columns, err := outputColumnRefs(res, tables, ref)
- if err != nil {
- return nil, err
- }
- for _, c := range columns {
- if firstColumn == nil {
- firstColumn = c
- }
- shouldNotBeNull = shouldNotBeNull || c.NotNull
- }
- }
- }
- if firstColumn != nil {
- firstColumn.NotNull = shouldNotBeNull
- firstColumn.skipTableRequiredCheck = true
- cols = append(cols, firstColumn)
- } else {
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- case *ast.ColumnRef:
- if hasStarRef(n) {
- // add a column with a reference to an embedded table
- if embed, ok := qc.embeds.Find(n); ok {
- cols = append(cols, &Column{
- Name: embed.Table.Name,
- EmbedTable: embed.Table,
- })
- continue
- }
- // TODO: This code is copied in func expand()
- for _, t := range tables {
- scope := astutils.Join(n.Fields, ".")
- if scope != "" && scope != t.Rel.Name {
- continue
- }
- for _, c := range t.Columns {
- cname := c.Name
- if res.Name != nil {
- cname = *res.Name
- }
- cols = append(cols, &Column{
- Name: cname,
- OriginalName: c.Name,
- Type: c.Type,
- Scope: scope,
- Table: c.Table,
- TableAlias: t.Rel.Name,
- DataType: c.DataType,
- NotNull: c.NotNull,
- Unsigned: c.Unsigned,
- IsArray: c.IsArray,
- ArrayDims: c.ArrayDims,
- Length: c.Length,
- })
- }
- }
- continue
- }
- columns, err := outputColumnRefs(res, tables, n)
- if err != nil {
- return nil, err
- }
- cols = append(cols, columns...)
- case *ast.FuncCall:
- rel := n.Func
- name := rel.Name
- if res.Name != nil {
- name = *res.Name
- }
- fun, err := qc.catalog.ResolveFuncCall(n)
- if err == nil {
- cols = append(cols, &Column{
- Name: name,
- DataType: dataType(fun.ReturnType),
- NotNull: !fun.ReturnTypeNullable,
- IsFuncCall: true,
- })
- } else {
- cols = append(cols, &Column{
- Name: name,
- DataType: "any",
- IsFuncCall: true,
- })
- }
- case *ast.SubLink:
- name := "exists"
- if res.Name != nil {
- name = *res.Name
- }
- switch n.SubLinkType {
- case ast.EXISTS_SUBLINK:
- cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
- case ast.EXPR_SUBLINK:
- subcols, err := c.outputColumns(qc, n.Subselect)
- if err != nil {
- return nil, err
- }
- first := subcols[0]
- if res.Name != nil {
- first.Name = *res.Name
- }
- cols = append(cols, first)
- default:
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- case *ast.TypeCast:
- if n.TypeName == nil {
- return nil, errors.New("no type name type cast")
- }
- name := ""
- if ref, ok := n.Arg.(*ast.ColumnRef); ok {
- name = astutils.Join(ref.Fields, "_")
- }
- if res.Name != nil {
- name = *res.Name
- }
- // TODO Validate column names
- col := toColumn(n.TypeName)
- col.Name = name
- // TODO Add correct, real type inference
- if constant, ok := n.Arg.(*ast.A_Const); ok {
- if _, ok := constant.Val.(*ast.Null); ok {
- col.NotNull = false
- }
- }
- cols = append(cols, col)
- case *ast.SelectStmt:
- subcols, err := c.outputColumns(qc, n)
- if err != nil {
- return nil, err
- }
- first := subcols[0]
- if res.Name != nil {
- first.Name = *res.Name
- }
- cols = append(cols, first)
- default:
- name := ""
- if res.Name != nil {
- name = *res.Name
- }
- cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
- }
- }
- if n, ok := node.(*ast.SelectStmt); ok {
- for _, col := range cols {
- if !col.NotNull || col.Table == nil || col.skipTableRequiredCheck {
- continue
- }
- for _, f := range n.FromClause.Items {
- res := isTableRequired(f, col, tableRequired)
- if res != tableNotFound {
- col.NotNull = res == tableRequired
- break
- }
- }
- }
- }
- return cols, nil
- }
- const (
- tableNotFound = iota
- tableRequired
- tableOptional
- )
- func isTableRequired(n ast.Node, col *Column, prior int) int {
- switch n := n.(type) {
- case *ast.RangeVar:
- tableMatch := *n.Relname == col.Table.Name
- aliasMatch := true
- if n.Alias != nil && col.TableAlias != "" {
- aliasMatch = *n.Alias.Aliasname == col.TableAlias
- }
- if aliasMatch && tableMatch {
- return prior
- }
- case *ast.JoinExpr:
- helper := func(l, r int) int {
- if res := isTableRequired(n.Larg, col, l); res != tableNotFound {
- return res
- }
- if res := isTableRequired(n.Rarg, col, r); res != tableNotFound {
- return res
- }
- return tableNotFound
- }
- switch n.Jointype {
- case ast.JoinTypeLeft:
- return helper(tableRequired, tableOptional)
- case ast.JoinTypeRight:
- return helper(tableOptional, tableRequired)
- case ast.JoinTypeFull:
- return helper(tableOptional, tableOptional)
- case ast.JoinTypeInner:
- return helper(tableRequired, tableRequired)
- }
- case *ast.List:
- for _, item := range n.Items {
- if res := isTableRequired(item, col, prior); res != tableNotFound {
- return res
- }
- }
- }
- return tableNotFound
- }
- type tableVisitor struct {
- list ast.List
- }
- func (r *tableVisitor) Visit(n ast.Node) astutils.Visitor {
- switch n.(type) {
- case *ast.RangeVar, *ast.RangeFunction:
- r.list.Items = append(r.list.Items, n)
- return r
- case *ast.RangeSubselect:
- r.list.Items = append(r.list.Items, n)
- return nil
- default:
- return r
- }
- }
- // Compute the output columns for a statement.
- //
- // Return an error if column references are ambiguous
- // Return an error if column references don't exist
- // Return an error if a table is referenced twice
- // Return an error if an unknown column is referenced
- func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
- list := &ast.List{}
- switch n := node.(type) {
- case *ast.DeleteStmt:
- list = n.Relations
- case *ast.InsertStmt:
- list = &ast.List{
- Items: []ast.Node{n.Relation},
- }
- case *ast.SelectStmt:
- var tv tableVisitor
- astutils.Walk(&tv, n.FromClause)
- list = &tv.list
- case *ast.TruncateStmt:
- list = astutils.Search(n.Relations, func(node ast.Node) bool {
- _, ok := node.(*ast.RangeVar)
- return ok
- })
- case *ast.RefreshMatViewStmt:
- list = astutils.Search(n.Relation, func(node ast.Node) bool {
- _, ok := node.(*ast.RangeVar)
- return ok
- })
- case *ast.UpdateStmt:
- var tv tableVisitor
- astutils.Walk(&tv, n.FromClause)
- astutils.Walk(&tv, n.Relations)
- list = &tv.list
- }
- var tables []*Table
- for _, item := range list.Items {
- item := item
- switch n := item.(type) {
- case *ast.RangeFunction:
- var funcCall *ast.FuncCall
- switch f := n.Functions.Items[0].(type) {
- case *ast.List:
- switch fi := f.Items[0].(type) {
- case *ast.FuncCall:
- funcCall = fi
- case *ast.SQLValueFunction:
- continue // TODO handle this correctly
- default:
- continue
- }
- case *ast.FuncCall:
- funcCall = f
- default:
- return nil, fmt.Errorf("sourceTables: unsupported function call type %T", n.Functions.Items[0])
- }
- // If the function or table can't be found, don't error out. There
- // are many queries that depend on functions unknown to sqlc.
- fn, err := qc.GetFunc(funcCall.Func)
- if err != nil {
- continue
- }
- var table *Table
- if fn.ReturnType != nil {
- table, err = qc.GetTable(&ast.TableName{
- Catalog: fn.ReturnType.Catalog,
- Schema: fn.ReturnType.Schema,
- Name: fn.ReturnType.Name,
- })
- }
- if table == nil || err != nil {
- if n.Alias != nil && len(n.Alias.Colnames.Items) > 0 {
- table = &Table{}
- for _, colName := range n.Alias.Colnames.Items {
- table.Columns = append(table.Columns, &Column{
- Name: colName.(*ast.String).Str,
- DataType: "any",
- })
- }
- } else {
- colName := fn.Rel.Name
- if n.Alias != nil {
- colName = *n.Alias.Aliasname
- }
- table = &Table{
- Rel: &ast.TableName{
- Catalog: fn.Rel.Catalog,
- Schema: fn.Rel.Schema,
- Name: fn.Rel.Name,
- },
- }
- if len(fn.Outs) > 0 {
- for _, arg := range fn.Outs {
- table.Columns = append(table.Columns, &Column{
- Name: arg.Name,
- DataType: arg.Type.Name,
- })
- }
- }
- if fn.ReturnType != nil {
- table.Columns = []*Column{
- {
- Name: colName,
- DataType: fn.ReturnType.Name,
- },
- }
- }
- }
- }
- if n.Alias != nil {
- table.Rel = &ast.TableName{
- Name: *n.Alias.Aliasname,
- }
- }
- tables = append(tables, table)
- case *ast.RangeSubselect:
- cols, err := c.outputColumns(qc, n.Subquery)
- if err != nil {
- return nil, err
- }
- tables = append(tables, &Table{
- Rel: &ast.TableName{
- Name: *n.Alias.Aliasname,
- },
- Columns: cols,
- })
- case *ast.RangeVar:
- fqn, err := ParseTableName(n)
- if err != nil {
- return nil, err
- }
- if qc == nil {
- return nil, fmt.Errorf("query catalog is empty")
- }
- table, cerr := qc.GetTable(fqn)
- if cerr != nil {
- // TODO: Update error location
- // cerr.Location = n.Location
- // return nil, *cerr
- return nil, cerr
- }
- if n.Alias != nil {
- table.Rel = &ast.TableName{
- Catalog: table.Rel.Catalog,
- Schema: table.Rel.Schema,
- Name: *n.Alias.Aliasname,
- }
- }
- tables = append(tables, table)
- default:
- return nil, fmt.Errorf("sourceTable: unsupported list item type: %T", n)
- }
- }
- return tables, nil
- }
- func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef) ([]*Column, error) {
- parts := stringSlice(node.Fields)
- var schema, name, alias string
- switch {
- case len(parts) == 1:
- name = parts[0]
- case len(parts) == 2:
- alias = parts[0]
- name = parts[1]
- case len(parts) == 3:
- schema = parts[0]
- alias = parts[1]
- name = parts[2]
- default:
- return nil, fmt.Errorf("unknown number of fields: %d", len(parts))
- }
- var cols []*Column
- var found int
- for _, t := range tables {
- if schema != "" && t.Rel.Schema != schema {
- continue
- }
- if alias != "" && t.Rel.Name != alias {
- continue
- }
- for _, c := range t.Columns {
- if c.Name == name {
- found += 1
- cname := c.Name
- if res.Name != nil {
- cname = *res.Name
- }
- cols = append(cols, &Column{
- Name: cname,
- Type: c.Type,
- Table: c.Table,
- TableAlias: alias,
- DataType: c.DataType,
- NotNull: c.NotNull,
- Unsigned: c.Unsigned,
- IsArray: c.IsArray,
- ArrayDims: c.ArrayDims,
- Length: c.Length,
- EmbedTable: c.EmbedTable,
- OriginalName: c.Name,
- })
- }
- }
- }
- if found == 0 {
- return nil, &sqlerr.Error{
- Code: "42703",
- Message: fmt.Sprintf("column %q does not exist", name),
- Location: res.Location,
- }
- }
- if found > 1 {
- return nil, &sqlerr.Error{
- Code: "42703",
- Message: fmt.Sprintf("column reference %q is ambiguous", name),
- Location: res.Location,
- }
- }
- return cols, nil
- }
- func findColumnForNode(item ast.Node, tables []*Table, targetList *ast.List) error {
- ref, ok := item.(*ast.ColumnRef)
- if !ok {
- return nil
- }
- return findColumnForRef(ref, tables, targetList)
- }
- func findColumnForRef(ref *ast.ColumnRef, tables []*Table, targetList *ast.List) error {
- parts := stringSlice(ref.Fields)
- var alias, name string
- if len(parts) == 1 {
- name = parts[0]
- } else if len(parts) == 2 {
- alias = parts[0]
- name = parts[1]
- }
- var found int
- for _, t := range tables {
- if alias != "" && t.Rel.Name != alias {
- continue
- }
- // Find matching column
- for _, c := range t.Columns {
- if c.Name == name {
- found++
- break
- }
- }
- }
- // Find matching alias if necessary
- if found == 0 {
- for _, c := range targetList.Items {
- resTarget, ok := c.(*ast.ResTarget)
- if !ok {
- continue
- }
- if resTarget.Name != nil && *resTarget.Name == name {
- found++
- }
- }
- }
- if found == 0 {
- return &sqlerr.Error{
- Code: "42703",
- Message: fmt.Sprintf("column reference %q not found", name),
- Location: ref.Location,
- }
- }
- if found > 1 {
- return &sqlerr.Error{
- Code: "42703",
- Message: fmt.Sprintf("column reference %q is ambiguous", name),
- Location: ref.Location,
- }
- }
- return nil
- }
|