123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372 |
- package catalog
- import (
- "strings"
- "github.com/kyleconroy/sqlc/internal/sql/ast"
- "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
- )
- func stringSlice(list *ast.List) []string {
- items := []string{}
- for _, item := range list.Items {
- if n, ok := item.(*ast.String); ok {
- items = append(items, n.Str)
- }
- }
- return items
- }
- type Catalog struct {
- Comment string
- DefaultSchema string
- Name string
- Schemas []*Schema
- SearchPath []string
- LoadExtension func(string) *Schema
- // TODO: un-export
- Extensions map[string]struct{}
- }
- func (c *Catalog) getSchema(name string) (*Schema, error) {
- for i := range c.Schemas {
- if c.Schemas[i].Name == name {
- return c.Schemas[i], nil
- }
- }
- return nil, sqlerr.SchemaNotFound(name)
- }
- func (c *Catalog) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int, error) {
- ns := rel.Schema
- if ns == "" {
- ns = c.DefaultSchema
- }
- s, err := c.getSchema(ns)
- if err != nil {
- return nil, -1, err
- }
- return s.getFunc(rel, tns)
- }
- func (c *Catalog) getTable(name *ast.TableName) (*Schema, *Table, error) {
- ns := name.Schema
- if ns == "" {
- ns = c.DefaultSchema
- }
- var s *Schema
- for i := range c.Schemas {
- if c.Schemas[i].Name == ns {
- s = c.Schemas[i]
- break
- }
- }
- if s == nil {
- return nil, nil, sqlerr.SchemaNotFound(ns)
- }
- t, _, err := s.getTable(name)
- if err != nil {
- return nil, nil, err
- }
- return s, t, nil
- }
- func (c *Catalog) getType(rel *ast.TypeName) (Type, int, error) {
- ns := rel.Schema
- if ns == "" {
- ns = c.DefaultSchema
- }
- s, err := c.getSchema(ns)
- if err != nil {
- return nil, -1, err
- }
- return s.getType(rel)
- }
- type Schema struct {
- Name string
- Tables []*Table
- Types []Type
- Funcs []*Function
- Comment string
- }
- func sameType(a, b *ast.TypeName) bool {
- if a.Catalog != b.Catalog {
- return false
- }
- // The pg_catalog schema is searched by default, so take that into
- // account when comparing schemas
- aSchema := a.Schema
- bSchema := b.Schema
- if aSchema == "pg_catalog" {
- aSchema = ""
- }
- if bSchema == "pg_catalog" {
- bSchema = ""
- }
- if aSchema != bSchema {
- return false
- }
- if a.Name != b.Name {
- return false
- }
- return true
- }
- func (s *Schema) getFunc(rel *ast.FuncName, tns []*ast.TypeName) (*Function, int, error) {
- for i := range s.Funcs {
- if strings.ToLower(s.Funcs[i].Name) != strings.ToLower(rel.Name) {
- continue
- }
- args := s.Funcs[i].InArgs()
- if len(args) != len(tns) {
- continue
- }
- found := true
- for j := range args {
- if !sameType(s.Funcs[i].Args[j].Type, tns[j]) {
- found = false
- break
- }
- }
- if !found {
- continue
- }
- return s.Funcs[i], i, nil
- }
- return nil, -1, sqlerr.RelationNotFound(rel.Name)
- }
- func (s *Schema) getFuncByName(rel *ast.FuncName) (*Function, int, error) {
- idx := -1
- name := strings.ToLower(rel.Name)
- for i := range s.Funcs {
- lowered := strings.ToLower(s.Funcs[i].Name)
- if lowered == name && idx >= 0 {
- return nil, -1, sqlerr.FunctionNotUnique(rel.Name)
- }
- if lowered == name {
- idx = i
- }
- }
- if idx < 0 {
- return nil, -1, sqlerr.RelationNotFound(rel.Name)
- }
- return s.Funcs[idx], idx, nil
- }
- func (s *Schema) getTable(rel *ast.TableName) (*Table, int, error) {
- for i := range s.Tables {
- if s.Tables[i].Rel.Name == rel.Name {
- return s.Tables[i], i, nil
- }
- }
- return nil, -1, sqlerr.RelationNotFound(rel.Name)
- }
- func (s *Schema) getType(rel *ast.TypeName) (Type, int, error) {
- for i := range s.Types {
- switch typ := s.Types[i].(type) {
- case *Enum:
- if typ.Name == rel.Name {
- return s.Types[i], i, nil
- }
- }
- }
- return nil, -1, sqlerr.TypeNotFound(rel.Name)
- }
- type Table struct {
- Rel *ast.TableName
- Columns []*Column
- Comment string
- }
- // TODO: Should this just be ast Nodes?
- type Column struct {
- Name string
- Type ast.TypeName
- IsNotNull bool
- IsArray bool
- Comment string
- Length *int
- }
- type Type interface {
- isType()
- SetComment(string)
- }
- type Enum struct {
- Name string
- Vals []string
- Comment string
- }
- func (e *Enum) SetComment(c string) {
- e.Comment = c
- }
- func (e *Enum) isType() {
- }
- type CompositeType struct {
- Name string
- Comment string
- }
- func (ct *CompositeType) isType() {
- }
- func (ct *CompositeType) SetComment(c string) {
- ct.Comment = c
- }
- type Function struct {
- Name string
- Args []*Argument
- ReturnType *ast.TypeName
- Comment string
- Desc string
- ReturnTypeNullable bool
- }
- func (f *Function) InArgs() []*Argument {
- var args []*Argument
- for _, a := range f.Args {
- switch a.Mode {
- case ast.FuncParamTable, ast.FuncParamOut:
- continue
- default:
- args = append(args, a)
- }
- }
- return args
- }
- type Argument struct {
- Name string
- Type *ast.TypeName
- HasDefault bool
- Mode ast.FuncParamMode
- }
- func New(def string) *Catalog {
- return &Catalog{
- DefaultSchema: def,
- Schemas: []*Schema{
- {Name: def},
- },
- Extensions: map[string]struct{}{},
- }
- }
- func (c *Catalog) Build(stmts []ast.Statement) error {
- for i := range stmts {
- if err := c.Update(stmts[i], nil); err != nil {
- return err
- }
- }
- return nil
- }
- // An interface is used to resolve a circular import between the catalog and compiler packages.
- // The createView function requires access to functions in the compiler package to parse the SELECT
- // statement that defines the view.
- type columnGenerator interface {
- OutputColumns(node ast.Node) ([]*Column, error)
- }
- func (c *Catalog) Update(stmt ast.Statement, colGen columnGenerator) error {
- if stmt.Raw == nil {
- return nil
- }
- var err error
- switch n := stmt.Raw.Stmt.(type) {
- case *ast.AlterTableStmt:
- err = c.alterTable(n)
- case *ast.AlterTableSetSchemaStmt:
- err = c.alterTableSetSchema(n)
- case *ast.AlterTypeAddValueStmt:
- err = c.alterTypeAddValue(n)
- case *ast.AlterTypeRenameValueStmt:
- err = c.alterTypeRenameValue(n)
- case *ast.CommentOnColumnStmt:
- err = c.commentOnColumn(n)
- case *ast.CommentOnSchemaStmt:
- err = c.commentOnSchema(n)
- case *ast.CommentOnTableStmt:
- err = c.commentOnTable(n)
- case *ast.CommentOnTypeStmt:
- err = c.commentOnType(n)
- case *ast.CompositeTypeStmt:
- err = c.createCompositeType(n)
- case *ast.CreateEnumStmt:
- err = c.createEnum(n)
- case *ast.CreateExtensionStmt:
- err = c.createExtension(n)
- case *ast.CreateFunctionStmt:
- err = c.createFunction(n)
- case *ast.CreateSchemaStmt:
- err = c.createSchema(n)
- case *ast.CreateTableStmt:
- err = c.createTable(n)
- case *ast.ViewStmt:
- err = c.createView(n, colGen)
- case *ast.DropFunctionStmt:
- err = c.dropFunction(n)
- case *ast.DropSchemaStmt:
- err = c.dropSchema(n)
- case *ast.DropTableStmt:
- err = c.dropTable(n)
- case *ast.DropTypeStmt:
- err = c.dropType(n)
- case *ast.RenameColumnStmt:
- err = c.renameColumn(n)
- case *ast.RenameTableStmt:
- err = c.renameTable(n)
- case *ast.RenameTypeStmt:
- err = c.renameType(n)
- case *ast.List:
- for _, nn := range n.Items {
- if err = c.Update(ast.Statement{
- Raw: &ast.RawStmt{
- Stmt: nn,
- StmtLocation: stmt.Raw.StmtLocation,
- StmtLen: stmt.Raw.StmtLen,
- },
- }, colGen); err != nil {
- return err
- }
- }
- }
- return err
- }
|