table.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. package catalog
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/kyleconroy/sqlc/internal/sql/ast"
  6. "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
  7. )
  8. func (c *Catalog) alterTable(stmt *ast.AlterTableStmt) error {
  9. var implemented bool
  10. for _, item := range stmt.Cmds.Items {
  11. switch cmd := item.(type) {
  12. case *ast.AlterTableCmd:
  13. switch cmd.Subtype {
  14. case ast.AT_AddColumn:
  15. implemented = true
  16. case ast.AT_AlterColumnType:
  17. implemented = true
  18. case ast.AT_DropColumn:
  19. implemented = true
  20. case ast.AT_DropNotNull:
  21. implemented = true
  22. case ast.AT_SetNotNull:
  23. implemented = true
  24. }
  25. }
  26. }
  27. if !implemented {
  28. return nil
  29. }
  30. _, table, err := c.getTable(stmt.Table)
  31. if err != nil {
  32. return err
  33. }
  34. for _, cmd := range stmt.Cmds.Items {
  35. switch cmd := cmd.(type) {
  36. case *ast.AlterTableCmd:
  37. idx := -1
  38. // Lookup column names for column-related commands
  39. switch cmd.Subtype {
  40. case ast.AT_AlterColumnType,
  41. ast.AT_DropColumn,
  42. ast.AT_DropNotNull,
  43. ast.AT_SetNotNull:
  44. for i, c := range table.Columns {
  45. if c.Name == *cmd.Name {
  46. idx = i
  47. break
  48. }
  49. }
  50. if idx < 0 && !cmd.MissingOk {
  51. return sqlerr.ColumnNotFound(table.Rel.Name, *cmd.Name)
  52. }
  53. // If a missing column is allowed, skip this command
  54. if idx < 0 && cmd.MissingOk {
  55. continue
  56. }
  57. }
  58. switch cmd.Subtype {
  59. case ast.AT_AddColumn:
  60. for _, c := range table.Columns {
  61. if c.Name == cmd.Def.Colname {
  62. return sqlerr.ColumnExists(table.Rel.Name, c.Name)
  63. }
  64. }
  65. table.Columns = append(table.Columns, &Column{
  66. Name: cmd.Def.Colname,
  67. Type: *cmd.Def.TypeName,
  68. IsNotNull: cmd.Def.IsNotNull,
  69. IsArray: cmd.Def.IsArray,
  70. Length: cmd.Def.Length,
  71. })
  72. case ast.AT_AlterColumnType:
  73. table.Columns[idx].Type = *cmd.Def.TypeName
  74. table.Columns[idx].IsArray = cmd.Def.IsArray
  75. case ast.AT_DropColumn:
  76. table.Columns = append(table.Columns[:idx], table.Columns[idx+1:]...)
  77. case ast.AT_DropNotNull:
  78. table.Columns[idx].IsNotNull = false
  79. case ast.AT_SetNotNull:
  80. table.Columns[idx].IsNotNull = true
  81. }
  82. }
  83. }
  84. return nil
  85. }
  86. func (c *Catalog) alterTableSetSchema(stmt *ast.AlterTableSetSchemaStmt) error {
  87. ns := stmt.Table.Schema
  88. if ns == "" {
  89. ns = c.DefaultSchema
  90. }
  91. oldSchema, err := c.getSchema(ns)
  92. if err != nil {
  93. return err
  94. }
  95. tbl, idx, err := oldSchema.getTable(stmt.Table)
  96. if err != nil {
  97. return err
  98. }
  99. tbl.Rel.Schema = *stmt.NewSchema
  100. newSchema, err := c.getSchema(*stmt.NewSchema)
  101. if err != nil {
  102. return err
  103. }
  104. if _, _, err := newSchema.getTable(stmt.Table); err == nil {
  105. return sqlerr.RelationExists(stmt.Table.Name)
  106. }
  107. oldSchema.Tables = append(oldSchema.Tables[:idx], oldSchema.Tables[idx+1:]...)
  108. newSchema.Tables = append(newSchema.Tables, tbl)
  109. return nil
  110. }
  111. func (c *Catalog) createTable(stmt *ast.CreateTableStmt) error {
  112. ns := stmt.Name.Schema
  113. if ns == "" {
  114. ns = c.DefaultSchema
  115. }
  116. schema, err := c.getSchema(ns)
  117. if err != nil {
  118. return err
  119. }
  120. _, _, err = schema.getTable(stmt.Name)
  121. if err == nil && stmt.IfNotExists {
  122. return nil
  123. } else if err == nil {
  124. return sqlerr.RelationExists(stmt.Name.Name)
  125. }
  126. tbl := Table{Rel: stmt.Name, Comment: stmt.Comment}
  127. for _, inheritTable := range stmt.Inherits {
  128. t, _, err := schema.getTable(inheritTable)
  129. if err != nil {
  130. return err
  131. }
  132. tbl.Columns = append(tbl.Columns, t.Columns...)
  133. }
  134. if stmt.ReferTable != nil && len(stmt.Cols) != 0 {
  135. return errors.New("create table node cannot have both a ReferTable and Cols")
  136. }
  137. if stmt.ReferTable != nil {
  138. _, original, err := c.getTable(stmt.ReferTable)
  139. if err != nil {
  140. return err
  141. }
  142. for _, col := range original.Columns {
  143. newCol := *col // make a copy, so changes to the ReferTable don't propagate
  144. tbl.Columns = append(tbl.Columns, &newCol)
  145. }
  146. } else {
  147. for _, col := range stmt.Cols {
  148. tc := &Column{
  149. Name: col.Colname,
  150. Type: *col.TypeName,
  151. IsNotNull: col.IsNotNull,
  152. IsArray: col.IsArray,
  153. Comment: col.Comment,
  154. Length: col.Length,
  155. }
  156. if col.Vals != nil {
  157. typeName := ast.TypeName{
  158. Name: fmt.Sprintf("%s_%s", stmt.Name.Name, col.Colname),
  159. }
  160. s := &ast.CreateEnumStmt{TypeName: &typeName, Vals: col.Vals}
  161. if err := c.createEnum(s); err != nil {
  162. return err
  163. }
  164. tc.Type = typeName
  165. }
  166. tbl.Columns = append(tbl.Columns, tc)
  167. }
  168. }
  169. schema.Tables = append(schema.Tables, &tbl)
  170. return nil
  171. }
  172. func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error {
  173. for _, name := range stmt.Tables {
  174. ns := name.Schema
  175. if ns == "" {
  176. ns = c.DefaultSchema
  177. }
  178. schema, err := c.getSchema(ns)
  179. if errors.Is(err, sqlerr.NotFound) && stmt.IfExists {
  180. continue
  181. } else if err != nil {
  182. return err
  183. }
  184. _, idx, err := schema.getTable(name)
  185. if errors.Is(err, sqlerr.NotFound) && stmt.IfExists {
  186. continue
  187. } else if err != nil {
  188. return err
  189. }
  190. schema.Tables = append(schema.Tables[:idx], schema.Tables[idx+1:]...)
  191. }
  192. return nil
  193. }
  194. func (c *Catalog) renameColumn(stmt *ast.RenameColumnStmt) error {
  195. _, tbl, err := c.getTable(stmt.Table)
  196. if err != nil {
  197. return err
  198. }
  199. idx := -1
  200. for i := range tbl.Columns {
  201. if tbl.Columns[i].Name == stmt.Col.Name {
  202. idx = i
  203. }
  204. if tbl.Columns[i].Name == *stmt.NewName {
  205. return sqlerr.ColumnExists(tbl.Rel.Name, *stmt.NewName)
  206. }
  207. }
  208. if idx == -1 {
  209. return sqlerr.ColumnNotFound(tbl.Rel.Name, stmt.Col.Name)
  210. }
  211. tbl.Columns[idx].Name = *stmt.NewName
  212. return nil
  213. }
  214. func (c *Catalog) renameTable(stmt *ast.RenameTableStmt) error {
  215. sch, tbl, err := c.getTable(stmt.Table)
  216. if err != nil {
  217. return err
  218. }
  219. if _, _, err := sch.getTable(&ast.TableName{Name: *stmt.NewName}); err == nil {
  220. return sqlerr.RelationExists(*stmt.NewName)
  221. }
  222. if stmt.NewName != nil {
  223. tbl.Rel.Name = *stmt.NewName
  224. }
  225. return nil
  226. }