catalog.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. package catalog
  2. import (
  3. "github.com/kyleconroy/sqlc/internal/sql/ast"
  4. "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
  5. )
  6. func stringSlice(list *ast.List) []string {
  7. items := []string{}
  8. for _, item := range list.Items {
  9. if n, ok := item.(*ast.String); ok {
  10. items = append(items, n.Str)
  11. }
  12. }
  13. return items
  14. }
  15. type Catalog struct {
  16. Name string
  17. Schemas []*Schema
  18. Comment string
  19. SearchPath []string
  20. DefaultSchema string
  21. }
  22. func (c *Catalog) getSchema(name string) (*Schema, error) {
  23. for i := range c.Schemas {
  24. if c.Schemas[i].Name == name {
  25. return c.Schemas[i], nil
  26. }
  27. }
  28. return nil, sqlerr.SchemaNotFound(name)
  29. }
  30. func (c *Catalog) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int, error) {
  31. ns := rel.Schema
  32. if ns == "" {
  33. ns = c.DefaultSchema
  34. }
  35. s, err := c.getSchema(ns)
  36. if err != nil {
  37. return nil, -1, err
  38. }
  39. return s.getFunc(rel, tns)
  40. }
  41. func (c *Catalog) getTable(name *ast.TableName) (*Schema, *Table, error) {
  42. ns := name.Schema
  43. if ns == "" {
  44. ns = c.DefaultSchema
  45. }
  46. var s *Schema
  47. for i := range c.Schemas {
  48. if c.Schemas[i].Name == ns {
  49. s = c.Schemas[i]
  50. break
  51. }
  52. }
  53. if s == nil {
  54. return nil, nil, sqlerr.SchemaNotFound(ns)
  55. }
  56. t, _, err := s.getTable(name)
  57. if err != nil {
  58. return nil, nil, err
  59. }
  60. return s, t, nil
  61. }
  62. func (c *Catalog) getType(rel *ast.TypeName) (Type, int, error) {
  63. ns := rel.Schema
  64. if ns == "" {
  65. ns = c.DefaultSchema
  66. }
  67. s, err := c.getSchema(ns)
  68. if err != nil {
  69. return nil, -1, err
  70. }
  71. return s.getType(rel)
  72. }
  73. type Schema struct {
  74. Name string
  75. Tables []*Table
  76. Types []Type
  77. Funcs []*Function
  78. Comment string
  79. }
  80. func sameType(a, b *ast.TypeName) bool {
  81. if a.Catalog != b.Catalog {
  82. return false
  83. }
  84. if a.Schema != b.Schema {
  85. return false
  86. }
  87. if a.Name != b.Name {
  88. return false
  89. }
  90. return true
  91. }
  92. func (s *Schema) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int, error) {
  93. for i := range s.Funcs {
  94. if s.Funcs[i].Name != rel.Name {
  95. continue
  96. }
  97. args := s.Funcs[i].InArgs()
  98. if len(args) != len(tns) {
  99. continue
  100. }
  101. found := true
  102. for j := range args {
  103. if !sameType(s.Funcs[i].Args[j].Type, tns[j]) {
  104. found = false
  105. break
  106. }
  107. }
  108. if !found {
  109. continue
  110. }
  111. return s.Funcs[i], i, nil
  112. }
  113. return nil, -1, sqlerr.RelationNotFound(rel.Name)
  114. }
  115. func (s *Schema) getFuncByName(rel *ast.FuncName) (*Function, int, error) {
  116. idx := -1
  117. for i := range s.Funcs {
  118. if s.Funcs[i].Name == rel.Name && idx >= 0 {
  119. return nil, -1, sqlerr.FunctionNotUnique(rel.Name)
  120. }
  121. if s.Funcs[i].Name == rel.Name {
  122. idx = i
  123. }
  124. }
  125. if idx < 0 {
  126. return nil, -1, sqlerr.RelationNotFound(rel.Name)
  127. }
  128. return s.Funcs[idx], idx, nil
  129. }
  130. func (s *Schema) getTable(rel *ast.TableName) (*Table, int, error) {
  131. for i := range s.Tables {
  132. if s.Tables[i].Rel.Name == rel.Name {
  133. return s.Tables[i], i, nil
  134. }
  135. }
  136. return nil, -1, sqlerr.RelationNotFound(rel.Name)
  137. }
  138. func (s *Schema) getType(rel *ast.TypeName) (Type, int, error) {
  139. for i := range s.Types {
  140. switch typ := s.Types[i].(type) {
  141. case *Enum:
  142. if typ.Name == rel.Name {
  143. return s.Types[i], i, nil
  144. }
  145. }
  146. }
  147. return nil, -1, sqlerr.TypeNotFound(rel.Name)
  148. }
  149. type Table struct {
  150. Rel *ast.TableName
  151. Columns []*Column
  152. Comment string
  153. }
  154. // TODO: Should this just be ast Nodes?
  155. type Column struct {
  156. Name string
  157. Type ast.TypeName
  158. IsNotNull bool
  159. IsArray bool
  160. Comment string
  161. }
  162. type Type interface {
  163. isType()
  164. SetComment(string)
  165. }
  166. type Enum struct {
  167. Name string
  168. Vals []string
  169. Comment string
  170. }
  171. func (e *Enum) SetComment(c string) {
  172. e.Comment = c
  173. }
  174. func (e *Enum) isType() {
  175. }
  176. type CompositeType struct {
  177. Name string
  178. Comment string
  179. }
  180. func (ct *CompositeType) isType() {
  181. }
  182. func (ct *CompositeType) SetComment(c string) {
  183. ct.Comment = c
  184. }
  185. type Function struct {
  186. Name string
  187. Args []*Argument
  188. ReturnType *ast.TypeName
  189. Comment string
  190. Desc string
  191. }
  192. func (f *Function) InArgs() []*Argument {
  193. var args []*Argument
  194. for _, a := range f.Args {
  195. switch a.Mode {
  196. case ast.FuncParamTable, ast.FuncParamOut:
  197. continue
  198. default:
  199. args = append(args, a)
  200. }
  201. }
  202. return args
  203. }
  204. type Argument struct {
  205. Name string
  206. Type *ast.TypeName
  207. HasDefault bool
  208. Mode ast.FuncParamMode
  209. }
  210. func New(def string) *Catalog {
  211. return &Catalog{
  212. DefaultSchema: def,
  213. Schemas: []*Schema{
  214. &Schema{Name: def},
  215. },
  216. }
  217. }
  218. func (c *Catalog) Build(stmts []ast.Statement) error {
  219. for i := range stmts {
  220. if err := c.Update(stmts[i]); err != nil {
  221. return err
  222. }
  223. }
  224. return nil
  225. }
  226. func (c *Catalog) Update(stmt ast.Statement) error {
  227. if stmt.Raw == nil {
  228. return nil
  229. }
  230. var err error
  231. switch n := stmt.Raw.Stmt.(type) {
  232. case *ast.AlterTableStmt:
  233. err = c.alterTable(n)
  234. case *ast.AlterTableSetSchemaStmt:
  235. err = c.alterTableSetSchema(n)
  236. case *ast.AlterTypeAddValueStmt:
  237. err = c.alterTypeAddValue(n)
  238. case *ast.AlterTypeRenameValueStmt:
  239. err = c.alterTypeRenameValue(n)
  240. case *ast.CommentOnColumnStmt:
  241. err = c.commentOnColumn(n)
  242. case *ast.CommentOnSchemaStmt:
  243. err = c.commentOnSchema(n)
  244. case *ast.CommentOnTableStmt:
  245. err = c.commentOnTable(n)
  246. case *ast.CommentOnTypeStmt:
  247. err = c.commentOnType(n)
  248. case *ast.CompositeTypeStmt:
  249. err = c.createCompositeType(n)
  250. case *ast.CreateEnumStmt:
  251. err = c.createEnum(n)
  252. case *ast.CreateFunctionStmt:
  253. err = c.createFunction(n)
  254. case *ast.CreateSchemaStmt:
  255. err = c.createSchema(n)
  256. case *ast.CreateTableStmt:
  257. err = c.createTable(n)
  258. case *ast.DropFunctionStmt:
  259. err = c.dropFunction(n)
  260. case *ast.DropSchemaStmt:
  261. err = c.dropSchema(n)
  262. case *ast.DropTableStmt:
  263. err = c.dropTable(n)
  264. case *ast.DropTypeStmt:
  265. err = c.dropType(n)
  266. case *ast.RenameColumnStmt:
  267. err = c.renameColumn(n)
  268. case *ast.RenameTableStmt:
  269. err = c.renameTable(n)
  270. }
  271. return err
  272. }