1
0

postgres.go 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package sqltest
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "math/rand"
  6. "os"
  7. "path/filepath"
  8. "testing"
  9. "time"
  10. "github.com/kyleconroy/sqlc/internal/sql/sqlpath"
  11. _ "github.com/lib/pq"
  12. )
  13. func init() {
  14. rand.Seed(time.Now().UnixNano())
  15. }
  16. var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
  17. func id() string {
  18. b := make([]rune, 10)
  19. for i := range b {
  20. b[i] = letterRunes[rand.Intn(len(letterRunes))]
  21. }
  22. return string(b)
  23. }
  24. func PostgreSQL(t *testing.T, migrations []string) (*sql.DB, func()) {
  25. t.Helper()
  26. pgUser := os.Getenv("PG_USER")
  27. pgHost := os.Getenv("PG_HOST")
  28. pgPort := os.Getenv("PG_PORT")
  29. pgPass := os.Getenv("PG_PASSWORD")
  30. pgDB := os.Getenv("PG_DATABASE")
  31. if pgUser == "" {
  32. pgUser = "postgres"
  33. }
  34. if pgPass == "" {
  35. pgPass = "mysecretpassword"
  36. }
  37. if pgPort == "" {
  38. pgPort = "5432"
  39. }
  40. if pgHost == "" {
  41. pgHost = "127.0.0.1"
  42. }
  43. if pgDB == "" {
  44. pgDB = "dinotest"
  45. }
  46. source := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", pgUser, pgPass, pgHost, pgPort, pgDB)
  47. t.Logf("db: %s", source)
  48. db, err := sql.Open("postgres", source)
  49. if err != nil {
  50. t.Fatal(err)
  51. }
  52. // For each test, pick a new schema name at random.
  53. schema := "sqltest_postgresql_" + id()
  54. if _, err := db.Exec("CREATE SCHEMA " + schema); err != nil {
  55. t.Fatal(err)
  56. }
  57. sdb, err := sql.Open("postgres", source+"&search_path="+schema)
  58. if err != nil {
  59. t.Fatal(err)
  60. }
  61. files, err := sqlpath.Glob(migrations)
  62. if err != nil {
  63. t.Fatal(err)
  64. }
  65. for _, f := range files {
  66. blob, err := os.ReadFile(f)
  67. if err != nil {
  68. t.Fatal(err)
  69. }
  70. if _, err := sdb.Exec(string(blob)); err != nil {
  71. t.Fatalf("%s: %s", filepath.Base(f), err)
  72. }
  73. }
  74. return sdb, func() {
  75. if _, err := db.Exec("DROP SCHEMA " + schema + " CASCADE"); err != nil {
  76. t.Fatal(err)
  77. }
  78. }
  79. }