fmt_test.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. package main
  2. import (
  3. "bytes"
  4. "fmt"
  5. "os"
  6. "path/filepath"
  7. "strings"
  8. "testing"
  9. pg_query "github.com/pganalyze/pg_query_go/v4"
  10. "github.com/sqlc-dev/sqlc/internal/debug"
  11. "github.com/sqlc-dev/sqlc/internal/engine/postgresql"
  12. "github.com/sqlc-dev/sqlc/internal/sql/ast"
  13. )
  14. func TestFormat(t *testing.T) {
  15. t.Parallel()
  16. parse := postgresql.NewParser()
  17. for _, tc := range FindTests(t, "testdata", "base") {
  18. tc := tc
  19. if !strings.Contains(tc.Path, filepath.Join("pgx/v5")) {
  20. continue
  21. }
  22. q := filepath.Join(tc.Path, "query.sql")
  23. if _, err := os.Stat(q); os.IsNotExist(err) {
  24. continue
  25. }
  26. t.Run(tc.Name, func(t *testing.T) {
  27. contents, err := os.ReadFile(q)
  28. if err != nil {
  29. t.Fatal(err)
  30. }
  31. for i, query := range bytes.Split(bytes.TrimSpace(contents), []byte(";")) {
  32. if len(query) <= 1 {
  33. continue
  34. }
  35. query := query
  36. t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
  37. expected, err := pg_query.Fingerprint(string(query))
  38. if err != nil {
  39. t.Fatal(err)
  40. }
  41. stmts, err := parse.Parse(bytes.NewReader(query))
  42. if err != nil {
  43. t.Fatal(err)
  44. }
  45. if len(stmts) != 1 {
  46. t.Fatal("expected one statement")
  47. }
  48. if false {
  49. r, err := pg_query.Parse(string(query))
  50. debug.Dump(r, err)
  51. }
  52. out := ast.Format(stmts[0].Raw)
  53. actual, err := pg_query.Fingerprint(out)
  54. if err != nil {
  55. t.Error(err)
  56. }
  57. if expected != actual {
  58. debug.Dump(stmts[0].Raw)
  59. t.Errorf("- %s", expected)
  60. t.Errorf("- %s", string(query))
  61. t.Errorf("+ %s", actual)
  62. t.Errorf("+ %s", out)
  63. }
  64. })
  65. }
  66. })
  67. }
  68. }