diff --git a/backup.go b/backup.go index 9d7bfdd..eaa18d1 100644 --- a/backup.go +++ b/backup.go @@ -313,7 +313,8 @@ func backupctl(ctx context.Context, ctl *ctl) { } dstdbpath := filepath.Join(dstDataDir, path) - db, err := bstore.Open(ctx, dstdbpath, &bstore.Options{MustExist: true}, queue.DBTypes...) + opts := bstore.Options{MustExist: true, RegisterLogger: ctl.log.Logger} + db, err := bstore.Open(ctx, dstdbpath, &opts, queue.DBTypes...) if err != nil { xerrx("open copied queue database", err, slog.String("dstpath", dstdbpath), slog.Duration("duration", time.Since(tmQueue))) return @@ -419,7 +420,8 @@ func backupctl(ctx context.Context, ctl *ctl) { } dstdbpath := filepath.Join(dstDataDir, dbpath) - db, err := bstore.Open(ctx, dstdbpath, &bstore.Options{MustExist: true}, store.DBTypes...) + opts := bstore.Options{MustExist: true, RegisterLogger: ctl.log.Logger} + db, err := bstore.Open(ctx, dstdbpath, &opts, store.DBTypes...) if err != nil { xerrx("open copied account database", err, slog.String("dstpath", dstdbpath), slog.Duration("duration", time.Since(tmAccount))) return diff --git a/dmarcdb/dmarcdb.go b/dmarcdb/dmarcdb.go index 76b957d..6d073d3 100644 --- a/dmarcdb/dmarcdb.go +++ b/dmarcdb/dmarcdb.go @@ -11,6 +11,15 @@ package dmarcdb import ( + "context" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/mjl-/bstore" + + "github.com/mjl-/mox/mlog" "github.com/mjl-/mox/mox-" ) @@ -19,11 +28,49 @@ import ( // The incoming reports and evaluations for outgoing reports are in separate // databases for simpler file-based handling of the databases. func Init() error { - if _, err := reportsDB(mox.Shutdown); err != nil { - return err + if ReportsDB != nil || EvalDB != nil { + return fmt.Errorf("already initialized") } - if _, err := evalDB(mox.Shutdown); err != nil { - return err + + log := mlog.New("dmarcdb", nil) + var err error + + ReportsDB, err = openReportsDB(mox.Shutdown, log) + if err != nil { + return fmt.Errorf("open reports db: %v", err) } + + EvalDB, err = openEvalDB(mox.Shutdown, log) + if err != nil { + return fmt.Errorf("open eval db: %v", err) + } + return nil } + +func Close() error { + if err := ReportsDB.Close(); err != nil { + return fmt.Errorf("closing reports db: %w", err) + } + ReportsDB = nil + + if err := EvalDB.Close(); err != nil { + return fmt.Errorf("closing eval db: %w", err) + } + EvalDB = nil + return nil +} + +func openReportsDB(ctx context.Context, log mlog.Log) (*bstore.DB, error) { + p := mox.DataDirPath("dmarcrpt.db") + os.MkdirAll(filepath.Dir(p), 0770) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger} + return bstore.Open(ctx, p, &opts, ReportsDBTypes...) +} + +func openEvalDB(ctx context.Context, log mlog.Log) (*bstore.DB, error) { + p := mox.DataDirPath("dmarceval.db") + os.MkdirAll(filepath.Dir(p), 0770) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger} + return bstore.Open(ctx, p, &opts, EvalDBTypes...) +} diff --git a/dmarcdb/eval.go b/dmarcdb/eval.go index 5874089..9b893e1 100644 --- a/dmarcdb/eval.go +++ b/dmarcdb/eval.go @@ -15,7 +15,6 @@ import ( "net/textproto" "net/url" "os" - "path/filepath" "runtime/debug" "slices" "sort" @@ -66,8 +65,7 @@ var ( // Exported for backups. For incoming deliveries the SMTP server adds evaluations // to the database. Every hour, a goroutine wakes up that gathers evaluations from // the last hour(s), sends a report, and removes the evaluations from the database. - EvalDB *bstore.DB - evalMutex sync.Mutex + EvalDB *bstore.DB ) // Evaluation is the result of an evaluation of a DMARC policy, to be included @@ -162,21 +160,6 @@ func (e Evaluation) ReportRecord(count int) dmarcrpt.ReportRecord { } } -func evalDB(ctx context.Context) (rdb *bstore.DB, rerr error) { - evalMutex.Lock() - defer evalMutex.Unlock() - if EvalDB == nil { - p := mox.DataDirPath("dmarceval.db") - os.MkdirAll(filepath.Dir(p), 0770) - db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, EvalDBTypes...) - if err != nil { - return nil, err - } - EvalDB = db - } - return EvalDB, nil -} - var intervalOpts = []int{24, 12, 8, 6, 4, 3, 2} func intervalHours(seconds int) int { @@ -197,23 +180,13 @@ func intervalHours(seconds int) int { func AddEvaluation(ctx context.Context, aggregateReportingIntervalSeconds int, e *Evaluation) error { e.IntervalHours = intervalHours(aggregateReportingIntervalSeconds) - db, err := evalDB(ctx) - if err != nil { - return err - } - e.ID = 0 - return db.Insert(ctx, e) + return EvalDB.Insert(ctx, e) } // Evaluations returns all evaluations in the database. func Evaluations(ctx context.Context) ([]Evaluation, error) { - db, err := evalDB(ctx) - if err != nil { - return nil, err - } - - q := bstore.QueryDB[Evaluation](ctx, db) + q := bstore.QueryDB[Evaluation](ctx, EvalDB) q.SortAsc("Evaluated") return q.List() } @@ -229,14 +202,9 @@ type EvaluationStat struct { // EvaluationStats returns evaluation counts and report-sending status per domain. func EvaluationStats(ctx context.Context) (map[string]EvaluationStat, error) { - db, err := evalDB(ctx) - if err != nil { - return nil, err - } - r := map[string]EvaluationStat{} - err = bstore.QueryDB[Evaluation](ctx, db).ForEach(func(e Evaluation) error { + err := bstore.QueryDB[Evaluation](ctx, EvalDB).ForEach(func(e Evaluation) error { if stat, ok := r[e.PolicyDomain]; ok { if !slices.Contains(stat.Dispositions, string(e.Disposition)) { stat.Dispositions = append(stat.Dispositions, string(e.Disposition)) @@ -263,12 +231,7 @@ func EvaluationStats(ctx context.Context) (map[string]EvaluationStat, error) { // EvaluationsDomain returns all evaluations for a domain. func EvaluationsDomain(ctx context.Context, domain dns.Domain) ([]Evaluation, error) { - db, err := evalDB(ctx) - if err != nil { - return nil, err - } - - q := bstore.QueryDB[Evaluation](ctx, db) + q := bstore.QueryDB[Evaluation](ctx, EvalDB) q.FilterNonzero(Evaluation{PolicyDomain: domain.Name()}) q.SortAsc("Evaluated") return q.List() @@ -277,14 +240,9 @@ func EvaluationsDomain(ctx context.Context, domain dns.Domain) ([]Evaluation, er // RemoveEvaluationsDomain removes evaluations for domain so they won't be sent in // an aggregate report. func RemoveEvaluationsDomain(ctx context.Context, domain dns.Domain) error { - db, err := evalDB(ctx) - if err != nil { - return err - } - - q := bstore.QueryDB[Evaluation](ctx, db) + q := bstore.QueryDB[Evaluation](ctx, EvalDB) q.FilterNonzero(Evaluation{PolicyDomain: domain.Name()}) - _, err = q.Delete() + _, err := q.Delete() return err } @@ -318,12 +276,6 @@ func Start(resolver dns.Resolver) { ctx := mox.Shutdown - db, err := evalDB(ctx) - if err != nil { - log.Errorx("opening dmarc evaluations database for sending dmarc aggregate reports, not sending reports", err) - return - } - for { now := time.Now() nextEnd := nextWholeHour(now) @@ -355,12 +307,12 @@ func Start(resolver dns.Resolver) { // 24 hour interval). They should have been processed by now. We may have kept them // during temporary errors, but persistent temporary errors shouldn't fill up our // database. This also cleans up evaluations that were all optional for a domain. - _, err := bstore.QueryDB[Evaluation](ctx, db).FilterLess("Evaluated", nextEnd.Add(-48*time.Hour)).Delete() + _, err := bstore.QueryDB[Evaluation](ctx, EvalDB).FilterLess("Evaluated", nextEnd.Add(-48*time.Hour)).Delete() log.Check(err, "removing stale dmarc evaluations from database") clog := log.WithCid(mox.Cid()) clog.Info("sending dmarc aggregate reports", slog.Time("end", nextEnd.UTC()), slog.Any("intervals", intervals)) - if err := sendReports(ctx, clog, resolver, db, nextEnd, intervals); err != nil { + if err := sendReports(ctx, clog, resolver, EvalDB, nextEnd, intervals); err != nil { clog.Errorx("sending dmarc aggregate reports", err) metricReportError.Inc() } else { @@ -1091,46 +1043,26 @@ func dkimSign(ctx context.Context, log mlog.Log, fromAddr smtp.Address, smtputf8 // SuppressAdd adds an address to the suppress list. func SuppressAdd(ctx context.Context, ba *SuppressAddress) error { - db, err := evalDB(ctx) - if err != nil { - return err - } - - return db.Insert(ctx, ba) + return EvalDB.Insert(ctx, ba) } // SuppressList returns all reporting addresses on the suppress list. func SuppressList(ctx context.Context) ([]SuppressAddress, error) { - db, err := evalDB(ctx) - if err != nil { - return nil, err - } - - return bstore.QueryDB[SuppressAddress](ctx, db).SortDesc("ID").List() + return bstore.QueryDB[SuppressAddress](ctx, EvalDB).SortDesc("ID").List() } // SuppressRemove removes a reporting address record from the suppress list. func SuppressRemove(ctx context.Context, id int64) error { - db, err := evalDB(ctx) - if err != nil { - return err - } - - return db.Delete(ctx, &SuppressAddress{ID: id}) + return EvalDB.Delete(ctx, &SuppressAddress{ID: id}) } // SuppressUpdate updates the until field of a reporting address record. func SuppressUpdate(ctx context.Context, id int64, until time.Time) error { - db, err := evalDB(ctx) - if err != nil { - return err - } - ba := SuppressAddress{ID: id} - err = db.Get(ctx, &ba) + err := EvalDB.Get(ctx, &ba) if err != nil { return err } ba.Until = until - return db.Update(ctx, &ba) + return EvalDB.Update(ctx, &ba) } diff --git a/dmarcdb/eval_test.go b/dmarcdb/eval_test.go index 2db8c9a..bb963c1 100644 --- a/dmarcdb/eval_test.go +++ b/dmarcdb/eval_test.go @@ -41,13 +41,13 @@ func TestEvaluations(t *testing.T) { mox.Context = ctxbg mox.ConfigStaticPath = filepath.FromSlash("../testdata/dmarcdb/mox.conf") mox.MustLoadConfig(true, false) - EvalDB = nil - _, err := evalDB(ctxbg) - tcheckf(t, err, "database") + os.Remove(mox.DataDirPath("dmarceval.db")) + err := Init() + tcheckf(t, err, "init") defer func() { - EvalDB.Close() - EvalDB = nil + err := Close() + tcheckf(t, err, "close") }() parseJSON := func(s string) (e Evaluation) { @@ -163,13 +163,13 @@ func TestSendReports(t *testing.T) { mox.Context = ctxbg mox.ConfigStaticPath = filepath.FromSlash("../testdata/dmarcdb/mox.conf") mox.MustLoadConfig(true, false) - EvalDB = nil - db, err := evalDB(ctxbg) - tcheckf(t, err, "database") + os.Remove(mox.DataDirPath("dmarceval.db")) + err := Init() + tcheckf(t, err, "init") defer func() { - EvalDB.Close() - EvalDB = nil + err := Close() + tcheckf(t, err, "close") }() resolver := dns.MockResolver{ @@ -288,7 +288,7 @@ func TestSendReports(t *testing.T) { mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg) for _, e := range evals { - err := db.Insert(ctxbg, &e) + err := EvalDB.Insert(ctxbg, &e) tcheckf(t, err, "inserting evaluation") } @@ -359,13 +359,13 @@ func TestSendReports(t *testing.T) { // Address is suppressed. sa := SuppressAddress{ReportingAddress: "dmarcrpt@sender.example", Until: time.Now().Add(time.Minute)} - err = db.Insert(ctxbg, &sa) + err = EvalDB.Insert(ctxbg, &sa) tcheckf(t, err, "insert suppress address") test([]Evaluation{eval}, map[string]struct{}{}, map[string]struct{}{}, nil) // Suppression has expired. sa.Until = time.Now().Add(-time.Minute) - err = db.Update(ctxbg, &sa) + err = EvalDB.Update(ctxbg, &sa) tcheckf(t, err, "update suppress address") test([]Evaluation{eval}, map[string]struct{}{"dmarcrpt@sender.example": {}}, map[string]struct{}{}, expFeedback) diff --git a/dmarcdb/reports.go b/dmarcdb/reports.go index f7281ba..2746051 100644 --- a/dmarcdb/reports.go +++ b/dmarcdb/reports.go @@ -3,9 +3,6 @@ package dmarcdb import ( "context" "fmt" - "os" - "path/filepath" - "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -15,13 +12,11 @@ import ( "github.com/mjl-/mox/dmarcrpt" "github.com/mjl-/mox/dns" - "github.com/mjl-/mox/mox-" ) var ( ReportsDBTypes = []any{DomainFeedback{}} // Types stored in DB. ReportsDB *bstore.DB // Exported for backups. - reportsMutex sync.Mutex ) var ( @@ -59,38 +54,18 @@ type DomainFeedback struct { dmarcrpt.Feedback } -func reportsDB(ctx context.Context) (rdb *bstore.DB, rerr error) { - reportsMutex.Lock() - defer reportsMutex.Unlock() - if ReportsDB == nil { - p := mox.DataDirPath("dmarcrpt.db") - os.MkdirAll(filepath.Dir(p), 0770) - db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, ReportsDBTypes...) - if err != nil { - return nil, err - } - ReportsDB = db - } - return ReportsDB, nil -} - // AddReport adds a DMARC aggregate feedback report from an email to the database, // and updates prometheus metrics. // // fromDomain is the domain in the report message From header. func AddReport(ctx context.Context, f *dmarcrpt.Feedback, fromDomain dns.Domain) error { - db, err := reportsDB(ctx) - if err != nil { - return err - } - d, err := dns.ParseDomain(f.PolicyPublished.Domain) if err != nil { return fmt.Errorf("parsing domain in report: %v", err) } df := DomainFeedback{0, d.Name(), fromDomain.Name(), *f} - if err := db.Insert(ctx, &df); err != nil { + if err := ReportsDB.Insert(ctx, &df); err != nil { return err } @@ -129,38 +104,23 @@ func AddReport(ctx context.Context, f *dmarcrpt.Feedback, fromDomain dns.Domain) // Records returns all reports in the database. func Records(ctx context.Context) ([]DomainFeedback, error) { - db, err := reportsDB(ctx) - if err != nil { - return nil, err - } - - return bstore.QueryDB[DomainFeedback](ctx, db).List() + return bstore.QueryDB[DomainFeedback](ctx, ReportsDB).List() } // RecordID returns the report for the ID. func RecordID(ctx context.Context, id int64) (DomainFeedback, error) { - db, err := reportsDB(ctx) - if err != nil { - return DomainFeedback{}, err - } - e := DomainFeedback{ID: id} - err = db.Get(ctx, &e) + err := ReportsDB.Get(ctx, &e) return e, err } // RecordsPeriodDomain returns the reports overlapping start and end, for the given // domain. If domain is empty, all records match for domain. func RecordsPeriodDomain(ctx context.Context, start, end time.Time, domain string) ([]DomainFeedback, error) { - db, err := reportsDB(ctx) - if err != nil { - return nil, err - } - s := start.Unix() e := end.Unix() - q := bstore.QueryDB[DomainFeedback](ctx, db) + q := bstore.QueryDB[DomainFeedback](ctx, ReportsDB) if domain != "" { q.FilterNonzero(DomainFeedback{Domain: domain}) } diff --git a/dmarcdb/reports_test.go b/dmarcdb/reports_test.go index 1e86a39..16f43c7 100644 --- a/dmarcdb/reports_test.go +++ b/dmarcdb/reports_test.go @@ -20,16 +20,12 @@ func TestDMARCDB(t *testing.T) { mox.ConfigStaticPath = filepath.FromSlash("../testdata/dmarcdb/mox.conf") mox.MustLoadConfig(true, false) - dbpath := mox.DataDirPath("dmarcrpt.db") - os.MkdirAll(filepath.Dir(dbpath), 0770) - - if err := Init(); err != nil { - t.Fatalf("init database: %s", err) - } - defer os.Remove(dbpath) + os.Remove(mox.DataDirPath("dmarcrpt.db")) + err := Init() + tcheckf(t, err, "init") defer func() { - ReportsDB.Close() - ReportsDB = nil + err := Close() + tcheckf(t, err, "close") }() feedback := &dmarcrpt.Feedback{ diff --git a/export.go b/export.go index 1be5c5f..1e250b4 100644 --- a/export.go +++ b/export.go @@ -62,7 +62,8 @@ func xcmdExport(mbox, single bool, args []string, c *cmd) { } dbpath := filepath.Join(accountDir, "index.db") - db, err := bstore.Open(context.Background(), dbpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, store.DBTypes...) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: c.log.Logger} + db, err := bstore.Open(context.Background(), dbpath, &opts, store.DBTypes...) xcheckf(err, "open database %q", dbpath) defer func() { if err := db.Close(); err != nil { diff --git a/go.mod b/go.mod index 7a71dd9..54cf64b 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21 require ( github.com/mjl-/adns v0.0.0-20240509092456-2dc8715bf4af github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05 - github.com/mjl-/bstore v0.0.5 + github.com/mjl-/bstore v0.0.6 github.com/mjl-/sconf v0.0.6 github.com/mjl-/sherpa v0.6.7 github.com/mjl-/sherpadoc v0.0.16 diff --git a/go.sum b/go.sum index 3fccb33..5d0624a 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,8 @@ github.com/mjl-/adns v0.0.0-20240509092456-2dc8715bf4af h1:sEDWZPIi5K1qKk7JQoAZy github.com/mjl-/adns v0.0.0-20240509092456-2dc8715bf4af/go.mod h1:v47qUMJnipnmDTRGaHwpCwzE6oypa5K33mUvBfzZBn8= github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05 h1:s6ay4bh4tmpPLdxjyeWG45mcwHfEluBMuGPkqxHWUJ4= github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05/go.mod h1:taMFU86abMxKLPV4Bynhv8enbYmS67b8LG80qZv2Qus= -github.com/mjl-/bstore v0.0.5 h1:Cx+LWEBnFBsqSxZNMxeVujkfc0kG10lUJaAU4vWSRHo= -github.com/mjl-/bstore v0.0.5/go.mod h1:/cD25FNBaDfvL/plFRxI3Ba3E+wcB0XVOS8nJDqndg0= +github.com/mjl-/bstore v0.0.6 h1:ntlu9MkfCkpm2XfBY4+Ws4KK9YzXzewr3+lCueFB+9c= +github.com/mjl-/bstore v0.0.6/go.mod h1:/cD25FNBaDfvL/plFRxI3Ba3E+wcB0XVOS8nJDqndg0= github.com/mjl-/sconf v0.0.6 h1:5Dt58488ZOoVx680zgK2K3vUrokLsp5mXDUACrJlrUc= github.com/mjl-/sconf v0.0.6/go.mod h1:uF8OdWtLT8La3i4ln176i1pB0ps9pXGCaABEU55ZkE0= github.com/mjl-/sherpa v0.6.7 h1:C5F8XQdV5nCuS4fvB+ye/ziUQrajEhOoj/t2w5T14BY= diff --git a/junk/filter.go b/junk/filter.go index 1de8059..48f8ca3 100644 --- a/junk/filter.go +++ b/junk/filter.go @@ -125,7 +125,7 @@ func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomP } } - db, err := openDB(ctx, dbPath) + db, err := openDB(ctx, log, dbPath) if err != nil { return nil, fmt.Errorf("open database: %s", err) } @@ -230,18 +230,20 @@ func newDB(ctx context.Context, log mlog.Log, path string) (db *bstore.DB, rerr } }() - db, err := bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger} + db, err := bstore.Open(ctx, path, &opts, DBTypes...) if err != nil { return nil, fmt.Errorf("open new database: %w", err) } return db, nil } -func openDB(ctx context.Context, path string) (*bstore.DB, error) { +func openDB(ctx context.Context, log mlog.Log, path string) (*bstore.DB, error) { if _, err := os.Stat(path); err != nil { return nil, fmt.Errorf("stat db file: %w", err) } - return bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger} + return bstore.Open(ctx, path, &opts, DBTypes...) } // Save stores modifications, e.g. from training, to the database and bloom diff --git a/mtastsdb/db.go b/mtastsdb/db.go index 9a13cd8..928d8d7 100644 --- a/mtastsdb/db.go +++ b/mtastsdb/db.go @@ -14,7 +14,6 @@ import ( "os" "path/filepath" "strings" - "sync" "time" "github.com/prometheus/client_golang/prometheus" @@ -68,27 +67,17 @@ var ( var DBTypes = []any{PolicyRecord{}} // Types stored in DB. var DB *bstore.DB // Exported for backups. -var mutex sync.Mutex - -func database(ctx context.Context) (rdb *bstore.DB, rerr error) { - mutex.Lock() - defer mutex.Unlock() - if DB == nil { - p := mox.DataDirPath("mtasts.db") - os.MkdirAll(filepath.Dir(p), 0770) - db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) - if err != nil { - return nil, err - } - DB = db - } - return DB, nil -} // Init opens the database and starts a goroutine that refreshes policies in // the database, and keeps doing so periodically. func Init(refresher bool) error { - _, err := database(mox.Shutdown) + log := mlog.New("mtastsdb", nil) + + p := mox.DataDirPath("mtasts.db") + os.MkdirAll(filepath.Dir(p), 0770) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger} + var err error + DB, err = bstore.Open(mox.Shutdown, p, &opts, DBTypes...) if err != nil { return err } @@ -102,14 +91,12 @@ func Init(refresher bool) error { } // Close closes the database. -func Close() { - mutex.Lock() - defer mutex.Unlock() - if DB != nil { - err := DB.Close() - mlog.New("mtastsdb", nil).Check(err, "closing database") - DB = nil +func Close() error { + if err := DB.Close(); err != nil { + return fmt.Errorf("close db: %w", err) } + DB = nil + return nil } // lookup looks up a policy for the domain in the database. @@ -119,16 +106,11 @@ func Close() { // Returns ErrNotFound if record is not present. // Returns ErrBackoff if a recent attempt to fetch a record failed. func lookup(ctx context.Context, log mlog.Log, domain dns.Domain) (*PolicyRecord, error) { - db, err := database(ctx) - if err != nil { - return nil, err - } - if domain.IsZero() { return nil, fmt.Errorf("empty domain") } now := timeNow() - q := bstore.QueryDB[PolicyRecord](ctx, db) + q := bstore.QueryDB[PolicyRecord](ctx, DB) q.FilterNonzero(PolicyRecord{Domain: domain.Name()}) q.FilterGreater("ValidEnd", now) pr, err := q.Get() @@ -139,7 +121,7 @@ func lookup(ctx context.Context, log mlog.Log, domain dns.Domain) (*PolicyRecord } pr.LastUse = now - if err := db.Update(ctx, &pr); err != nil { + if err := DB.Update(ctx, &pr); err != nil { log.Errorx("marking cached mta-sts policy as used in database", err) } if pr.Backoff { @@ -151,12 +133,7 @@ func lookup(ctx context.Context, log mlog.Log, domain dns.Domain) (*PolicyRecord // Upsert adds the policy to the database, overwriting an existing policy for the domain. // Policy can be nil, indicating a failure to fetch the policy. func Upsert(ctx context.Context, domain dns.Domain, recordID string, policy *mtasts.Policy, policyText string) error { - db, err := database(ctx) - if err != nil { - return err - } - - return db.Write(ctx, func(tx *bstore.Tx) error { + return DB.Write(ctx, func(tx *bstore.Tx) error { pr := PolicyRecord{Domain: domain.Name()} err := tx.Get(&pr) if err != nil && err != bstore.ErrAbsent { @@ -195,11 +172,7 @@ func Upsert(ctx context.Context, domain dns.Domain, recordID string, policy *mta // PolicyRecords returns all policies in the database, sorted descending by last // use, domain. func PolicyRecords(ctx context.Context) ([]PolicyRecord, error) { - db, err := database(ctx) - if err != nil { - return nil, err - } - return bstore.QueryDB[PolicyRecord](ctx, db).SortDesc("LastUse", "Domain").List() + return bstore.QueryDB[PolicyRecord](ctx, DB).SortDesc("LastUse", "Domain").List() } // Get retrieves an MTA-STS policy for domain and whether it is fresh. diff --git a/mtastsdb/refresh.go b/mtastsdb/refresh.go index d61df91..d4c7950 100644 --- a/mtastsdb/refresh.go +++ b/mtastsdb/refresh.go @@ -51,19 +51,14 @@ func refresh() int { // jitter to the timing. Each refresh is done in a new goroutine, so a single slow // refresh doesn't mess up the timing. func refresh1(ctx context.Context, log mlog.Log, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) { - db, err := database(ctx) - if err != nil { - return 0, err - } - now := timeNow() - qdel := bstore.QueryDB[PolicyRecord](ctx, db) + qdel := bstore.QueryDB[PolicyRecord](ctx, DB) qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour)) if _, err := qdel.Delete(); err != nil { return 0, fmt.Errorf("deleting old unused policies: %s", err) } - qup := bstore.QueryDB[PolicyRecord](ctx, db) + qup := bstore.QueryDB[PolicyRecord](ctx, DB) qup.FilterLess("LastUpdate", now.Add(-12*time.Hour)) prs, err := qup.List() if err != nil { @@ -89,7 +84,7 @@ func refresh1(ctx context.Context, log mlog.Log, resolver dns.Resolver, sleep fu log.Debug("will refresh mta-sts policies over next 3 hours", slog.Int("count", len(prs))) start := timeNow() for i, pr := range prs { - go refreshDomain(ctx, log, db, resolver, pr) + go refreshDomain(ctx, log, DB, resolver, pr) if i < len(prs)-1 { interval := 3 * int64(time.Hour) / int64(len(prs)-1) extra := time.Duration(rand.Int63n(interval) - interval/2) diff --git a/mtastsdb/refresh_test.go b/mtastsdb/refresh_test.go index 99428f3..64a353a 100644 --- a/mtastsdb/refresh_test.go +++ b/mtastsdb/refresh_test.go @@ -9,7 +9,7 @@ import ( "errors" "fmt" "io" - "log" + golog "log" "math/big" "net" "net/http" @@ -39,15 +39,14 @@ func TestRefresh(t *testing.T) { os.Remove(dbpath) defer os.Remove(dbpath) - if err := Init(false); err != nil { - t.Fatalf("init database: %s", err) - } - defer Close() + log := mlog.New("mtastsdb", nil) - db, err := database(ctxbg) - if err != nil { - t.Fatalf("database: %s", err) - } + err := Init(false) + tcheckf(t, err, "init database") + defer func() { + err := Close() + tcheckf(t, err, "close database") + }() cert := fakeCert(t, false) defer func() { @@ -70,7 +69,7 @@ func TestRefresh(t *testing.T) { } pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy, policy.String()} - if err := db.Insert(ctxbg, &pr); err != nil { + if err := DB.Insert(ctxbg, &pr); err != nil { t.Fatalf("insert policy: %s", err) } } @@ -114,7 +113,7 @@ func TestRefresh(t *testing.T) { TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, }, - ErrorLog: log.New(io.Discard, "", 0), + ErrorLog: golog.New(io.Discard, "", 0), } s.ServeTLS(l, "", "") }() @@ -136,7 +135,6 @@ func TestRefresh(t *testing.T) { t.Fatalf("bad sleep duration %v", d) } } - log := mlog.New("mtastsdb", nil) if n, err := refresh1(ctxbg, log, resolver, sleep); err != nil || n != 3 { t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n) } @@ -146,7 +144,7 @@ func TestRefresh(t *testing.T) { time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database. // Should not do any more refreshes and return immediately. - q := bstore.QueryDB[PolicyRecord](ctxbg, db) + q := bstore.QueryDB[PolicyRecord](ctxbg, DB) q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"}) if _, err := q.Delete(); err != nil { t.Fatalf("delete record that would be refreshed: %v", err) diff --git a/queue/hook_test.go b/queue/hook_test.go index 4534691..c5d7c36 100644 --- a/queue/hook_test.go +++ b/queue/hook_test.go @@ -24,8 +24,6 @@ import ( func TestHookIncoming(t *testing.T) { acc, cleanup := setup(t) defer cleanup() - err := Init() - tcheck(t, err, "queue init") accret, err := store.OpenAccount(pkglog, "retired") tcheck(t, err, "open account for retired") @@ -119,8 +117,6 @@ func TestHookIncoming(t *testing.T) { func TestFromIDIncomingDelivery(t *testing.T) { acc, cleanup := setup(t) defer cleanup() - err := Init() - tcheck(t, err, "queue init") accret, err := store.OpenAccount(pkglog, "retired") tcheck(t, err, "open account for retired") @@ -525,8 +521,6 @@ func TestFromIDIncomingDelivery(t *testing.T) { func TestHookListFilterSort(t *testing.T) { _, cleanup := setup(t) defer cleanup() - err := Init() - tcheck(t, err, "queue init") now := time.Now().Round(0) h := Hook{0, 0, "fromid", "messageid", "subj", nil, "mjl", "http://localhost", "", false, "delivered", "", now, 0, now, []HookResult{}} @@ -534,7 +528,7 @@ func TestHookListFilterSort(t *testing.T) { h1.Submitted = now.Add(-time.Second) h1.NextAttempt = now.Add(time.Minute) hl := []Hook{h, h, h, h, h, h1} - err = DB.Write(ctxbg, func(tx *bstore.Tx) error { + err := DB.Write(ctxbg, func(tx *bstore.Tx) error { for i := range hl { err := hookInsert(tx, &hl[i], now, time.Minute) tcheck(t, err, "insert hook") diff --git a/queue/queue.go b/queue/queue.go index 3cbbb56..76b9af6 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -350,7 +350,9 @@ func Init() error { } var err error - DB, err = bstore.Open(mox.Shutdown, qpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) + log := mlog.New("queue", nil) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger} + DB, err = bstore.Open(mox.Shutdown, qpath, &opts, DBTypes...) if err == nil { err = DB.Read(mox.Shutdown, func(tx *bstore.Tx) error { return metricHoldUpdate(tx) diff --git a/queue/queue_test.go b/queue/queue_test.go index e2f1a63..3474ee8 100644 --- a/queue/queue_test.go +++ b/queue/queue_test.go @@ -28,6 +28,7 @@ import ( "github.com/mjl-/mox/dns" "github.com/mjl-/mox/mlog" "github.com/mjl-/mox/mox-" + "github.com/mjl-/mox/mtastsdb" "github.com/mjl-/mox/smtp" "github.com/mjl-/mox/smtpclient" "github.com/mjl-/mox/store" @@ -60,6 +61,12 @@ func setup(t *testing.T) (*store.Account, func()) { mox.Context = ctxbg mox.ConfigStaticPath = filepath.FromSlash("../testdata/queue/mox.conf") mox.MustLoadConfig(true, false) + err := Init() + tcheck(t, err, "queue init") + err = mtastsdb.Init(false) + tcheck(t, err, "mtastsdb init") + err = tlsrptdb.Init() + tcheck(t, err, "tlsrptdb init") acc, err := store.OpenAccount(log, "mjl") tcheck(t, err, "open account") err = acc.SetPassword(log, "testtest") @@ -72,6 +79,10 @@ func setup(t *testing.T) (*store.Account, func()) { mox.ShutdownCancel() mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg) Shutdown() + err := mtastsdb.Close() + tcheck(t, err, "mtastsdb close") + err = tlsrptdb.Close() + tcheck(t, err, "tlsrptdb close") switchStop() } } @@ -95,8 +106,6 @@ func prepareFile(t *testing.T) *os.File { func TestQueue(t *testing.T) { acc, cleanup := setup(t) defer cleanup() - err := Init() - tcheck(t, err, "queue init") idfilter := func(msgID int64) Filter { return Filter{IDs: []int64{msgID}} @@ -951,8 +960,6 @@ func checkTLSResults(t *testing.T, policyDomain, expRecipientDomain string, expI func TestRetiredHooks(t *testing.T) { _, cleanup := setup(t) defer cleanup() - err := Init() - tcheck(t, err, "queue init") addr, err := smtp.ParseAddress("mjl@mox.example") tcheck(t, err, "parse address") @@ -1193,6 +1200,7 @@ func TestQueueStart(t *testing.T) { <-done mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg) }() + Shutdown() // DB was opened already. Start will open it again. Just close it before. err := Start(resolver, done) tcheck(t, err, "queue start") @@ -1284,8 +1292,6 @@ func TestQueueStart(t *testing.T) { func TestListFilterSort(t *testing.T) { _, cleanup := setup(t) defer cleanup() - err := Init() - tcheck(t, err, "queue init") // insert Msgs. insert RetiredMsgs based on that. call list with filters and sort. filter to select a single. filter to paginate one by one, and in reverse. @@ -1301,7 +1307,7 @@ func TestListFilterSort(t *testing.T) { qm1.Queued = now.Add(-time.Second) qm1.NextAttempt = now.Add(time.Minute) qml := []Msg{qm, qm, qm, qm, qm, qm1} - err = Add(ctxbg, pkglog, "mjl", mf, qml...) + err := Add(ctxbg, pkglog, "mjl", mf, qml...) tcheck(t, err, "add messages to queue") qm1 = qml[len(qml)-1] diff --git a/queue/suppression_test.go b/queue/suppression_test.go index 2ac52d8..7b77198 100644 --- a/queue/suppression_test.go +++ b/queue/suppression_test.go @@ -10,8 +10,6 @@ import ( func TestSuppression(t *testing.T) { _, cleanup := setup(t) defer cleanup() - err := Init() - tcheck(t, err, "queue init") l, err := SuppressionList(ctxbg, "bogus") tcheck(t, err, "listing suppressions for unknown account") diff --git a/serve.go b/serve.go index ac7ea8c..aef5406 100644 --- a/serve.go +++ b/serve.go @@ -71,11 +71,15 @@ func start(mtastsdbRefresher, sendDMARCReports, sendTLSReports, skipForkExec boo } if err := mtastsdb.Init(mtastsdbRefresher); err != nil { - return fmt.Errorf("mtasts init: %s", err) + return fmt.Errorf("mtastsdb init: %s", err) } if err := tlsrptdb.Init(); err != nil { - return fmt.Errorf("tlsrpt init: %s", err) + return fmt.Errorf("tlsrptdb init: %s", err) + } + + if err := dmarcdb.Init(); err != nil { + return fmt.Errorf("dmarcdb init: %s", err) } done := make(chan struct{}) // Goroutines for messages and webhooks, and cleaners. @@ -83,10 +87,6 @@ func start(mtastsdbRefresher, sendDMARCReports, sendTLSReports, skipForkExec boo return fmt.Errorf("queue start: %s", err) } - // dmarcdb starts after queue because it may start sending reports through the queue. - if err := dmarcdb.Init(); err != nil { - return fmt.Errorf("dmarc init: %s", err) - } if sendDMARCReports { dmarcdb.Start(dns.StrictResolver{Pkg: "dmarcdb"}) } diff --git a/smtpserver/reputation_test.go b/smtpserver/reputation_test.go index 49e97da..0b85783 100644 --- a/smtpserver/reputation_test.go +++ b/smtpserver/reputation_test.go @@ -108,7 +108,8 @@ func TestReputation(t *testing.T) { p := filepath.FromSlash("../testdata/smtpserver-reputation.db") defer os.Remove(p) - db, err := bstore.Open(ctxbg, p, &bstore.Options{Timeout: 5 * time.Second}, store.DBTypes...) + opts := bstore.Options{Timeout: 5 * time.Second, RegisterLogger: log.Logger} + db, err := bstore.Open(ctxbg, p, &opts, store.DBTypes...) tcheck(t, err, "open db") defer db.Close() diff --git a/smtpserver/server_test.go b/smtpserver/server_test.go index db48b89..0d051bf 100644 --- a/smtpserver/server_test.go +++ b/smtpserver/server_test.go @@ -105,11 +105,6 @@ func newTestServer(t *testing.T, configPath string, resolver dns.Resolver) *test ts := testserver{t: t, cid: 1, resolver: resolver, tlsmode: smtpclient.TLSOpportunistic} - if dmarcdb.EvalDB != nil { - dmarcdb.EvalDB.Close() - dmarcdb.EvalDB = nil - } - log := mlog.New("smtpserver", nil) mox.Context = ctxbg mox.ConfigStaticPath = configPath @@ -117,7 +112,11 @@ func newTestServer(t *testing.T, configPath string, resolver dns.Resolver) *test dataDir := mox.ConfigDirPath(mox.Conf.Static.DataDir) os.RemoveAll(dataDir) - var err error + err := dmarcdb.Init() + tcheck(t, err, "dmarcdb init") + err = tlsrptdb.Init() + tcheck(t, err, "tlsrptdb init") + ts.acc, err = store.OpenAccount(log, "mjl") tcheck(t, err, "open account") err = ts.acc.SetPassword(log, password0) @@ -136,10 +135,14 @@ func (ts *testserver) close() { if ts.acc == nil { return } + err := dmarcdb.Close() + tcheck(ts.t, err, "dmarcdb close") + err = tlsrptdb.Close() + tcheck(ts.t, err, "tlsrptdb close") ts.comm.Unregister() queue.Shutdown() ts.switchStop() - err := ts.acc.Close() + err = ts.acc.Close() tcheck(ts.t, err, "closing account") ts.acc.CheckClosed() ts.acc = nil diff --git a/store/account.go b/store/account.go index b2577d2..081f13f 100644 --- a/store/account.go +++ b/store/account.go @@ -907,7 +907,8 @@ func OpenAccountDB(log mlog.Log, accountDir, accountName string) (a *Account, re os.MkdirAll(accountDir, 0770) } - db, err := bstore.Open(context.TODO(), dbpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger} + db, err := bstore.Open(context.TODO(), dbpath, &opts, DBTypes...) if err != nil { return nil, err } diff --git a/store/threads_test.go b/store/threads_test.go index a273f6f..8c1f9ce 100644 --- a/store/threads_test.go +++ b/store/threads_test.go @@ -121,7 +121,8 @@ func TestThreadingUpgrade(t *testing.T) { // Now clear the threading upgrade, and the threading fields and close the account. // We open the database file directly, so we don't trigger the consistency checker. - db, err := bstore.Open(ctxbg, dbpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger} + db, err := bstore.Open(ctxbg, dbpath, &opts, DBTypes...) err = db.Write(ctxbg, func(tx *bstore.Tx) error { up := Upgrade{ID: 1} err := tx.Delete(&up) diff --git a/tlsrptdb/db.go b/tlsrptdb/db.go index 4c3b42f..b2e22c9 100644 --- a/tlsrptdb/db.go +++ b/tlsrptdb/db.go @@ -1,7 +1,11 @@ package tlsrptdb import ( - "sync" + "context" + "fmt" + "os" + "path/filepath" + "time" "github.com/mjl-/bstore" @@ -12,7 +16,6 @@ import ( var ( ReportDBTypes = []any{Record{}} ReportDB *bstore.DB - mutex sync.Mutex // Accessed directly by tlsrptsend. ResultDBTypes = []any{TLSResult{}, SuppressAddress{}} @@ -21,29 +24,48 @@ var ( // Init opens and possibly initializes the databases. func Init() error { - if _, err := reportDB(mox.Shutdown); err != nil { - return err + if ReportDB != nil || ResultDB != nil { + return fmt.Errorf("already initialized") } - if _, err := resultDB(mox.Shutdown); err != nil { - return err + + log := mlog.New("tlsrptdb", nil) + var err error + + ReportDB, err = openReportDB(mox.Shutdown, log) + if err != nil { + return fmt.Errorf("opening report db: %v", err) + } + ResultDB, err = openResultDB(mox.Shutdown, log) + if err != nil { + return fmt.Errorf("opening result db: %v", err) } return nil } -// Close closes the database connections. -func Close() { - log := mlog.New("tlsrptdb", nil) - if ResultDB != nil { - err := ResultDB.Close() - log.Check(err, "closing result database") - ResultDB = nil - } - - mutex.Lock() - defer mutex.Unlock() - if ReportDB != nil { - err := ReportDB.Close() - log.Check(err, "closing report database") - ReportDB = nil - } +func openReportDB(ctx context.Context, log mlog.Log) (*bstore.DB, error) { + p := mox.DataDirPath("tlsrpt.db") + os.MkdirAll(filepath.Dir(p), 0770) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger} + return bstore.Open(ctx, p, &opts, ReportDBTypes...) +} + +func openResultDB(ctx context.Context, log mlog.Log) (*bstore.DB, error) { + p := mox.DataDirPath("tlsrptresult.db") + os.MkdirAll(filepath.Dir(p), 0770) + opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger} + return bstore.Open(ctx, p, &opts, ResultDBTypes...) +} + +// Close closes the database connections. +func Close() error { + if err := ResultDB.Close(); err != nil { + return fmt.Errorf("closing result db: %w", err) + } + ResultDB = nil + + if err := ReportDB.Close(); err != nil { + return fmt.Errorf("closing report db: %w", err) + } + ReportDB = nil + return nil } diff --git a/tlsrptdb/report.go b/tlsrptdb/report.go index 6c3dd63..3ca8847 100644 --- a/tlsrptdb/report.go +++ b/tlsrptdb/report.go @@ -5,8 +5,6 @@ import ( "context" "fmt" "log/slog" - "os" - "path/filepath" "time" "github.com/prometheus/client_golang/prometheus" @@ -55,21 +53,6 @@ type Record struct { Report tlsrpt.Report } -func reportDB(ctx context.Context) (rdb *bstore.DB, rerr error) { - mutex.Lock() - defer mutex.Unlock() - if ReportDB == nil { - p := mox.DataDirPath("tlsrpt.db") - os.MkdirAll(filepath.Dir(p), 0770) - db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, ReportDBTypes...) - if err != nil { - return nil, err - } - ReportDB = db - } - return ReportDB, nil -} - // AddReport adds a TLS report to the database. // // The report should have come in over SMTP, with a DKIM-validated @@ -82,17 +65,12 @@ func reportDB(ctx context.Context) (rdb *bstore.DB, rerr error) { // // Prometheus metrics are updated only for configured domains. func AddReport(ctx context.Context, log mlog.Log, verifiedFromDomain dns.Domain, mailFrom string, hostReport bool, r *tlsrpt.Report) error { - db, err := reportDB(ctx) - if err != nil { - return err - } - if len(r.Policies) == 0 { return fmt.Errorf("no policies in report") } var inserted int - return db.Write(ctx, func(tx *bstore.Tx) error { + return ReportDB.Write(ctx, func(tx *bstore.Tx) error { for _, p := range r.Policies { pp := p.Policy @@ -132,22 +110,13 @@ func AddReport(ctx context.Context, log mlog.Log, verifiedFromDomain dns.Domain, // Records returns all TLS reports in the database. func Records(ctx context.Context) ([]Record, error) { - db, err := reportDB(ctx) - if err != nil { - return nil, err - } - return bstore.QueryDB[Record](ctx, db).List() + return bstore.QueryDB[Record](ctx, ReportDB).List() } // RecordID returns the report for the ID. func RecordID(ctx context.Context, id int64) (Record, error) { - db, err := reportDB(ctx) - if err != nil { - return Record{}, err - } - e := Record{ID: id} - err = db.Get(ctx, &e) + err := ReportDB.Get(ctx, &e) return e, err } @@ -155,12 +124,7 @@ func RecordID(ctx context.Context, id int64) (Record, error) { // given policy domain. If policy domain is empty, records for all domains are // returned. func RecordsPeriodDomain(ctx context.Context, start, end time.Time, policyDomain dns.Domain) ([]Record, error) { - db, err := reportDB(ctx) - if err != nil { - return nil, err - } - - q := bstore.QueryDB[Record](ctx, db) + q := bstore.QueryDB[Record](ctx, ReportDB) var zerodom dns.Domain if policyDomain != zerodom { q.FilterNonzero(Record{Domain: policyDomain.Name()}) diff --git a/tlsrptdb/result.go b/tlsrptdb/result.go index e4957d7..32a3cc1 100644 --- a/tlsrptdb/result.go +++ b/tlsrptdb/result.go @@ -3,14 +3,11 @@ package tlsrptdb import ( "context" "fmt" - "os" - "path/filepath" "time" "github.com/mjl-/bstore" "github.com/mjl-/mox/dns" - "github.com/mjl-/mox/mox-" "github.com/mjl-/mox/tlsrpt" ) @@ -70,33 +67,13 @@ type SuppressAddress struct { Comment string } -func resultDB(ctx context.Context) (rdb *bstore.DB, rerr error) { - mutex.Lock() - defer mutex.Unlock() - if ResultDB == nil { - p := mox.DataDirPath("tlsrptresult.db") - os.MkdirAll(filepath.Dir(p), 0770) - db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, ResultDBTypes...) - if err != nil { - return nil, err - } - ResultDB = db - } - return ResultDB, nil -} - // AddTLSResults adds or merges all tls results for delivering to a policy domain, // on its UTC day to a recipient domain to the database. Results may cause multiple // separate reports to be sent. func AddTLSResults(ctx context.Context, results []TLSResult) error { - db, err := resultDB(ctx) - if err != nil { - return err - } - now := time.Now() - err = db.Write(ctx, func(tx *bstore.Tx) error { + err := ResultDB.Write(ctx, func(tx *bstore.Tx) error { for _, result := range results { // Ensure all slices are non-nil. We do this now so all readers will marshal to // compliant with the JSON schema. And also for consistent equality checks when @@ -148,102 +125,57 @@ func AddTLSResults(ctx context.Context, results []TLSResult) error { // Results returns all TLS results in the database, for all policy domains each // with potentially multiple days. Sorted by RecipientDomain and day. func Results(ctx context.Context) ([]TLSResult, error) { - db, err := resultDB(ctx) - if err != nil { - return nil, err - } - - return bstore.QueryDB[TLSResult](ctx, db).SortAsc("PolicyDomain", "DayUTC", "RecipientDomain").List() + return bstore.QueryDB[TLSResult](ctx, ResultDB).SortAsc("PolicyDomain", "DayUTC", "RecipientDomain").List() } // ResultsDomain returns all TLSResults for a policy domain, potentially for // multiple days. func ResultsPolicyDomain(ctx context.Context, policyDomain dns.Domain) ([]TLSResult, error) { - db, err := resultDB(ctx) - if err != nil { - return nil, err - } - - return bstore.QueryDB[TLSResult](ctx, db).FilterNonzero(TLSResult{PolicyDomain: policyDomain.Name()}).SortAsc("DayUTC", "RecipientDomain").List() + return bstore.QueryDB[TLSResult](ctx, ResultDB).FilterNonzero(TLSResult{PolicyDomain: policyDomain.Name()}).SortAsc("DayUTC", "RecipientDomain").List() } // ResultsRecipientDomain returns all TLSResults for a recipient domain, // potentially for multiple days. func ResultsRecipientDomain(ctx context.Context, recipientDomain dns.Domain) ([]TLSResult, error) { - db, err := resultDB(ctx) - if err != nil { - return nil, err - } - - return bstore.QueryDB[TLSResult](ctx, db).FilterNonzero(TLSResult{RecipientDomain: recipientDomain.Name()}).SortAsc("DayUTC", "PolicyDomain").List() + return bstore.QueryDB[TLSResult](ctx, ResultDB).FilterNonzero(TLSResult{RecipientDomain: recipientDomain.Name()}).SortAsc("DayUTC", "PolicyDomain").List() } // RemoveResultsPolicyDomain removes all TLSResults for the policy domain on the // day from the database. func RemoveResultsPolicyDomain(ctx context.Context, policyDomain dns.Domain, dayUTC string) error { - db, err := resultDB(ctx) - if err != nil { - return err - } - - _, err = bstore.QueryDB[TLSResult](ctx, db).FilterNonzero(TLSResult{PolicyDomain: policyDomain.Name(), DayUTC: dayUTC}).Delete() + _, err := bstore.QueryDB[TLSResult](ctx, ResultDB).FilterNonzero(TLSResult{PolicyDomain: policyDomain.Name(), DayUTC: dayUTC}).Delete() return err } // RemoveResultsRecipientDomain removes all TLSResults for the recipient domain on // the day from the database. func RemoveResultsRecipientDomain(ctx context.Context, recipientDomain dns.Domain, dayUTC string) error { - db, err := resultDB(ctx) - if err != nil { - return err - } - - _, err = bstore.QueryDB[TLSResult](ctx, db).FilterNonzero(TLSResult{RecipientDomain: recipientDomain.Name(), DayUTC: dayUTC}).Delete() + _, err := bstore.QueryDB[TLSResult](ctx, ResultDB).FilterNonzero(TLSResult{RecipientDomain: recipientDomain.Name(), DayUTC: dayUTC}).Delete() return err } // SuppressAdd adds an address to the suppress list. func SuppressAdd(ctx context.Context, ba *SuppressAddress) error { - db, err := resultDB(ctx) - if err != nil { - return err - } - - return db.Insert(ctx, ba) + return ResultDB.Insert(ctx, ba) } // SuppressList returns all reporting addresses on the suppress list. func SuppressList(ctx context.Context) ([]SuppressAddress, error) { - db, err := resultDB(ctx) - if err != nil { - return nil, err - } - - return bstore.QueryDB[SuppressAddress](ctx, db).SortDesc("ID").List() + return bstore.QueryDB[SuppressAddress](ctx, ResultDB).SortDesc("ID").List() } // SuppressRemove removes a reporting address record from the suppress list. func SuppressRemove(ctx context.Context, id int64) error { - db, err := resultDB(ctx) - if err != nil { - return err - } - - return db.Delete(ctx, &SuppressAddress{ID: id}) + return ResultDB.Delete(ctx, &SuppressAddress{ID: id}) } // SuppressUpdate updates the until field of a reporting address record. func SuppressUpdate(ctx context.Context, id int64, until time.Time) error { - db, err := resultDB(ctx) - if err != nil { - return err - } - ba := SuppressAddress{ID: id} - err = db.Get(ctx, &ba) + err := ResultDB.Get(ctx, &ba) if err != nil { return err } ba.Until = until - return db.Update(ctx, &ba) + return ResultDB.Update(ctx, &ba) } diff --git a/vendor/github.com/mjl-/bstore/doc.go b/vendor/github.com/mjl-/bstore/doc.go index f7e2d7e..09c30c3 100644 --- a/vendor/github.com/mjl-/bstore/doc.go +++ b/vendor/github.com/mjl-/bstore/doc.go @@ -216,24 +216,24 @@ stays trivial (though not plan9, unfortunately). Although bstore is much more limited in so many aspects than sqlite, bstore also offers some advantages as well. Some points of comparison: -- Cross-compilation and reproducibility: Trivial with bstore due to pure Go, - much harder with sqlite because of cgo. -- Code complexity: low with bstore (7k lines including comments/docs), high - with sqlite. -- Query language: mostly-type-checked function calls in bstore, free-form query - strings only checked at runtime with sqlite. -- Functionality: very limited with bstore, much more full-featured with sqlite. -- Schema management: mostly automatic based on Go type definitions in bstore, - manual with ALTER statements in sqlite. -- Types and packing/parsing: automatic/transparent in bstore based on Go types - (including maps, slices, structs and custom MarshalBinary encoding), versus - manual scanning and parameter passing with sqlite with limited set of SQL - types. -- Performance: low to good performance with bstore, high performance with - sqlite. -- Database files: single file with bstore, several files with sqlite (due to - WAL or journal files). -- Test coverage: decent coverage but limited real-world for bstore, versus - extremely thoroughly tested and with enormous real-world use. + - Cross-compilation and reproducibility: Trivial with bstore due to pure Go, + much harder with sqlite because of cgo. + - Code complexity: low with bstore (7k lines including comments/docs), high + with sqlite. + - Query language: mostly-type-checked function calls in bstore, free-form query + strings only checked at runtime with sqlite. + - Functionality: very limited with bstore, much more full-featured with sqlite. + - Schema management: mostly automatic based on Go type definitions in bstore, + manual with ALTER statements in sqlite. + - Types and packing/parsing: automatic/transparent in bstore based on Go types + (including maps, slices, structs and custom MarshalBinary encoding), versus + manual scanning and parameter passing with sqlite with limited set of SQL + types. + - Performance: low to good performance with bstore, high performance with + sqlite. + - Database files: single file with bstore, several files with sqlite (due to + WAL or journal files). + - Test coverage: decent coverage but limited real-world for bstore, versus + extremely thoroughly tested and with enormous real-world use. */ package bstore diff --git a/vendor/github.com/mjl-/bstore/register.go b/vendor/github.com/mjl-/bstore/register.go index e740846..cfa44b8 100644 --- a/vendor/github.com/mjl-/bstore/register.go +++ b/vendor/github.com/mjl-/bstore/register.go @@ -7,6 +7,7 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" "os" "reflect" "sort" @@ -50,6 +51,17 @@ var errSchemaCheck = errors.New("schema check") // to "changed", an error is returned if there is no schema change. If it is set to // "unchanged", an error is returned if there was a schema change. func (db *DB) Register(ctx context.Context, typeValues ...any) error { + return db.register(ctx, slog.New(discardHandler{}), typeValues...) +} + +type discardHandler struct{} + +func (l discardHandler) Enabled(context.Context, slog.Level) bool { return false } +func (l discardHandler) Handle(context.Context, slog.Record) error { return nil } +func (l discardHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return l } +func (l discardHandler) WithGroup(name string) slog.Handler { return l } + +func (db *DB) register(ctx context.Context, log *slog.Logger, typeValues ...any) error { // We will drop/create new indices as needed. For changed indices, we drop // and recreate. E.g. if an index becomes a unique index, or if a field in // an index changes. These values map type and index name to their index. @@ -148,6 +160,9 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error { tv.Version = 1 if st.Current != nil { tv.Version = st.Current.Version + 1 + log.Debug("updating schema for type", slog.String("type", tv.name), slog.Uint64("version", uint64(tv.Version))) + } else { + log.Debug("adding schema for new type", slog.String("type", tv.name), slog.Uint64("version", uint64(tv.Version))) } k, v, err := packSchema(tv) if err != nil { @@ -299,6 +314,7 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error { } foundField = true + log.Debug("verifying foreign key constraint for new reference", slog.String("fromtype", ntname), slog.String("fromfield", f.Name), slog.String("totype", name)) // For newly added references, check they are valid. b, err := tx.recordsBucket(ntname, ntv.fillPercent) @@ -396,6 +412,7 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error { db.types[st.Type] = *st db.typeNames[st.Name] = *st ntvp = &ntv + log.Debug("updating schema for type due to new incoming reference", slog.String("type", name), slog.Uint64("version", uint64(ntv.Version))) } k, v, err := packSchema(ntvp) @@ -453,6 +470,7 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error { if !drop { continue } + log.Debug("dropping old/modified index", slog.String("type", name), slog.String("indexname", iname)) b, err := tx.typeBucket(name) if err != nil { return err @@ -482,6 +500,7 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error { if !create { continue } + log.Debug("preparing for new/modified index for type", slog.String("type", name), slog.String("indexname", iname)) b, err := tx.typeBucket(name) if err != nil { return err @@ -578,6 +597,8 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error { return nil } + log.Debug("creating new/modified indexes for type", slog.String("type", name)) + // Now do all sorts + inserts. for i, ib := range ibs { idx := idxs[i] diff --git a/vendor/github.com/mjl-/bstore/store.go b/vendor/github.com/mjl-/bstore/store.go index 87227ae..b723366 100644 --- a/vendor/github.com/mjl-/bstore/store.go +++ b/vendor/github.com/mjl-/bstore/store.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "io/fs" + "log/slog" "os" "reflect" "sync" @@ -282,9 +283,10 @@ type fieldType struct { // Options configure how a database should be opened or initialized. type Options struct { - Timeout time.Duration // Abort if opening DB takes longer than Timeout. If not set, the deadline from the context is used. - Perm fs.FileMode // Permissions for new file if created. If zero, 0600 is used. - MustExist bool // Before opening, check that file exists. If not, io/fs.ErrNotExist is returned. + Timeout time.Duration // Abort if opening DB takes longer than Timeout. If not set, the deadline from the context is used. + Perm fs.FileMode // Permissions for new file if created. If zero, 0600 is used. + MustExist bool // Before opening, check that file exists. If not, io/fs.ErrNotExist is returned. + RegisterLogger *slog.Logger // For debug logging about schema upgrades. } // Open opens a bstore database and registers types by calling Register. @@ -326,7 +328,16 @@ func Open(ctx context.Context, path string, opts *Options, typeValues ...any) (* typeNames := map[string]storeType{} types := map[reflect.Type]storeType{} db := &DB{bdb: bdb, typeNames: typeNames, types: types} - if err := db.Register(ctx, typeValues...); err != nil { + var log *slog.Logger + if opts != nil { + log = opts.RegisterLogger + } + if log == nil { + log = slog.New(discardHandler{}) + } else { + log = log.With("dbpath", path) + } + if err := db.register(ctx, log, typeValues...); err != nil { bdb.Close() return nil, err } diff --git a/vendor/modules.txt b/vendor/modules.txt index 8575b16..c11b323 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -16,7 +16,7 @@ github.com/mjl-/adns/internal/singleflight # github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05 ## explicit; go 1.20 github.com/mjl-/autocert -# github.com/mjl-/bstore v0.0.5 +# github.com/mjl-/bstore v0.0.6 ## explicit; go 1.19 github.com/mjl-/bstore # github.com/mjl-/sconf v0.0.6 diff --git a/verifydata.go b/verifydata.go index a274ce1..9d01e44 100644 --- a/verifydata.go +++ b/verifydata.go @@ -120,7 +120,8 @@ possibly making them potentially no longer readable by the previous version. checkf(err, path, "reading bolt database") bdb.Close() - db, err := bstore.Open(ctxbg, path, nil, types...) + opts := bstore.Options{RegisterLogger: c.log.Logger} + db, err := bstore.Open(ctxbg, path, &opts, types...) checkf(err, path, "open database with bstore") if err != nil { return @@ -162,7 +163,8 @@ possibly making them potentially no longer readable by the previous version. // Check that all messages present in the database also exist on disk. seen := map[string]struct{}{} - db, err := bstore.Open(ctxbg, dbpath, &bstore.Options{MustExist: true}, queue.DBTypes...) + opts := bstore.Options{MustExist: true, RegisterLogger: c.log.Logger} + db, err := bstore.Open(ctxbg, dbpath, &opts, queue.DBTypes...) checkf(err, dbpath, "opening queue database to check messages") if err == nil { err := bstore.QueryDB[queue.Msg](ctxbg, db).ForEach(func(m queue.Msg) error { @@ -237,7 +239,8 @@ possibly making them potentially no longer readable by the previous version. // And check consistency of UIDs with the mailbox UIDNext, and check UIDValidity. seen := map[string]struct{}{} dbpath := filepath.Join(accdir, "index.db") - db, err := bstore.Open(ctxbg, dbpath, &bstore.Options{MustExist: true}, store.DBTypes...) + opts := bstore.Options{MustExist: true, RegisterLogger: c.log.Logger} + db, err := bstore.Open(ctxbg, dbpath, &opts, store.DBTypes...) checkf(err, dbpath, "opening account database to check messages") if err == nil { uidvalidity := store.NextUIDValidity{ID: 1} diff --git a/webmail/api_test.go b/webmail/api_test.go index 3bb11fc..5658878 100644 --- a/webmail/api_test.go +++ b/webmail/api_test.go @@ -17,6 +17,7 @@ import ( "github.com/mjl-/mox/dns" "github.com/mjl-/mox/mlog" "github.com/mjl-/mox/mox-" + "github.com/mjl-/mox/mtastsdb" "github.com/mjl-/mox/queue" "github.com/mjl-/mox/store" ) @@ -58,6 +59,8 @@ func TestAPI(t *testing.T) { defer store.Switchboard()() log := mlog.New("webmail", nil) + err := mtastsdb.Init(false) + tcheck(t, err, "mtastsdb init") acc, err := store.OpenAccount(log, "mjl") tcheck(t, err, "open account") const pw0 = "te\u0301st \u00a0\u2002\u200a" // NFD and various unicode spaces. @@ -65,7 +68,9 @@ func TestAPI(t *testing.T) { err = acc.SetPassword(log, pw0) tcheck(t, err, "set password") defer func() { - err := acc.Close() + err := mtastsdb.Close() + tcheck(t, err, "mtastsdb close") + err = acc.Close() pkglog.Check(err, "closing account") acc.CheckClosed() }()