1
0

config_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package config
  2. import (
  3. "strings"
  4. "testing"
  5. "github.com/google/go-cmp/cmp"
  6. )
  7. const missingVersion = `{
  8. }`
  9. const missingPackages = `{
  10. "version": "1"
  11. }`
  12. const unknownVersion = `{
  13. "version": "foo"
  14. }`
  15. const unknownFields = `{
  16. "version": "1",
  17. "foo": "bar"
  18. }`
  19. func TestBadConfigs(t *testing.T) {
  20. for _, test := range []struct {
  21. name string
  22. err string
  23. json string
  24. }{
  25. {
  26. "missing version",
  27. "no version number",
  28. missingVersion,
  29. },
  30. {
  31. "missing packages",
  32. "no packages",
  33. missingPackages,
  34. },
  35. {
  36. "unknown version",
  37. "invalid version number",
  38. unknownVersion,
  39. },
  40. {
  41. "unknown fields",
  42. `yaml: unmarshal errors:
  43. line 3: field foo not found in type config.V1GenerateSettings`,
  44. unknownFields,
  45. },
  46. } {
  47. tt := test
  48. t.Run(tt.name, func(t *testing.T) {
  49. _, err := ParseConfig(strings.NewReader(tt.json))
  50. if err == nil {
  51. t.Fatalf("expected err; got nil")
  52. }
  53. if diff := cmp.Diff(err.Error(), tt.err); diff != "" {
  54. t.Errorf("differed (-want +got):\n%s", diff)
  55. }
  56. })
  57. }
  58. }
  59. func TestInvalidConfig(t *testing.T) {
  60. err := Validate(&Config{
  61. SQL: []SQL{{
  62. Gen: SQLGen{
  63. Go: &SQLGo{
  64. EmitMethodsWithDBArgument: true,
  65. EmitPreparedQueries: true,
  66. },
  67. },
  68. }}})
  69. if err == nil {
  70. t.Errorf("expected err; got nil")
  71. }
  72. }
  73. func TestTypeOverrides(t *testing.T) {
  74. for _, test := range []struct {
  75. override Override
  76. pkg string
  77. typeName string
  78. basic bool
  79. }{
  80. {
  81. Override{
  82. DBType: "uuid",
  83. GoType: GoType{Spec: "github.com/segmentio/ksuid.KSUID"},
  84. },
  85. "github.com/segmentio/ksuid",
  86. "ksuid.KSUID",
  87. false,
  88. },
  89. // TODO: Add test for struct pointers
  90. //
  91. // {
  92. // Override{
  93. // DBType: "uuid",
  94. // GoType: "github.com/segmentio/*ksuid.KSUID",
  95. // },
  96. // "github.com/segmentio/ksuid",
  97. // "*ksuid.KSUID",
  98. // false,
  99. // },
  100. {
  101. Override{
  102. DBType: "citext",
  103. GoType: GoType{Spec: "string"},
  104. },
  105. "",
  106. "string",
  107. true,
  108. },
  109. } {
  110. tt := test
  111. t.Run(tt.override.GoType.Spec, func(t *testing.T) {
  112. if err := tt.override.Parse(); err != nil {
  113. t.Fatalf("override parsing failed; %s", err)
  114. }
  115. if diff := cmp.Diff(tt.pkg, tt.override.GoImportPath); diff != "" {
  116. t.Errorf("package mismatch;\n%s", diff)
  117. }
  118. if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" {
  119. t.Errorf("type name mismatch;\n%s", diff)
  120. }
  121. if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" {
  122. t.Errorf("basic mismatch;\n%s", diff)
  123. }
  124. })
  125. }
  126. for _, test := range []struct {
  127. override Override
  128. err string
  129. }{
  130. {
  131. Override{
  132. DBType: "uuid",
  133. GoType: GoType{Spec: "Pointer"},
  134. },
  135. "Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'",
  136. },
  137. {
  138. Override{
  139. DBType: "uuid",
  140. GoType: GoType{Spec: "untyped rune"},
  141. },
  142. "Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'",
  143. },
  144. } {
  145. tt := test
  146. t.Run(tt.override.GoType.Spec, func(t *testing.T) {
  147. err := tt.override.Parse()
  148. if err == nil {
  149. t.Fatalf("expected pars to fail; got nil")
  150. }
  151. if diff := cmp.Diff(tt.err, err.Error()); diff != "" {
  152. t.Errorf("error mismatch;\n%s", diff)
  153. }
  154. })
  155. }
  156. }