public_keys.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. // Copyright 2023 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package database
  5. import (
  6. "os"
  7. "path/filepath"
  8. "github.com/pkg/errors"
  9. "gorm.io/gorm"
  10. "gogs.io/gogs/internal/conf"
  11. "gogs.io/gogs/internal/osutil"
  12. )
  13. // PublicKeysStore is the storage layer for public keys.
  14. type PublicKeysStore struct {
  15. db *gorm.DB
  16. }
  17. func newPublicKeysStore(db *gorm.DB) *PublicKeysStore {
  18. return &PublicKeysStore{db: db}
  19. }
  20. func authorizedKeysPath() string {
  21. return filepath.Join(conf.SSH.RootPath, "authorized_keys")
  22. }
  23. // RewriteAuthorizedKeys rewrites the "authorized_keys" file under the SSH root
  24. // path with all public keys stored in the database.
  25. func (s *PublicKeysStore) RewriteAuthorizedKeys() error {
  26. sshOpLocker.Lock()
  27. defer sshOpLocker.Unlock()
  28. err := os.MkdirAll(conf.SSH.RootPath, os.ModePerm)
  29. if err != nil {
  30. return errors.Wrap(err, "create SSH root path")
  31. }
  32. fpath := authorizedKeysPath()
  33. tempPath := fpath + ".tmp"
  34. f, err := os.OpenFile(tempPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
  35. if err != nil {
  36. return errors.Wrap(err, "create temporary file")
  37. }
  38. defer func() {
  39. _ = f.Close()
  40. _ = os.Remove(tempPath)
  41. }()
  42. // NOTE: More recently updated keys are more likely to be used more frequently,
  43. // putting them in the earlier lines could speed up the key lookup by SSHD.
  44. rows, err := s.db.Model(&PublicKey{}).Order("updated_unix DESC").Rows()
  45. if err != nil {
  46. return errors.Wrap(err, "iterate public keys")
  47. }
  48. defer func() { _ = rows.Close() }()
  49. for rows.Next() {
  50. var key PublicKey
  51. err = s.db.ScanRows(rows, &key)
  52. if err != nil {
  53. return errors.Wrap(err, "scan rows")
  54. }
  55. _, err = f.WriteString(key.AuthorizedString())
  56. if err != nil {
  57. return errors.Wrapf(err, "write key %d", key.ID)
  58. }
  59. }
  60. if err = rows.Err(); err != nil {
  61. return errors.Wrap(err, "check rows.Err")
  62. }
  63. err = f.Close()
  64. if err != nil {
  65. return errors.Wrap(err, "close temporary file")
  66. }
  67. if osutil.IsExist(fpath) {
  68. err = os.Remove(fpath)
  69. if err != nil {
  70. return errors.Wrap(err, "remove")
  71. }
  72. }
  73. err = os.Rename(tempPath, fpath)
  74. if err != nil {
  75. return errors.Wrap(err, "rename")
  76. }
  77. return nil
  78. }