1
0

public.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. package catalog
  2. import (
  3. "fmt"
  4. "strings"
  5. "github.com/kyleconroy/sqlc/internal/sql/ast"
  6. "github.com/kyleconroy/sqlc/internal/sql/sqlerr"
  7. )
  8. func (c *Catalog) schemasToSearch(ns string) []string {
  9. if ns == "" {
  10. ns = c.DefaultSchema
  11. }
  12. return append(c.SearchPath, ns)
  13. }
  14. func (c *Catalog) ListFuncsByName(rel *ast.FuncName) ([]Function, error) {
  15. var funcs []Function
  16. lowered := strings.ToLower(rel.Name)
  17. for _, ns := range c.schemasToSearch(rel.Schema) {
  18. s, err := c.getSchema(ns)
  19. if err != nil {
  20. return nil, err
  21. }
  22. for i := range s.Funcs {
  23. if strings.ToLower(s.Funcs[i].Name) == lowered {
  24. funcs = append(funcs, *s.Funcs[i])
  25. }
  26. }
  27. }
  28. return funcs, nil
  29. }
  30. func (c *Catalog) ResolveFuncCall(call *ast.FuncCall) (*Function, error) {
  31. // Do not validate unknown functions
  32. funs, err := c.ListFuncsByName(call.Func)
  33. if err != nil || len(funs) == 0 {
  34. return nil, sqlerr.FunctionNotFound(call.Func.Name)
  35. }
  36. // https://www.postgresql.org/docs/current/sql-syntax-calling-funcs.html
  37. var positional []ast.Node
  38. var named []*ast.NamedArgExpr
  39. if call.Args != nil {
  40. for _, arg := range call.Args.Items {
  41. if narg, ok := arg.(*ast.NamedArgExpr); ok {
  42. named = append(named, narg)
  43. } else {
  44. // The mixed notation combines positional and named notation.
  45. // However, as already mentioned, named arguments cannot precede
  46. // positional arguments.
  47. if len(named) > 0 {
  48. return nil, &sqlerr.Error{
  49. Code: "",
  50. Message: "positional argument cannot follow named argument",
  51. Location: call.Pos(),
  52. }
  53. }
  54. positional = append(positional, arg)
  55. }
  56. }
  57. }
  58. for _, fun := range funs {
  59. args := fun.InArgs()
  60. var defaults int
  61. var variadic bool
  62. known := map[string]struct{}{}
  63. for _, arg := range args {
  64. if arg.HasDefault {
  65. defaults += 1
  66. }
  67. if arg.Mode == ast.FuncParamVariadic {
  68. variadic = true
  69. defaults += 1
  70. }
  71. if arg.Name != "" {
  72. known[arg.Name] = struct{}{}
  73. }
  74. }
  75. if variadic {
  76. if (len(named) + len(positional)) < (len(args) - defaults) {
  77. continue
  78. }
  79. } else {
  80. if (len(named) + len(positional)) > len(args) {
  81. continue
  82. }
  83. if (len(named) + len(positional)) < (len(args) - defaults) {
  84. continue
  85. }
  86. }
  87. // Validate that the provided named arguments exist in the function
  88. var unknownArgName bool
  89. for _, expr := range named {
  90. if expr.Name != nil {
  91. if _, found := known[*expr.Name]; !found {
  92. unknownArgName = true
  93. }
  94. }
  95. }
  96. if unknownArgName {
  97. continue
  98. }
  99. return &fun, nil
  100. }
  101. var sig []string
  102. for range call.Args.Items {
  103. sig = append(sig, "unknown")
  104. }
  105. return nil, &sqlerr.Error{
  106. Code: "42883",
  107. Message: fmt.Sprintf("function %s(%s) does not exist", call.Func.Name, strings.Join(sig, ", ")),
  108. Location: call.Pos(),
  109. // Hint: "No function matches the given name and argument types. You might need to add explicit type casts.",
  110. }
  111. }
  112. func (c *Catalog) GetTable(rel *ast.TableName) (Table, error) {
  113. _, table, err := c.getTable(rel)
  114. if table == nil {
  115. return Table{}, err
  116. } else {
  117. return *table, err
  118. }
  119. }