mirror of
https://github.com/mjl-/mox.git
synced 2024-12-27 17:03:47 +03:00
d1b87cdb0d
since we are now at go1.21 as minimum.
320 lines
9.6 KiB
Go
320 lines
9.6 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
cryptorand "crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"runtime/debug"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/mjl-/bstore"
|
|
|
|
"github.com/mjl-/mox/metrics"
|
|
"github.com/mjl-/mox/mlog"
|
|
"github.com/mjl-/mox/mox-"
|
|
)
|
|
|
|
const sessionsPerAccount = 100 // We remove the oldest when 100th is added.
|
|
const sessionLifetime = 24 * time.Hour // Extended automatically by use.
|
|
const sessionWriteDelay = 5 * time.Minute // Per account, for coalescing writes.
|
|
|
|
var sessions = struct {
|
|
sync.Mutex
|
|
|
|
// For each account, we keep all sessions (with fixed maximum number) in memory. If
|
|
// the map for an account is nil, it is initialized from the database on first use.
|
|
accounts map[string]map[SessionToken]LoginSession
|
|
|
|
// We flush sessions with extended expiration timestamp to disk with a delay, to
|
|
// coalesce potentially many changes. The delay is short enough that we don't have
|
|
// to care about flushing to disk on shutdown.
|
|
pendingFlushes map[string]map[SessionToken]struct{}
|
|
}{
|
|
accounts: map[string]map[SessionToken]LoginSession{},
|
|
pendingFlushes: map[string]map[SessionToken]struct{}{},
|
|
}
|
|
|
|
// Ensure sessions for account are initialized from database. If the sessions were
|
|
// initialized from the database, or when alwaysOpenAccount is true, an open
|
|
// account is returned (assuming no error occurred).
|
|
//
|
|
// must be called with sessions lock held.
|
|
func ensureAccountSessions(ctx context.Context, log mlog.Log, accountName string, alwaysOpenAccount bool) (*Account, error) {
|
|
var acc *Account
|
|
accSessions := sessions.accounts[accountName]
|
|
if accSessions == nil {
|
|
var err error
|
|
acc, err = OpenAccount(log, accountName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// We still hold the lock, not great...
|
|
|
|
accSessions = map[SessionToken]LoginSession{}
|
|
err = bstore.QueryDB[LoginSession](ctx, acc.DB).ForEach(func(ls LoginSession) error {
|
|
// We keep strings around for easy comparison.
|
|
ls.sessionToken = SessionToken(base64.RawURLEncoding.EncodeToString(ls.SessionTokenBinary[:]))
|
|
ls.csrfToken = CSRFToken(base64.RawURLEncoding.EncodeToString(ls.CSRFTokenBinary[:]))
|
|
|
|
accSessions[ls.sessionToken] = ls
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sessions.accounts[accountName] = accSessions
|
|
}
|
|
if acc == nil && alwaysOpenAccount {
|
|
return OpenAccount(log, accountName)
|
|
}
|
|
return acc, nil
|
|
}
|
|
|
|
// SessionUse checks if a session is valid. If csrfToken is the empty string, no
|
|
// CSRF check is done. Otherwise it must be the csrf token associated with the
|
|
// session token.
|
|
func SessionUse(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken, csrfToken CSRFToken) (LoginSession, error) {
|
|
sessions.Lock()
|
|
defer sessions.Unlock()
|
|
|
|
acc, err := ensureAccountSessions(ctx, log, accountName, false)
|
|
if err != nil {
|
|
return LoginSession{}, err
|
|
} else if acc != nil {
|
|
if err := acc.Close(); err != nil {
|
|
return LoginSession{}, fmt.Errorf("closing account: %w", err)
|
|
}
|
|
}
|
|
|
|
return sessionUse(ctx, log, accountName, sessionToken, csrfToken)
|
|
}
|
|
|
|
// must be called with sessions lock held.
|
|
func sessionUse(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken, csrfToken CSRFToken) (LoginSession, error) {
|
|
// Check if valid.
|
|
ls, ok := sessions.accounts[accountName][sessionToken]
|
|
if !ok {
|
|
return LoginSession{}, fmt.Errorf("unknown session token")
|
|
} else if time.Until(ls.Expires) < 0 {
|
|
return LoginSession{}, fmt.Errorf("session expired")
|
|
} else if csrfToken != "" && csrfToken != ls.csrfToken {
|
|
return LoginSession{}, fmt.Errorf("mismatch between csrf and session tokens")
|
|
}
|
|
|
|
// Extend lifetime.
|
|
ls.Expires = time.Now().Add(sessionLifetime)
|
|
sessions.accounts[accountName][sessionToken] = ls
|
|
|
|
// If we haven't scheduled a flush to database yet, schedule one now.
|
|
if sessions.pendingFlushes[accountName] == nil {
|
|
sessions.pendingFlushes[accountName] = map[SessionToken]struct{}{}
|
|
go func() {
|
|
pkglog := mlog.New("store", nil)
|
|
|
|
defer func() {
|
|
x := recover()
|
|
if x != nil {
|
|
pkglog.Error("recover from panic", slog.Any("panic", x))
|
|
debug.PrintStack()
|
|
metrics.PanicInc(metrics.Store)
|
|
}
|
|
}()
|
|
|
|
time.Sleep(sessionWriteDelay)
|
|
sessionsDelayedFlush(pkglog, accountName)
|
|
}()
|
|
}
|
|
sessions.pendingFlushes[accountName][ls.sessionToken] = struct{}{}
|
|
|
|
return ls, nil
|
|
}
|
|
|
|
// wait, then flush all changed sessions for an account.
|
|
func sessionsDelayedFlush(log mlog.Log, accountName string) {
|
|
sessions.Lock()
|
|
defer sessions.Unlock()
|
|
|
|
sessionTokens := sessions.pendingFlushes[accountName]
|
|
delete(sessions.pendingFlushes, accountName)
|
|
|
|
_, ok := sessions.accounts[accountName]
|
|
if !ok {
|
|
// Account may have been removed. Nothing to flush.
|
|
return
|
|
}
|
|
|
|
acc, err := OpenAccount(log, accountName)
|
|
if err != nil && errors.Is(err, ErrAccountUnknown) {
|
|
// Account may have been removed. Nothing to flush.
|
|
log.Infox("flushing sessions for account", err, slog.String("account", accountName))
|
|
return
|
|
}
|
|
if err != nil {
|
|
log.Errorx("open account for flushing changed session tokens", err, slog.String("account", accountName))
|
|
return
|
|
}
|
|
defer func() {
|
|
err := acc.Close()
|
|
log.Check(err, "closing account")
|
|
}()
|
|
|
|
err = acc.DB.Write(mox.Context, func(tx *bstore.Tx) error {
|
|
for sessionToken := range sessionTokens {
|
|
ls, ok := sessions.accounts[accountName][sessionToken]
|
|
if !ok {
|
|
return fmt.Errorf("unknown session token to flush")
|
|
}
|
|
if err := tx.Update(&ls); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
log.Check(err, "flushing changed sessions for account", slog.String("account", accountName))
|
|
}
|
|
|
|
// SessionAddTokens adds a prepared or pre-existing LoginSession to the database and
|
|
// cache. Can be used to restore a session token that was used to reset a password.
|
|
func SessionAddToken(ctx context.Context, log mlog.Log, ls *LoginSession) error {
|
|
sessions.Lock()
|
|
defer sessions.Unlock()
|
|
|
|
acc, err := ensureAccountSessions(ctx, log, ls.AccountName, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
err := acc.Close()
|
|
log.Check(err, "closing account after adding session token")
|
|
}()
|
|
|
|
return sessionAddToken(ctx, log, acc, ls)
|
|
}
|
|
|
|
// caller must hold sessions lock.
|
|
func sessionAddToken(ctx context.Context, log mlog.Log, acc *Account, ls *LoginSession) error {
|
|
ls.ID = 0
|
|
|
|
err := acc.DB.Write(ctx, func(tx *bstore.Tx) error {
|
|
// Remove sessions if we have too many, starting with expired sessions, and
|
|
// removing the oldest if needed.
|
|
if len(sessions.accounts[ls.AccountName]) >= sessionsPerAccount {
|
|
var oldest LoginSession
|
|
for _, ols := range sessions.accounts[ls.AccountName] {
|
|
if time.Until(ols.Expires) < 0 {
|
|
if err := tx.Delete(&ols); err != nil {
|
|
return err
|
|
}
|
|
delete(sessions.accounts[ls.AccountName], ols.sessionToken)
|
|
continue
|
|
}
|
|
if oldest.ID == 0 || ols.Expires.Before(oldest.Expires) {
|
|
oldest = ols
|
|
}
|
|
}
|
|
if len(sessions.accounts[ls.AccountName]) >= sessionsPerAccount {
|
|
if err := tx.Delete(&oldest); err != nil {
|
|
return err
|
|
}
|
|
delete(sessions.accounts[ls.AccountName], oldest.sessionToken)
|
|
}
|
|
}
|
|
|
|
if err := tx.Insert(ls); err != nil {
|
|
return fmt.Errorf("insert: %v", err)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sessions.accounts[ls.AccountName][ls.sessionToken] = *ls
|
|
return nil
|
|
}
|
|
|
|
// SessionAdd creates a new session token, with csrf token, and adds it to the
|
|
// database and in-memory session cache. If there are too many sessions, the oldest
|
|
// is removed.
|
|
func SessionAdd(ctx context.Context, log mlog.Log, accountName, loginAddress string) (session SessionToken, csrf CSRFToken, rerr error) {
|
|
// Prepare new LoginSession.
|
|
ls := LoginSession{0, time.Time{}, time.Now().Add(sessionLifetime), [16]byte{}, [16]byte{}, accountName, loginAddress, "", ""}
|
|
if _, err := cryptorand.Read(ls.SessionTokenBinary[:]); err != nil {
|
|
return "", "", err
|
|
}
|
|
if _, err := cryptorand.Read(ls.CSRFTokenBinary[:]); err != nil {
|
|
return "", "", err
|
|
}
|
|
ls.sessionToken = SessionToken(base64.RawURLEncoding.EncodeToString(ls.SessionTokenBinary[:]))
|
|
ls.csrfToken = CSRFToken(base64.RawURLEncoding.EncodeToString(ls.CSRFTokenBinary[:]))
|
|
|
|
sessions.Lock()
|
|
defer sessions.Unlock()
|
|
|
|
acc, err := ensureAccountSessions(ctx, log, accountName, true)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
defer func() {
|
|
err := acc.Close()
|
|
log.Check(err, "closing account")
|
|
}()
|
|
|
|
if err := sessionAddToken(ctx, log, acc, &ls); err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
return ls.sessionToken, ls.csrfToken, nil
|
|
}
|
|
|
|
// SessionRemove removes a session from the database and in-memory cache. Future
|
|
// operations using the session token will fail.
|
|
func SessionRemove(ctx context.Context, log mlog.Log, accountName string, sessionToken SessionToken) error {
|
|
sessions.Lock()
|
|
defer sessions.Unlock()
|
|
|
|
acc, err := ensureAccountSessions(ctx, log, accountName, true)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer acc.Close()
|
|
|
|
ls, ok := sessions.accounts[accountName][sessionToken]
|
|
if !ok {
|
|
return fmt.Errorf("unknown session token")
|
|
}
|
|
|
|
if err := acc.DB.Delete(ctx, &ls); err != nil {
|
|
return err
|
|
}
|
|
|
|
delete(sessions.accounts[accountName], sessionToken)
|
|
pf := sessions.pendingFlushes[accountName]
|
|
if pf != nil {
|
|
delete(pf, sessionToken)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// sessionRemoveAll removes all session tokens for an account. Useful after a password reset.
|
|
func sessionRemoveAll(ctx context.Context, log mlog.Log, tx *bstore.Tx, accountName string) error {
|
|
sessions.Lock()
|
|
defer sessions.Unlock()
|
|
|
|
if _, err := bstore.QueryTx[LoginSession](tx).Delete(); err != nil {
|
|
return err
|
|
}
|
|
|
|
sessions.accounts[accountName] = map[SessionToken]LoginSession{}
|
|
if sessions.pendingFlushes[accountName] != nil {
|
|
sessions.pendingFlushes[accountName] = map[SessionToken]struct{}{}
|
|
}
|
|
|
|
return nil
|
|
}
|