two_factors_test.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package database
  5. import (
  6. "context"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "gorm.io/gorm"
  12. "gogs.io/gogs/internal/errutil"
  13. )
  14. func TestTwoFactor_BeforeCreate(t *testing.T) {
  15. now := time.Now()
  16. db := &gorm.DB{
  17. Config: &gorm.Config{
  18. SkipDefaultTransaction: true,
  19. NowFunc: func() time.Time {
  20. return now
  21. },
  22. },
  23. }
  24. t.Run("CreatedUnix has been set", func(t *testing.T) {
  25. tf := &TwoFactor{
  26. CreatedUnix: 1,
  27. }
  28. _ = tf.BeforeCreate(db)
  29. assert.Equal(t, int64(1), tf.CreatedUnix)
  30. })
  31. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  32. tf := &TwoFactor{}
  33. _ = tf.BeforeCreate(db)
  34. assert.Equal(t, db.NowFunc().Unix(), tf.CreatedUnix)
  35. })
  36. }
  37. func TestTwoFactor_AfterFind(t *testing.T) {
  38. now := time.Now()
  39. db := &gorm.DB{
  40. Config: &gorm.Config{
  41. SkipDefaultTransaction: true,
  42. NowFunc: func() time.Time {
  43. return now
  44. },
  45. },
  46. }
  47. tf := &TwoFactor{
  48. CreatedUnix: now.Unix(),
  49. }
  50. _ = tf.AfterFind(db)
  51. assert.Equal(t, tf.CreatedUnix, tf.Created.Unix())
  52. }
  53. func TestTwoFactors(t *testing.T) {
  54. if testing.Short() {
  55. t.Skip()
  56. }
  57. t.Parallel()
  58. ctx := context.Background()
  59. s := &TwoFactorsStore{
  60. db: newTestDB(t, "TwoFactorsStore"),
  61. }
  62. for _, tc := range []struct {
  63. name string
  64. test func(t *testing.T, ctx context.Context, s *TwoFactorsStore)
  65. }{
  66. {"Create", twoFactorsCreate},
  67. {"GetByUserID", twoFactorsGetByUserID},
  68. {"IsEnabled", twoFactorsIsEnabled},
  69. } {
  70. t.Run(tc.name, func(t *testing.T) {
  71. t.Cleanup(func() {
  72. err := clearTables(t, s.db)
  73. require.NoError(t, err)
  74. })
  75. tc.test(t, ctx, s)
  76. })
  77. if t.Failed() {
  78. break
  79. }
  80. }
  81. }
  82. func twoFactorsCreate(t *testing.T, ctx context.Context, s *TwoFactorsStore) {
  83. // Create a 2FA token
  84. err := s.Create(ctx, 1, "secure-key", "secure-secret")
  85. require.NoError(t, err)
  86. // Get it back and check the Created field
  87. tf, err := s.GetByUserID(ctx, 1)
  88. require.NoError(t, err)
  89. assert.Equal(t, s.db.NowFunc().Format(time.RFC3339), tf.Created.UTC().Format(time.RFC3339))
  90. // Verify there are 10 recover codes generated
  91. var count int64
  92. err = s.db.Model(new(TwoFactorRecoveryCode)).Count(&count).Error
  93. require.NoError(t, err)
  94. assert.Equal(t, int64(10), count)
  95. }
  96. func twoFactorsGetByUserID(t *testing.T, ctx context.Context, s *TwoFactorsStore) {
  97. // Create a 2FA token for user 1
  98. err := s.Create(ctx, 1, "secure-key", "secure-secret")
  99. require.NoError(t, err)
  100. // We should be able to get it back
  101. _, err = s.GetByUserID(ctx, 1)
  102. require.NoError(t, err)
  103. // Try to get a non-existent 2FA token
  104. _, err = s.GetByUserID(ctx, 2)
  105. wantErr := ErrTwoFactorNotFound{args: errutil.Args{"userID": int64(2)}}
  106. assert.Equal(t, wantErr, err)
  107. }
  108. func twoFactorsIsEnabled(t *testing.T, ctx context.Context, s *TwoFactorsStore) {
  109. // Create a 2FA token for user 1
  110. err := s.Create(ctx, 1, "secure-key", "secure-secret")
  111. require.NoError(t, err)
  112. assert.True(t, s.IsEnabled(ctx, 1))
  113. assert.False(t, s.IsEnabled(ctx, 2))
  114. }