123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249 |
- package catalog
- import (
- "errors"
- "fmt"
- "github.com/kyleconroy/sqlc/internal/sql/ast"
- "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
- )
- func (c *Catalog) alterTable(stmt *ast.AlterTableStmt) error {
- var implemented bool
- for _, item := range stmt.Cmds.Items {
- switch cmd := item.(type) {
- case *ast.AlterTableCmd:
- switch cmd.Subtype {
- case ast.AT_AddColumn:
- implemented = true
- case ast.AT_AlterColumnType:
- implemented = true
- case ast.AT_DropColumn:
- implemented = true
- case ast.AT_DropNotNull:
- implemented = true
- case ast.AT_SetNotNull:
- implemented = true
- }
- }
- }
- if !implemented {
- return nil
- }
- _, table, err := c.getTable(stmt.Table)
- if err != nil {
- return err
- }
- for _, cmd := range stmt.Cmds.Items {
- switch cmd := cmd.(type) {
- case *ast.AlterTableCmd:
- idx := -1
- // Lookup column names for column-related commands
- switch cmd.Subtype {
- case ast.AT_AlterColumnType,
- ast.AT_DropColumn,
- ast.AT_DropNotNull,
- ast.AT_SetNotNull:
- for i, c := range table.Columns {
- if c.Name == *cmd.Name {
- idx = i
- break
- }
- }
- if idx < 0 && !cmd.MissingOk {
- return sqlerr.ColumnNotFound(table.Rel.Name, *cmd.Name)
- }
- // If a missing column is allowed, skip this command
- if idx < 0 && cmd.MissingOk {
- continue
- }
- }
- switch cmd.Subtype {
- case ast.AT_AddColumn:
- for _, c := range table.Columns {
- if c.Name == cmd.Def.Colname {
- return sqlerr.ColumnExists(table.Rel.Name, c.Name)
- }
- }
- table.Columns = append(table.Columns, &Column{
- Name: cmd.Def.Colname,
- Type: *cmd.Def.TypeName,
- IsNotNull: cmd.Def.IsNotNull,
- IsArray: cmd.Def.IsArray,
- Length: cmd.Def.Length,
- })
- case ast.AT_AlterColumnType:
- table.Columns[idx].Type = *cmd.Def.TypeName
- table.Columns[idx].IsArray = cmd.Def.IsArray
- case ast.AT_DropColumn:
- table.Columns = append(table.Columns[:idx], table.Columns[idx+1:]...)
- case ast.AT_DropNotNull:
- table.Columns[idx].IsNotNull = false
- case ast.AT_SetNotNull:
- table.Columns[idx].IsNotNull = true
- }
- }
- }
- return nil
- }
- func (c *Catalog) alterTableSetSchema(stmt *ast.AlterTableSetSchemaStmt) error {
- ns := stmt.Table.Schema
- if ns == "" {
- ns = c.DefaultSchema
- }
- oldSchema, err := c.getSchema(ns)
- if err != nil {
- return err
- }
- tbl, idx, err := oldSchema.getTable(stmt.Table)
- if err != nil {
- return err
- }
- tbl.Rel.Schema = *stmt.NewSchema
- newSchema, err := c.getSchema(*stmt.NewSchema)
- if err != nil {
- return err
- }
- if _, _, err := newSchema.getTable(stmt.Table); err == nil {
- return sqlerr.RelationExists(stmt.Table.Name)
- }
- oldSchema.Tables = append(oldSchema.Tables[:idx], oldSchema.Tables[idx+1:]...)
- newSchema.Tables = append(newSchema.Tables, tbl)
- return nil
- }
- func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error {
- ns := stmt.Name.Schema
- if ns == "" {
- ns = c.DefaultSchema
- }
- schema, err := c.getSchema(ns)
- if err != nil {
- return err
- }
- _, _, err = schema.getTable(stmt.Name)
- if err == nil && stmt.IfNotExists {
- return nil
- } else if err == nil {
- return sqlerr.RelationExists(stmt.Name.Name)
- }
- tbl := Table{Rel: stmt.Name, Comment: stmt.Comment}
- for _, inheritTable := range stmt.Inherits {
- t, _, err := schema.getTable(inheritTable)
- if err != nil {
- return err
- }
- tbl.Columns = append(tbl.Columns, t.Columns...)
- }
- if stmt.ReferTable != nil && len(stmt.Cols) != 0 {
- return errors.New("create table node cannot have both a ReferTable and Cols")
- }
- if stmt.ReferTable != nil {
- _, original, err := c.getTable(stmt.ReferTable)
- if err != nil {
- return err
- }
- for _, col := range original.Columns {
- newCol := *col // make a copy, so changes to the ReferTable don't propagate
- tbl.Columns = append(tbl.Columns, &newCol)
- }
- } else {
- for _, col := range stmt.Cols {
- tc := &Column{
- Name: col.Colname,
- Type: *col.TypeName,
- IsNotNull: col.IsNotNull,
- IsArray: col.IsArray,
- Comment: col.Comment,
- Length: col.Length,
- }
- if col.Vals != nil {
- typeName := ast.TypeName{
- Name: fmt.Sprintf("%s_%s", stmt.Name.Name, col.Colname),
- }
- s := &ast.CreateEnumStmt{TypeName: &typeName, Vals: col.Vals}
- if err := c.createEnum(s); err != nil {
- return err
- }
- tc.Type = typeName
- }
- tbl.Columns = append(tbl.Columns, tc)
- }
- }
- schema.Tables = append(schema.Tables, &tbl)
- return nil
- }
- func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error {
- for _, name := range stmt.Tables {
- ns := name.Schema
- if ns == "" {
- ns = c.DefaultSchema
- }
- schema, err := c.getSchema(ns)
- if errors.Is(err, sqlerr.NotFound) && stmt.IfExists {
- continue
- } else if err != nil {
- return err
- }
- _, idx, err := schema.getTable(name)
- if errors.Is(err, sqlerr.NotFound) && stmt.IfExists {
- continue
- } else if err != nil {
- return err
- }
- schema.Tables = append(schema.Tables[:idx], schema.Tables[idx+1:]...)
- }
- return nil
- }
- func (c *Catalog) renameColumn(stmt *ast.RenameColumnStmt) error {
- _, tbl, err := c.getTable(stmt.Table)
- if err != nil {
- return err
- }
- idx := -1
- for i := range tbl.Columns {
- if tbl.Columns[i].Name == stmt.Col.Name {
- idx = i
- }
- if tbl.Columns[i].Name == *stmt.NewName {
- return sqlerr.ColumnExists(tbl.Rel.Name, *stmt.NewName)
- }
- }
- if idx == -1 {
- return sqlerr.ColumnNotFound(tbl.Rel.Name, stmt.Col.Name)
- }
- tbl.Columns[idx].Name = *stmt.NewName
- return nil
- }
- func (c *Catalog) renameTable(stmt *ast.RenameTableStmt) error {
- sch, tbl, err := c.getTable(stmt.Table)
- if err != nil {
- return err
- }
- if _, _, err := sch.getTable(&ast.TableName{Name: *stmt.NewName}); err == nil {
- return sqlerr.RelationExists(*stmt.NewName)
- }
- if stmt.NewName != nil {
- tbl.Rel.Name = *stmt.NewName
- }
- return nil
- }
|