compat.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. package compiler
  2. import (
  3. "fmt"
  4. "strings"
  5. "github.com/kyleconroy/sqlc/internal/sql/ast"
  6. "github.com/kyleconroy/sqlc/internal/sql/ast/pg"
  7. "github.com/kyleconroy/sqlc/internal/sql/astutils"
  8. )
  9. // This is mainly copy-pasted from internal/postgresql/parse.go
  10. func stringSlice(list *ast.List) []string {
  11. items := []string{}
  12. for _, item := range list.Items {
  13. if n, ok := item.(*pg.String); ok {
  14. items = append(items, n.Str)
  15. continue
  16. }
  17. if n, ok := item.(*ast.String); ok {
  18. items = append(items, n.Str)
  19. continue
  20. }
  21. }
  22. return items
  23. }
  24. type Relation struct {
  25. Catalog string
  26. Schema string
  27. Name string
  28. }
  29. func parseRelation(node ast.Node) (*Relation, error) {
  30. switch n := node.(type) {
  31. case *ast.List:
  32. parts := stringSlice(n)
  33. switch len(parts) {
  34. case 1:
  35. return &Relation{
  36. Name: parts[0],
  37. }, nil
  38. case 2:
  39. return &Relation{
  40. Schema: parts[0],
  41. Name: parts[1],
  42. }, nil
  43. case 3:
  44. return &Relation{
  45. Catalog: parts[0],
  46. Schema: parts[1],
  47. Name: parts[2],
  48. }, nil
  49. default:
  50. return nil, fmt.Errorf("invalid name: %s", astutils.Join(n, "."))
  51. }
  52. case *pg.RangeVar:
  53. name := Relation{}
  54. if n.Catalogname != nil {
  55. name.Catalog = *n.Catalogname
  56. }
  57. if n.Schemaname != nil {
  58. name.Schema = *n.Schemaname
  59. }
  60. if n.Relname != nil {
  61. name.Name = *n.Relname
  62. }
  63. return &name, nil
  64. case *pg.TypeName:
  65. return parseRelation(n.Names)
  66. default:
  67. return nil, fmt.Errorf("unexpected node type: %T", n)
  68. }
  69. }
  70. func ParseTableName(node ast.Node) (*ast.TableName, error) {
  71. rel, err := parseRelation(node)
  72. if err != nil {
  73. return nil, fmt.Errorf("parse table name: %w", err)
  74. }
  75. return &ast.TableName{
  76. Catalog: rel.Catalog,
  77. Schema: rel.Schema,
  78. Name: rel.Name,
  79. }, nil
  80. }
  81. func ParseTypeName(node ast.Node) (*ast.TypeName, error) {
  82. rel, err := parseRelation(node)
  83. if err != nil {
  84. return nil, fmt.Errorf("parse table name: %w", err)
  85. }
  86. return &ast.TypeName{
  87. Catalog: rel.Catalog,
  88. Schema: rel.Schema,
  89. Name: rel.Name,
  90. }, nil
  91. }
  92. func ParseRelationString(name string) (*Relation, error) {
  93. parts := strings.Split(name, ".")
  94. switch len(parts) {
  95. case 1:
  96. return &Relation{
  97. Name: parts[0],
  98. }, nil
  99. case 2:
  100. return &Relation{
  101. Schema: parts[0],
  102. Name: parts[1],
  103. }, nil
  104. case 3:
  105. return &Relation{
  106. Catalog: parts[0],
  107. Schema: parts[1],
  108. Name: parts[2],
  109. }, nil
  110. default:
  111. return nil, fmt.Errorf("invalid name: %s", name)
  112. }
  113. }