123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- package config
- import (
- "strings"
- "testing"
- "github.com/google/go-cmp/cmp"
- )
- const missingVersion = `{
- }`
- const missingPackages = `{
- "version": "1"
- }`
- const unknownVersion = `{
- "version": "foo"
- }`
- const unknownFields = `{
- "version": "1",
- "foo": "bar"
- }`
- func TestBadConfigs(t *testing.T) {
- for _, test := range []struct {
- name string
- err string
- json string
- }{
- {
- "missing version",
- "no version number",
- missingVersion,
- },
- {
- "missing packages",
- "no packages",
- missingPackages,
- },
- {
- "unknown version",
- "invalid version number",
- unknownVersion,
- },
- {
- "unknown fields",
- `yaml: unmarshal errors:
- line 3: field foo not found in type config.V1GenerateSettings`,
- unknownFields,
- },
- } {
- tt := test
- t.Run(tt.name, func(t *testing.T) {
- _, err := ParseConfig(strings.NewReader(tt.json))
- if err == nil {
- t.Fatalf("expected err; got nil")
- }
- if diff := cmp.Diff(err.Error(), tt.err); diff != "" {
- t.Errorf("differed (-want +got):\n%s", diff)
- }
- })
- }
- }
- func TestInvalidConfig(t *testing.T) {
- err := Validate(&Config{
- SQL: []SQL{{
- Gen: SQLGen{
- Go: &SQLGo{
- EmitMethodsWithDBArgument: true,
- EmitPreparedQueries: true,
- },
- },
- }}})
- if err == nil {
- t.Errorf("expected err; got nil")
- }
- }
- func TestTypeOverrides(t *testing.T) {
- for _, test := range []struct {
- override Override
- pkg string
- typeName string
- basic bool
- }{
- {
- Override{
- DBType: "uuid",
- GoType: GoType{Spec: "github.com/segmentio/ksuid.KSUID"},
- },
- "github.com/segmentio/ksuid",
- "ksuid.KSUID",
- false,
- },
- // TODO: Add test for struct pointers
- //
- // {
- // Override{
- // DBType: "uuid",
- // GoType: "github.com/segmentio/*ksuid.KSUID",
- // },
- // "github.com/segmentio/ksuid",
- // "*ksuid.KSUID",
- // false,
- // },
- {
- Override{
- DBType: "citext",
- GoType: GoType{Spec: "string"},
- },
- "",
- "string",
- true,
- },
- } {
- tt := test
- t.Run(tt.override.GoType.Spec, func(t *testing.T) {
- if err := tt.override.Parse(); err != nil {
- t.Fatalf("override parsing failed; %s", err)
- }
- if diff := cmp.Diff(tt.pkg, tt.override.GoImportPath); diff != "" {
- t.Errorf("package mismatch;\n%s", diff)
- }
- if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" {
- t.Errorf("type name mismatch;\n%s", diff)
- }
- if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" {
- t.Errorf("basic mismatch;\n%s", diff)
- }
- })
- }
- for _, test := range []struct {
- override Override
- err string
- }{
- {
- Override{
- DBType: "uuid",
- GoType: GoType{Spec: "Pointer"},
- },
- "Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'",
- },
- {
- Override{
- DBType: "uuid",
- GoType: GoType{Spec: "untyped rune"},
- },
- "Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'",
- },
- } {
- tt := test
- t.Run(tt.override.GoType.Spec, func(t *testing.T) {
- err := tt.override.Parse()
- if err == nil {
- t.Fatalf("expected pars to fail; got nil")
- }
- if diff := cmp.Diff(tt.err, err.Error()); diff != "" {
- t.Errorf("error mismatch;\n%s", diff)
- }
- })
- }
- }
|