config_test.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. package config
  2. import (
  3. "strings"
  4. "testing"
  5. "github.com/google/go-cmp/cmp"
  6. )
  7. const missingVersion = `{
  8. }`
  9. const missingPackages = `{
  10. "version": "1"
  11. }`
  12. const unknownVersion = `{
  13. "version": "foo"
  14. }`
  15. const unknownFields = `{
  16. "version": "1",
  17. "foo": "bar"
  18. }`
  19. func TestBadConfigs(t *testing.T) {
  20. for _, test := range []struct {
  21. name string
  22. err string
  23. json string
  24. }{
  25. {
  26. "missing version",
  27. "no version number",
  28. missingVersion,
  29. },
  30. {
  31. "missing packages",
  32. "no packages",
  33. missingPackages,
  34. },
  35. {
  36. "unknown version",
  37. "invalid version number",
  38. unknownVersion,
  39. },
  40. {
  41. "unknown fields",
  42. `yaml: unmarshal errors:
  43. line 3: field foo not found in type config.V1GenerateSettings`,
  44. unknownFields,
  45. },
  46. } {
  47. tt := test
  48. t.Run(tt.name, func(t *testing.T) {
  49. _, err := ParseConfig(strings.NewReader(tt.json))
  50. if err == nil {
  51. t.Fatalf("expected err; got nil")
  52. }
  53. if diff := cmp.Diff(err.Error(), tt.err); diff != "" {
  54. t.Errorf("differed (-want +got):\n%s", diff)
  55. }
  56. })
  57. }
  58. }
  59. func TestTypeOverrides(t *testing.T) {
  60. for _, test := range []struct {
  61. override Override
  62. pkg string
  63. typeName string
  64. basic bool
  65. }{
  66. {
  67. Override{
  68. DBType: "uuid",
  69. GoType: "github.com/segmentio/ksuid.KSUID",
  70. },
  71. "github.com/segmentio/ksuid",
  72. "ksuid.KSUID",
  73. false,
  74. },
  75. // TODO: Add test for struct pointers
  76. //
  77. // {
  78. // Override{
  79. // DBType: "uuid",
  80. // GoType: "github.com/segmentio/*ksuid.KSUID",
  81. // },
  82. // "github.com/segmentio/ksuid",
  83. // "*ksuid.KSUID",
  84. // false,
  85. // },
  86. {
  87. Override{
  88. DBType: "citext",
  89. GoType: "string",
  90. },
  91. "",
  92. "string",
  93. true,
  94. },
  95. } {
  96. tt := test
  97. t.Run(tt.override.GoType, func(t *testing.T) {
  98. if err := tt.override.Parse(); err != nil {
  99. t.Fatalf("override parsing failed; %s", err)
  100. }
  101. if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" {
  102. t.Errorf("type name mismatch;\n%s", diff)
  103. }
  104. if diff := cmp.Diff(tt.pkg, tt.override.GoPackage); diff != "" {
  105. t.Errorf("package mismatch;\n%s", diff)
  106. }
  107. if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" {
  108. t.Errorf("basic mismatch;\n%s", diff)
  109. }
  110. })
  111. }
  112. for _, test := range []struct {
  113. override Override
  114. err string
  115. }{
  116. {
  117. Override{
  118. DBType: "uuid",
  119. GoType: "Pointer",
  120. },
  121. "Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'",
  122. },
  123. {
  124. Override{
  125. DBType: "uuid",
  126. GoType: "untyped rune",
  127. },
  128. "Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'",
  129. },
  130. } {
  131. tt := test
  132. t.Run(tt.override.GoType, func(t *testing.T) {
  133. err := tt.override.Parse()
  134. if err == nil {
  135. t.Fatalf("expected pars to fail; got nil")
  136. }
  137. if diff := cmp.Diff(tt.err, err.Error()); diff != "" {
  138. t.Errorf("error mismatch;\n%s", diff)
  139. }
  140. })
  141. }
  142. }