catalog.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. package catalog
  2. import (
  3. "strings"
  4. "github.com/kyleconroy/sqlc/internal/sql/ast"
  5. "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
  6. )
  7. func stringSlice(list *ast.List) []string {
  8. items := []string{}
  9. for _, item := range list.Items {
  10. if n, ok := item.(*ast.String); ok {
  11. items = append(items, n.Str)
  12. }
  13. }
  14. return items
  15. }
  16. type Catalog struct {
  17. Comment string
  18. DefaultSchema string
  19. Name string
  20. Schemas []*Schema
  21. SearchPath []string
  22. LoadExtension func(string) *Schema
  23. // TODO: un-export
  24. Extensions map[string]struct{}
  25. }
  26. func (c *Catalog) getSchema(name string) (*Schema, error) {
  27. for i := range c.Schemas {
  28. if c.Schemas[i].Name == name {
  29. return c.Schemas[i], nil
  30. }
  31. }
  32. return nil, sqlerr.SchemaNotFound(name)
  33. }
  34. func (c *Catalog) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int, error) {
  35. ns := rel.Schema
  36. if ns == "" {
  37. ns = c.DefaultSchema
  38. }
  39. s, err := c.getSchema(ns)
  40. if err != nil {
  41. return nil, -1, err
  42. }
  43. return s.getFunc(rel, tns)
  44. }
  45. func (c *Catalog) getTable(name *ast.TableName) (*Schema, *Table, error) {
  46. ns := name.Schema
  47. if ns == "" {
  48. ns = c.DefaultSchema
  49. }
  50. var s *Schema
  51. for i := range c.Schemas {
  52. if c.Schemas[i].Name == ns {
  53. s = c.Schemas[i]
  54. break
  55. }
  56. }
  57. if s == nil {
  58. return nil, nil, sqlerr.SchemaNotFound(ns)
  59. }
  60. t, _, err := s.getTable(name)
  61. if err != nil {
  62. return nil, nil, err
  63. }
  64. return s, t, nil
  65. }
  66. func (c *Catalog) getType(rel *ast.TypeName) (Type, int, error) {
  67. ns := rel.Schema
  68. if ns == "" {
  69. ns = c.DefaultSchema
  70. }
  71. s, err := c.getSchema(ns)
  72. if err != nil {
  73. return nil, -1, err
  74. }
  75. return s.getType(rel)
  76. }
  77. type Schema struct {
  78. Name string
  79. Tables []*Table
  80. Types []Type
  81. Funcs []*Function
  82. Comment string
  83. }
  84. func sameType(a, b *ast.TypeName) bool {
  85. if a.Catalog != b.Catalog {
  86. return false
  87. }
  88. // The pg_catalog schema is searched by default, so take that into
  89. // account when comparing schemas
  90. aSchema := a.Schema
  91. bSchema := b.Schema
  92. if aSchema == "pg_catalog" {
  93. aSchema = ""
  94. }
  95. if bSchema == "pg_catalog" {
  96. bSchema = ""
  97. }
  98. if aSchema != bSchema {
  99. return false
  100. }
  101. if a.Name != b.Name {
  102. return false
  103. }
  104. return true
  105. }
  106. func (s *Schema) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int, error) {
  107. for i := range s.Funcs {
  108. if strings.ToLower(s.Funcs[i].Name) != strings.ToLower(rel.Name) {
  109. continue
  110. }
  111. args := s.Funcs[i].InArgs()
  112. if len(args) != len(tns) {
  113. continue
  114. }
  115. found := true
  116. for j := range args {
  117. if !sameType(s.Funcs[i].Args[j].Type, tns[j]) {
  118. found = false
  119. break
  120. }
  121. }
  122. if !found {
  123. continue
  124. }
  125. return s.Funcs[i], i, nil
  126. }
  127. return nil, -1, sqlerr.RelationNotFound(rel.Name)
  128. }
  129. func (s *Schema) getFuncByName(rel *ast.FuncName) (*Function, int, error) {
  130. idx := -1
  131. name := strings.ToLower(rel.Name)
  132. for i := range s.Funcs {
  133. lowered := strings.ToLower(s.Funcs[i].Name)
  134. if lowered == name && idx >= 0 {
  135. return nil, -1, sqlerr.FunctionNotUnique(rel.Name)
  136. }
  137. if lowered == name {
  138. idx = i
  139. }
  140. }
  141. if idx < 0 {
  142. return nil, -1, sqlerr.RelationNotFound(rel.Name)
  143. }
  144. return s.Funcs[idx], idx, nil
  145. }
  146. func (s *Schema) getTable(rel *ast.TableName) (*Table, int, error) {
  147. for i := range s.Tables {
  148. if s.Tables[i].Rel.Name == rel.Name {
  149. return s.Tables[i], i, nil
  150. }
  151. }
  152. return nil, -1, sqlerr.RelationNotFound(rel.Name)
  153. }
  154. func (s *Schema) getType(rel *ast.TypeName) (Type, int, error) {
  155. for i := range s.Types {
  156. switch typ := s.Types[i].(type) {
  157. case *Enum:
  158. if typ.Name == rel.Name {
  159. return s.Types[i], i, nil
  160. }
  161. }
  162. }
  163. return nil, -1, sqlerr.TypeNotFound(rel.Name)
  164. }
  165. type Table struct {
  166. Rel *ast.TableName
  167. Columns []*Column
  168. Comment string
  169. }
  170. // TODO: Should this just be ast Nodes?
  171. type Column struct {
  172. Name string
  173. Type ast.TypeName
  174. IsNotNull bool
  175. IsArray bool
  176. Comment string
  177. Length *int
  178. }
  179. type Type interface {
  180. isType()
  181. SetComment(string)
  182. }
  183. type Enum struct {
  184. Name string
  185. Vals []string
  186. Comment string
  187. }
  188. func (e *Enum) SetComment(c string) {
  189. e.Comment = c
  190. }
  191. func (e *Enum) isType() {
  192. }
  193. type CompositeType struct {
  194. Name string
  195. Comment string
  196. }
  197. func (ct *CompositeType) isType() {
  198. }
  199. func (ct *CompositeType) SetComment(c string) {
  200. ct.Comment = c
  201. }
  202. type Function struct {
  203. Name string
  204. Args []*Argument
  205. ReturnType *ast.TypeName
  206. Comment string
  207. Desc string
  208. ReturnTypeNullable bool
  209. }
  210. func (f *Function) InArgs() []*Argument {
  211. var args []*Argument
  212. for _, a := range f.Args {
  213. switch a.Mode {
  214. case ast.FuncParamTable, ast.FuncParamOut:
  215. continue
  216. default:
  217. args = append(args, a)
  218. }
  219. }
  220. return args
  221. }
  222. type Argument struct {
  223. Name string
  224. Type *ast.TypeName
  225. HasDefault bool
  226. Mode ast.FuncParamMode
  227. }
  228. func New(def string) *Catalog {
  229. return &Catalog{
  230. DefaultSchema: def,
  231. Schemas: []*Schema{
  232. {Name: def},
  233. },
  234. Extensions: map[string]struct{}{},
  235. }
  236. }
  237. func (c *Catalog) Build(stmts []ast.Statement) error {
  238. for i := range stmts {
  239. if err := c.Update(stmts[i], nil); err != nil {
  240. return err
  241. }
  242. }
  243. return nil
  244. }
  245. // An interface is used to resolve a circular import between the catalog and compiler packages.
  246. // The createView function requires access to functions in the compiler package to parse the SELECT
  247. // statement that defines the view.
  248. type columnGenerator interface {
  249. OutputColumns(node ast.Node) ([]*Column, error)
  250. }
  251. func (c *Catalog) Update(stmt ast.Statement, colGen columnGenerator) error {
  252. if stmt.Raw == nil {
  253. return nil
  254. }
  255. var err error
  256. switch n := stmt.Raw.Stmt.(type) {
  257. case *ast.AlterTableStmt:
  258. err = c.alterTable(n)
  259. case *ast.AlterTableSetSchemaStmt:
  260. err = c.alterTableSetSchema(n)
  261. case *ast.AlterTypeAddValueStmt:
  262. err = c.alterTypeAddValue(n)
  263. case *ast.AlterTypeRenameValueStmt:
  264. err = c.alterTypeRenameValue(n)
  265. case *ast.CommentOnColumnStmt:
  266. err = c.commentOnColumn(n)
  267. case *ast.CommentOnSchemaStmt:
  268. err = c.commentOnSchema(n)
  269. case *ast.CommentOnTableStmt:
  270. err = c.commentOnTable(n)
  271. case *ast.CommentOnTypeStmt:
  272. err = c.commentOnType(n)
  273. case *ast.CompositeTypeStmt:
  274. err = c.createCompositeType(n)
  275. case *ast.CreateEnumStmt:
  276. err = c.createEnum(n)
  277. case *ast.CreateExtensionStmt:
  278. err = c.createExtension(n)
  279. case *ast.CreateFunctionStmt:
  280. err = c.createFunction(n)
  281. case *ast.CreateSchemaStmt:
  282. err = c.createSchema(n)
  283. case *ast.CreateTableStmt:
  284. err = c.createTable(n)
  285. case *ast.ViewStmt:
  286. err = c.createView(n, colGen)
  287. case *ast.DropFunctionStmt:
  288. err = c.dropFunction(n)
  289. case *ast.DropSchemaStmt:
  290. err = c.dropSchema(n)
  291. case *ast.DropTableStmt:
  292. err = c.dropTable(n)
  293. case *ast.DropTypeStmt:
  294. err = c.dropType(n)
  295. case *ast.RenameColumnStmt:
  296. err = c.renameColumn(n)
  297. case *ast.RenameTableStmt:
  298. err = c.renameTable(n)
  299. case *ast.RenameTypeStmt:
  300. err = c.renameType(n)
  301. case *ast.List:
  302. for _, nn := range n.Items {
  303. if err = c.Update(ast.Statement{
  304. Raw: &ast.RawStmt{
  305. Stmt: nn,
  306. StmtLocation: stmt.Raw.StmtLocation,
  307. StmtLen: stmt.Raw.StmtLen,
  308. },
  309. }, colGen); err != nil {
  310. return err
  311. }
  312. }
  313. }
  314. return err
  315. }