1
0

mysql.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package sqltest
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "os"
  6. "path/filepath"
  7. "testing"
  8. "github.com/kyleconroy/sqlc/internal/sql/sqlpath"
  9. _ "github.com/go-sql-driver/mysql"
  10. )
  11. func MySQL(t *testing.T, migrations []string) (*sql.DB, func()) {
  12. t.Helper()
  13. data := os.Getenv("MYSQL_DATABASE")
  14. host := os.Getenv("MYSQL_HOST")
  15. pass := os.Getenv("MYSQL_ROOT_PASSWORD")
  16. port := os.Getenv("MYSQL_PORT")
  17. user := os.Getenv("MYSQL_USER")
  18. if user == "" {
  19. user = "root"
  20. }
  21. if pass == "" {
  22. pass = "mysecretpassword"
  23. }
  24. if port == "" {
  25. port = "3306"
  26. }
  27. if host == "" {
  28. host = "127.0.0.1"
  29. }
  30. if data == "" {
  31. data = "dinotest"
  32. }
  33. source := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true", user, pass, host, port, data)
  34. t.Logf("db: %s", source)
  35. db, err := sql.Open("mysql", source)
  36. if err != nil {
  37. t.Fatal(err)
  38. }
  39. // For each test, pick a new database name at random.
  40. dbName := "sqltest_mysql_" + id()
  41. if _, err := db.Exec("CREATE DATABASE " + dbName); err != nil {
  42. t.Fatal(err)
  43. }
  44. source = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true", user, pass, host, port, dbName)
  45. sdb, err := sql.Open("mysql", source)
  46. if err != nil {
  47. t.Fatal(err)
  48. }
  49. files, err := sqlpath.Glob(migrations)
  50. if err != nil {
  51. t.Fatal(err)
  52. }
  53. for _, f := range files {
  54. blob, err := os.ReadFile(f)
  55. if err != nil {
  56. t.Fatal(err)
  57. }
  58. if _, err := sdb.Exec(string(blob)); err != nil {
  59. t.Fatalf("%s: %s", filepath.Base(f), err)
  60. }
  61. }
  62. return sdb, func() {
  63. // Drop the test db after test runs
  64. if _, err := db.Exec("DROP DATABASE " + dbName); err != nil {
  65. t.Fatal(err)
  66. }
  67. }
  68. }