func.go 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. package catalog
  2. import (
  3. "errors"
  4. "github.com/kyleconroy/sqlc/internal/sql/ast"
  5. "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
  6. )
  7. func (c *Catalog) createFunction(stmt *ast.CreateFunctionStmt) error {
  8. ns := stmt.Func.Schema
  9. if ns == "" {
  10. ns = c.DefaultSchema
  11. }
  12. s, err := c.getSchema(ns)
  13. if err != nil {
  14. return err
  15. }
  16. fn := &Function{
  17. Name: stmt.Func.Name,
  18. Args: make([]*Argument, len(stmt.Params.Items)),
  19. ReturnType: stmt.ReturnType,
  20. }
  21. types := make([]*ast.TypeName, len(stmt.Params.Items))
  22. for i, item := range stmt.Params.Items {
  23. arg := item.(*ast.FuncParam)
  24. var name string
  25. if arg.Name != nil {
  26. name = *arg.Name
  27. }
  28. fn.Args[i] = &Argument{
  29. Name: name,
  30. Type: arg.Type,
  31. Mode: arg.Mode,
  32. HasDefault: arg.DefExpr != nil,
  33. }
  34. types[i] = arg.Type
  35. }
  36. _, idx, err := s.getFunc(stmt.Func, types)
  37. if err == nil && !stmt.Replace {
  38. return sqlerr.RelationExists(stmt.Func.Name)
  39. }
  40. if idx >= 0 {
  41. s.Funcs[idx] = fn
  42. } else {
  43. s.Funcs = append(s.Funcs, fn)
  44. }
  45. return nil
  46. }
  47. func (c *Catalog) dropFunction(stmt *ast.DropFunctionStmt) error {
  48. for _, spec := range stmt.Funcs {
  49. ns := spec.Name.Schema
  50. if ns == "" {
  51. ns = c.DefaultSchema
  52. }
  53. s, err := c.getSchema(ns)
  54. if errors.Is(err, sqlerr.NotFound) && stmt.MissingOk {
  55. continue
  56. } else if err != nil {
  57. return err
  58. }
  59. var idx int
  60. if spec.HasArgs {
  61. _, idx, err = s.getFunc(spec.Name, spec.Args)
  62. } else {
  63. _, idx, err = s.getFuncByName(spec.Name)
  64. }
  65. if errors.Is(err, sqlerr.NotFound) && stmt.MissingOk {
  66. continue
  67. } else if err != nil {
  68. return err
  69. }
  70. s.Funcs = append(s.Funcs[:idx], s.Funcs[idx+1:]...)
  71. }
  72. return nil
  73. }