postgres.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package sqltest
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "io/ioutil"
  6. "math/rand"
  7. "os"
  8. "path/filepath"
  9. "testing"
  10. "time"
  11. "github.com/kyleconroy/sqlc/internal/sql/sqlpath"
  12. _ "github.com/lib/pq"
  13. )
  14. func init() {
  15. rand.Seed(time.Now().UnixNano())
  16. }
  17. var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
  18. func id() string {
  19. b := make([]rune, 10)
  20. for i := range b {
  21. b[i] = letterRunes[rand.Intn(len(letterRunes))]
  22. }
  23. return string(b)
  24. }
  25. func PostgreSQL(t *testing.T, migrations []string) (*sql.DB, func()) {
  26. t.Helper()
  27. pgUser := os.Getenv("PG_USER")
  28. pgHost := os.Getenv("PG_HOST")
  29. pgPort := os.Getenv("PG_PORT")
  30. pgPass := os.Getenv("PG_PASSWORD")
  31. pgDB := os.Getenv("PG_DATABASE")
  32. if pgUser == "" {
  33. pgUser = "postgres"
  34. }
  35. if pgPass == "" {
  36. pgPass = "mysecretpassword"
  37. }
  38. if pgPort == "" {
  39. pgPort = "5432"
  40. }
  41. if pgHost == "" {
  42. pgHost = "127.0.0.1"
  43. }
  44. if pgDB == "" {
  45. pgDB = "dinotest"
  46. }
  47. source := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", pgUser, pgPass, pgHost, pgPort, pgDB)
  48. t.Logf("db: %s", source)
  49. db, err := sql.Open("postgres", source)
  50. if err != nil {
  51. t.Fatal(err)
  52. }
  53. schema := "sqltest_" + id()
  54. // For each test, pick a new schema name at random.
  55. // `foo` is used here only as an example
  56. if _, err := db.Exec("CREATE SCHEMA " + schema); err != nil {
  57. t.Fatal(err)
  58. }
  59. sdb, err := sql.Open("postgres", source+"&search_path="+schema)
  60. if err != nil {
  61. t.Fatal(err)
  62. }
  63. files, err := sqlpath.Glob(migrations)
  64. if err != nil {
  65. t.Fatal(err)
  66. }
  67. for _, f := range files {
  68. blob, err := ioutil.ReadFile(f)
  69. if err != nil {
  70. t.Fatal(err)
  71. }
  72. if _, err := sdb.Exec(string(blob)); err != nil {
  73. t.Fatalf("%s: %s", filepath.Base(f), err)
  74. }
  75. }
  76. return sdb, func() {
  77. if _, err := db.Exec("DROP SCHEMA " + schema + " CASCADE"); err != nil {
  78. t.Fatal(err)
  79. }
  80. }
  81. }