add debug logging about bstore db schema upgrades

bstore was updated to v0.0.6 to add this logging.
this simplifies some of the db-handling code in mtastsdb,tlsrptdb,dmarcdb. we
now call the package-level Init() and Close() in all tests properly.
This commit is contained in:
Mechiel Lukkien 2024-05-10 14:44:37 +02:00
parent 3e4cce826e
commit bf8cfd9724
No known key found for this signature in database
31 changed files with 298 additions and 428 deletions

View file

@ -313,7 +313,8 @@ func backupctl(ctx context.Context, ctl *ctl) {
} }
dstdbpath := filepath.Join(dstDataDir, path) dstdbpath := filepath.Join(dstDataDir, path)
db, err := bstore.Open(ctx, dstdbpath, &bstore.Options{MustExist: true}, queue.DBTypes...) opts := bstore.Options{MustExist: true, RegisterLogger: ctl.log.Logger}
db, err := bstore.Open(ctx, dstdbpath, &opts, queue.DBTypes...)
if err != nil { if err != nil {
xerrx("open copied queue database", err, slog.String("dstpath", dstdbpath), slog.Duration("duration", time.Since(tmQueue))) xerrx("open copied queue database", err, slog.String("dstpath", dstdbpath), slog.Duration("duration", time.Since(tmQueue)))
return return
@ -419,7 +420,8 @@ func backupctl(ctx context.Context, ctl *ctl) {
} }
dstdbpath := filepath.Join(dstDataDir, dbpath) dstdbpath := filepath.Join(dstDataDir, dbpath)
db, err := bstore.Open(ctx, dstdbpath, &bstore.Options{MustExist: true}, store.DBTypes...) opts := bstore.Options{MustExist: true, RegisterLogger: ctl.log.Logger}
db, err := bstore.Open(ctx, dstdbpath, &opts, store.DBTypes...)
if err != nil { if err != nil {
xerrx("open copied account database", err, slog.String("dstpath", dstdbpath), slog.Duration("duration", time.Since(tmAccount))) xerrx("open copied account database", err, slog.String("dstpath", dstdbpath), slog.Duration("duration", time.Since(tmAccount)))
return return

View file

@ -11,6 +11,15 @@
package dmarcdb package dmarcdb
import ( import (
"context"
"fmt"
"os"
"path/filepath"
"time"
"github.com/mjl-/bstore"
"github.com/mjl-/mox/mlog"
"github.com/mjl-/mox/mox-" "github.com/mjl-/mox/mox-"
) )
@ -19,11 +28,49 @@ import (
// The incoming reports and evaluations for outgoing reports are in separate // The incoming reports and evaluations for outgoing reports are in separate
// databases for simpler file-based handling of the databases. // databases for simpler file-based handling of the databases.
func Init() error { func Init() error {
if _, err := reportsDB(mox.Shutdown); err != nil { if ReportsDB != nil || EvalDB != nil {
return err return fmt.Errorf("already initialized")
} }
if _, err := evalDB(mox.Shutdown); err != nil {
return err log := mlog.New("dmarcdb", nil)
var err error
ReportsDB, err = openReportsDB(mox.Shutdown, log)
if err != nil {
return fmt.Errorf("open reports db: %v", err)
} }
EvalDB, err = openEvalDB(mox.Shutdown, log)
if err != nil {
return fmt.Errorf("open eval db: %v", err)
}
return nil return nil
} }
func Close() error {
if err := ReportsDB.Close(); err != nil {
return fmt.Errorf("closing reports db: %w", err)
}
ReportsDB = nil
if err := EvalDB.Close(); err != nil {
return fmt.Errorf("closing eval db: %w", err)
}
EvalDB = nil
return nil
}
func openReportsDB(ctx context.Context, log mlog.Log) (*bstore.DB, error) {
p := mox.DataDirPath("dmarcrpt.db")
os.MkdirAll(filepath.Dir(p), 0770)
opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
return bstore.Open(ctx, p, &opts, ReportsDBTypes...)
}
func openEvalDB(ctx context.Context, log mlog.Log) (*bstore.DB, error) {
p := mox.DataDirPath("dmarceval.db")
os.MkdirAll(filepath.Dir(p), 0770)
opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
return bstore.Open(ctx, p, &opts, EvalDBTypes...)
}

View file

@ -15,7 +15,6 @@ import (
"net/textproto" "net/textproto"
"net/url" "net/url"
"os" "os"
"path/filepath"
"runtime/debug" "runtime/debug"
"slices" "slices"
"sort" "sort"
@ -67,7 +66,6 @@ var (
// to the database. Every hour, a goroutine wakes up that gathers evaluations from // to the database. Every hour, a goroutine wakes up that gathers evaluations from
// the last hour(s), sends a report, and removes the evaluations from the database. // the last hour(s), sends a report, and removes the evaluations from the database.
EvalDB *bstore.DB EvalDB *bstore.DB
evalMutex sync.Mutex
) )
// Evaluation is the result of an evaluation of a DMARC policy, to be included // Evaluation is the result of an evaluation of a DMARC policy, to be included
@ -162,21 +160,6 @@ func (e Evaluation) ReportRecord(count int) dmarcrpt.ReportRecord {
} }
} }
func evalDB(ctx context.Context) (rdb *bstore.DB, rerr error) {
evalMutex.Lock()
defer evalMutex.Unlock()
if EvalDB == nil {
p := mox.DataDirPath("dmarceval.db")
os.MkdirAll(filepath.Dir(p), 0770)
db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, EvalDBTypes...)
if err != nil {
return nil, err
}
EvalDB = db
}
return EvalDB, nil
}
var intervalOpts = []int{24, 12, 8, 6, 4, 3, 2} var intervalOpts = []int{24, 12, 8, 6, 4, 3, 2}
func intervalHours(seconds int) int { func intervalHours(seconds int) int {
@ -197,23 +180,13 @@ func intervalHours(seconds int) int {
func AddEvaluation(ctx context.Context, aggregateReportingIntervalSeconds int, e *Evaluation) error { func AddEvaluation(ctx context.Context, aggregateReportingIntervalSeconds int, e *Evaluation) error {
e.IntervalHours = intervalHours(aggregateReportingIntervalSeconds) e.IntervalHours = intervalHours(aggregateReportingIntervalSeconds)
db, err := evalDB(ctx)
if err != nil {
return err
}
e.ID = 0 e.ID = 0
return db.Insert(ctx, e) return EvalDB.Insert(ctx, e)
} }
// Evaluations returns all evaluations in the database. // Evaluations returns all evaluations in the database.
func Evaluations(ctx context.Context) ([]Evaluation, error) { func Evaluations(ctx context.Context) ([]Evaluation, error) {
db, err := evalDB(ctx) q := bstore.QueryDB[Evaluation](ctx, EvalDB)
if err != nil {
return nil, err
}
q := bstore.QueryDB[Evaluation](ctx, db)
q.SortAsc("Evaluated") q.SortAsc("Evaluated")
return q.List() return q.List()
} }
@ -229,14 +202,9 @@ type EvaluationStat struct {
// EvaluationStats returns evaluation counts and report-sending status per domain. // EvaluationStats returns evaluation counts and report-sending status per domain.
func EvaluationStats(ctx context.Context) (map[string]EvaluationStat, error) { func EvaluationStats(ctx context.Context) (map[string]EvaluationStat, error) {
db, err := evalDB(ctx)
if err != nil {
return nil, err
}
r := map[string]EvaluationStat{} r := map[string]EvaluationStat{}
err = bstore.QueryDB[Evaluation](ctx, db).ForEach(func(e Evaluation) error { err := bstore.QueryDB[Evaluation](ctx, EvalDB).ForEach(func(e Evaluation) error {
if stat, ok := r[e.PolicyDomain]; ok { if stat, ok := r[e.PolicyDomain]; ok {
if !slices.Contains(stat.Dispositions, string(e.Disposition)) { if !slices.Contains(stat.Dispositions, string(e.Disposition)) {
stat.Dispositions = append(stat.Dispositions, string(e.Disposition)) stat.Dispositions = append(stat.Dispositions, string(e.Disposition))
@ -263,12 +231,7 @@ func EvaluationStats(ctx context.Context) (map[string]EvaluationStat, error) {
// EvaluationsDomain returns all evaluations for a domain. // EvaluationsDomain returns all evaluations for a domain.
func EvaluationsDomain(ctx context.Context, domain dns.Domain) ([]Evaluation, error) { func EvaluationsDomain(ctx context.Context, domain dns.Domain) ([]Evaluation, error) {
db, err := evalDB(ctx) q := bstore.QueryDB[Evaluation](ctx, EvalDB)
if err != nil {
return nil, err
}
q := bstore.QueryDB[Evaluation](ctx, db)
q.FilterNonzero(Evaluation{PolicyDomain: domain.Name()}) q.FilterNonzero(Evaluation{PolicyDomain: domain.Name()})
q.SortAsc("Evaluated") q.SortAsc("Evaluated")
return q.List() return q.List()
@ -277,14 +240,9 @@ func EvaluationsDomain(ctx context.Context, domain dns.Domain) ([]Evaluation, er
// RemoveEvaluationsDomain removes evaluations for domain so they won't be sent in // RemoveEvaluationsDomain removes evaluations for domain so they won't be sent in
// an aggregate report. // an aggregate report.
func RemoveEvaluationsDomain(ctx context.Context, domain dns.Domain) error { func RemoveEvaluationsDomain(ctx context.Context, domain dns.Domain) error {
db, err := evalDB(ctx) q := bstore.QueryDB[Evaluation](ctx, EvalDB)
if err != nil {
return err
}
q := bstore.QueryDB[Evaluation](ctx, db)
q.FilterNonzero(Evaluation{PolicyDomain: domain.Name()}) q.FilterNonzero(Evaluation{PolicyDomain: domain.Name()})
_, err = q.Delete() _, err := q.Delete()
return err return err
} }
@ -318,12 +276,6 @@ func Start(resolver dns.Resolver) {
ctx := mox.Shutdown ctx := mox.Shutdown
db, err := evalDB(ctx)
if err != nil {
log.Errorx("opening dmarc evaluations database for sending dmarc aggregate reports, not sending reports", err)
return
}
for { for {
now := time.Now() now := time.Now()
nextEnd := nextWholeHour(now) nextEnd := nextWholeHour(now)
@ -355,12 +307,12 @@ func Start(resolver dns.Resolver) {
// 24 hour interval). They should have been processed by now. We may have kept them // 24 hour interval). They should have been processed by now. We may have kept them
// during temporary errors, but persistent temporary errors shouldn't fill up our // during temporary errors, but persistent temporary errors shouldn't fill up our
// database. This also cleans up evaluations that were all optional for a domain. // database. This also cleans up evaluations that were all optional for a domain.
_, err := bstore.QueryDB[Evaluation](ctx, db).FilterLess("Evaluated", nextEnd.Add(-48*time.Hour)).Delete() _, err := bstore.QueryDB[Evaluation](ctx, EvalDB).FilterLess("Evaluated", nextEnd.Add(-48*time.Hour)).Delete()
log.Check(err, "removing stale dmarc evaluations from database") log.Check(err, "removing stale dmarc evaluations from database")
clog := log.WithCid(mox.Cid()) clog := log.WithCid(mox.Cid())
clog.Info("sending dmarc aggregate reports", slog.Time("end", nextEnd.UTC()), slog.Any("intervals", intervals)) clog.Info("sending dmarc aggregate reports", slog.Time("end", nextEnd.UTC()), slog.Any("intervals", intervals))
if err := sendReports(ctx, clog, resolver, db, nextEnd, intervals); err != nil { if err := sendReports(ctx, clog, resolver, EvalDB, nextEnd, intervals); err != nil {
clog.Errorx("sending dmarc aggregate reports", err) clog.Errorx("sending dmarc aggregate reports", err)
metricReportError.Inc() metricReportError.Inc()
} else { } else {
@ -1091,46 +1043,26 @@ func dkimSign(ctx context.Context, log mlog.Log, fromAddr smtp.Address, smtputf8
// SuppressAdd adds an address to the suppress list. // SuppressAdd adds an address to the suppress list.
func SuppressAdd(ctx context.Context, ba *SuppressAddress) error { func SuppressAdd(ctx context.Context, ba *SuppressAddress) error {
db, err := evalDB(ctx) return EvalDB.Insert(ctx, ba)
if err != nil {
return err
}
return db.Insert(ctx, ba)
} }
// SuppressList returns all reporting addresses on the suppress list. // SuppressList returns all reporting addresses on the suppress list.
func SuppressList(ctx context.Context) ([]SuppressAddress, error) { func SuppressList(ctx context.Context) ([]SuppressAddress, error) {
db, err := evalDB(ctx) return bstore.QueryDB[SuppressAddress](ctx, EvalDB).SortDesc("ID").List()
if err != nil {
return nil, err
}
return bstore.QueryDB[SuppressAddress](ctx, db).SortDesc("ID").List()
} }
// SuppressRemove removes a reporting address record from the suppress list. // SuppressRemove removes a reporting address record from the suppress list.
func SuppressRemove(ctx context.Context, id int64) error { func SuppressRemove(ctx context.Context, id int64) error {
db, err := evalDB(ctx) return EvalDB.Delete(ctx, &SuppressAddress{ID: id})
if err != nil {
return err
}
return db.Delete(ctx, &SuppressAddress{ID: id})
} }
// SuppressUpdate updates the until field of a reporting address record. // SuppressUpdate updates the until field of a reporting address record.
func SuppressUpdate(ctx context.Context, id int64, until time.Time) error { func SuppressUpdate(ctx context.Context, id int64, until time.Time) error {
db, err := evalDB(ctx)
if err != nil {
return err
}
ba := SuppressAddress{ID: id} ba := SuppressAddress{ID: id}
err = db.Get(ctx, &ba) err := EvalDB.Get(ctx, &ba)
if err != nil { if err != nil {
return err return err
} }
ba.Until = until ba.Until = until
return db.Update(ctx, &ba) return EvalDB.Update(ctx, &ba)
} }

View file

@ -41,13 +41,13 @@ func TestEvaluations(t *testing.T) {
mox.Context = ctxbg mox.Context = ctxbg
mox.ConfigStaticPath = filepath.FromSlash("../testdata/dmarcdb/mox.conf") mox.ConfigStaticPath = filepath.FromSlash("../testdata/dmarcdb/mox.conf")
mox.MustLoadConfig(true, false) mox.MustLoadConfig(true, false)
EvalDB = nil
_, err := evalDB(ctxbg) os.Remove(mox.DataDirPath("dmarceval.db"))
tcheckf(t, err, "database") err := Init()
tcheckf(t, err, "init")
defer func() { defer func() {
EvalDB.Close() err := Close()
EvalDB = nil tcheckf(t, err, "close")
}() }()
parseJSON := func(s string) (e Evaluation) { parseJSON := func(s string) (e Evaluation) {
@ -163,13 +163,13 @@ func TestSendReports(t *testing.T) {
mox.Context = ctxbg mox.Context = ctxbg
mox.ConfigStaticPath = filepath.FromSlash("../testdata/dmarcdb/mox.conf") mox.ConfigStaticPath = filepath.FromSlash("../testdata/dmarcdb/mox.conf")
mox.MustLoadConfig(true, false) mox.MustLoadConfig(true, false)
EvalDB = nil
db, err := evalDB(ctxbg) os.Remove(mox.DataDirPath("dmarceval.db"))
tcheckf(t, err, "database") err := Init()
tcheckf(t, err, "init")
defer func() { defer func() {
EvalDB.Close() err := Close()
EvalDB = nil tcheckf(t, err, "close")
}() }()
resolver := dns.MockResolver{ resolver := dns.MockResolver{
@ -288,7 +288,7 @@ func TestSendReports(t *testing.T) {
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg) mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
for _, e := range evals { for _, e := range evals {
err := db.Insert(ctxbg, &e) err := EvalDB.Insert(ctxbg, &e)
tcheckf(t, err, "inserting evaluation") tcheckf(t, err, "inserting evaluation")
} }
@ -359,13 +359,13 @@ func TestSendReports(t *testing.T) {
// Address is suppressed. // Address is suppressed.
sa := SuppressAddress{ReportingAddress: "dmarcrpt@sender.example", Until: time.Now().Add(time.Minute)} sa := SuppressAddress{ReportingAddress: "dmarcrpt@sender.example", Until: time.Now().Add(time.Minute)}
err = db.Insert(ctxbg, &sa) err = EvalDB.Insert(ctxbg, &sa)
tcheckf(t, err, "insert suppress address") tcheckf(t, err, "insert suppress address")
test([]Evaluation{eval}, map[string]struct{}{}, map[string]struct{}{}, nil) test([]Evaluation{eval}, map[string]struct{}{}, map[string]struct{}{}, nil)
// Suppression has expired. // Suppression has expired.
sa.Until = time.Now().Add(-time.Minute) sa.Until = time.Now().Add(-time.Minute)
err = db.Update(ctxbg, &sa) err = EvalDB.Update(ctxbg, &sa)
tcheckf(t, err, "update suppress address") tcheckf(t, err, "update suppress address")
test([]Evaluation{eval}, map[string]struct{}{"dmarcrpt@sender.example": {}}, map[string]struct{}{}, expFeedback) test([]Evaluation{eval}, map[string]struct{}{"dmarcrpt@sender.example": {}}, map[string]struct{}{}, expFeedback)

View file

@ -3,9 +3,6 @@ package dmarcdb
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"path/filepath"
"sync"
"time" "time"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -15,13 +12,11 @@ import (
"github.com/mjl-/mox/dmarcrpt" "github.com/mjl-/mox/dmarcrpt"
"github.com/mjl-/mox/dns" "github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mox-"
) )
var ( var (
ReportsDBTypes = []any{DomainFeedback{}} // Types stored in DB. ReportsDBTypes = []any{DomainFeedback{}} // Types stored in DB.
ReportsDB *bstore.DB // Exported for backups. ReportsDB *bstore.DB // Exported for backups.
reportsMutex sync.Mutex
) )
var ( var (
@ -59,38 +54,18 @@ type DomainFeedback struct {
dmarcrpt.Feedback dmarcrpt.Feedback
} }
func reportsDB(ctx context.Context) (rdb *bstore.DB, rerr error) {
reportsMutex.Lock()
defer reportsMutex.Unlock()
if ReportsDB == nil {
p := mox.DataDirPath("dmarcrpt.db")
os.MkdirAll(filepath.Dir(p), 0770)
db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, ReportsDBTypes...)
if err != nil {
return nil, err
}
ReportsDB = db
}
return ReportsDB, nil
}
// AddReport adds a DMARC aggregate feedback report from an email to the database, // AddReport adds a DMARC aggregate feedback report from an email to the database,
// and updates prometheus metrics. // and updates prometheus metrics.
// //
// fromDomain is the domain in the report message From header. // fromDomain is the domain in the report message From header.
func AddReport(ctx context.Context, f *dmarcrpt.Feedback, fromDomain dns.Domain) error { func AddReport(ctx context.Context, f *dmarcrpt.Feedback, fromDomain dns.Domain) error {
db, err := reportsDB(ctx)
if err != nil {
return err
}
d, err := dns.ParseDomain(f.PolicyPublished.Domain) d, err := dns.ParseDomain(f.PolicyPublished.Domain)
if err != nil { if err != nil {
return fmt.Errorf("parsing domain in report: %v", err) return fmt.Errorf("parsing domain in report: %v", err)
} }
df := DomainFeedback{0, d.Name(), fromDomain.Name(), *f} df := DomainFeedback{0, d.Name(), fromDomain.Name(), *f}
if err := db.Insert(ctx, &df); err != nil { if err := ReportsDB.Insert(ctx, &df); err != nil {
return err return err
} }
@ -129,38 +104,23 @@ func AddReport(ctx context.Context, f *dmarcrpt.Feedback, fromDomain dns.Domain)
// Records returns all reports in the database. // Records returns all reports in the database.
func Records(ctx context.Context) ([]DomainFeedback, error) { func Records(ctx context.Context) ([]DomainFeedback, error) {
db, err := reportsDB(ctx) return bstore.QueryDB[DomainFeedback](ctx, ReportsDB).List()
if err != nil {
return nil, err
}
return bstore.QueryDB[DomainFeedback](ctx, db).List()
} }
// RecordID returns the report for the ID. // RecordID returns the report for the ID.
func RecordID(ctx context.Context, id int64) (DomainFeedback, error) { func RecordID(ctx context.Context, id int64) (DomainFeedback, error) {
db, err := reportsDB(ctx)
if err != nil {
return DomainFeedback{}, err
}
e := DomainFeedback{ID: id} e := DomainFeedback{ID: id}
err = db.Get(ctx, &e) err := ReportsDB.Get(ctx, &e)
return e, err return e, err
} }
// RecordsPeriodDomain returns the reports overlapping start and end, for the given // RecordsPeriodDomain returns the reports overlapping start and end, for the given
// domain. If domain is empty, all records match for domain. // domain. If domain is empty, all records match for domain.
func RecordsPeriodDomain(ctx context.Context, start, end time.Time, domain string) ([]DomainFeedback, error) { func RecordsPeriodDomain(ctx context.Context, start, end time.Time, domain string) ([]DomainFeedback, error) {
db, err := reportsDB(ctx)
if err != nil {
return nil, err
}
s := start.Unix() s := start.Unix()
e := end.Unix() e := end.Unix()
q := bstore.QueryDB[DomainFeedback](ctx, db) q := bstore.QueryDB[DomainFeedback](ctx, ReportsDB)
if domain != "" { if domain != "" {
q.FilterNonzero(DomainFeedback{Domain: domain}) q.FilterNonzero(DomainFeedback{Domain: domain})
} }

View file

@ -20,16 +20,12 @@ func TestDMARCDB(t *testing.T) {
mox.ConfigStaticPath = filepath.FromSlash("../testdata/dmarcdb/mox.conf") mox.ConfigStaticPath = filepath.FromSlash("../testdata/dmarcdb/mox.conf")
mox.MustLoadConfig(true, false) mox.MustLoadConfig(true, false)
dbpath := mox.DataDirPath("dmarcrpt.db") os.Remove(mox.DataDirPath("dmarcrpt.db"))
os.MkdirAll(filepath.Dir(dbpath), 0770) err := Init()
tcheckf(t, err, "init")
if err := Init(); err != nil {
t.Fatalf("init database: %s", err)
}
defer os.Remove(dbpath)
defer func() { defer func() {
ReportsDB.Close() err := Close()
ReportsDB = nil tcheckf(t, err, "close")
}() }()
feedback := &dmarcrpt.Feedback{ feedback := &dmarcrpt.Feedback{

View file

@ -62,7 +62,8 @@ func xcmdExport(mbox, single bool, args []string, c *cmd) {
} }
dbpath := filepath.Join(accountDir, "index.db") dbpath := filepath.Join(accountDir, "index.db")
db, err := bstore.Open(context.Background(), dbpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, store.DBTypes...) opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: c.log.Logger}
db, err := bstore.Open(context.Background(), dbpath, &opts, store.DBTypes...)
xcheckf(err, "open database %q", dbpath) xcheckf(err, "open database %q", dbpath)
defer func() { defer func() {
if err := db.Close(); err != nil { if err := db.Close(); err != nil {

2
go.mod
View file

@ -5,7 +5,7 @@ go 1.21
require ( require (
github.com/mjl-/adns v0.0.0-20240509092456-2dc8715bf4af github.com/mjl-/adns v0.0.0-20240509092456-2dc8715bf4af
github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05 github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05
github.com/mjl-/bstore v0.0.5 github.com/mjl-/bstore v0.0.6
github.com/mjl-/sconf v0.0.6 github.com/mjl-/sconf v0.0.6
github.com/mjl-/sherpa v0.6.7 github.com/mjl-/sherpa v0.6.7
github.com/mjl-/sherpadoc v0.0.16 github.com/mjl-/sherpadoc v0.0.16

4
go.sum
View file

@ -28,8 +28,8 @@ github.com/mjl-/adns v0.0.0-20240509092456-2dc8715bf4af h1:sEDWZPIi5K1qKk7JQoAZy
github.com/mjl-/adns v0.0.0-20240509092456-2dc8715bf4af/go.mod h1:v47qUMJnipnmDTRGaHwpCwzE6oypa5K33mUvBfzZBn8= github.com/mjl-/adns v0.0.0-20240509092456-2dc8715bf4af/go.mod h1:v47qUMJnipnmDTRGaHwpCwzE6oypa5K33mUvBfzZBn8=
github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05 h1:s6ay4bh4tmpPLdxjyeWG45mcwHfEluBMuGPkqxHWUJ4= github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05 h1:s6ay4bh4tmpPLdxjyeWG45mcwHfEluBMuGPkqxHWUJ4=
github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05/go.mod h1:taMFU86abMxKLPV4Bynhv8enbYmS67b8LG80qZv2Qus= github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05/go.mod h1:taMFU86abMxKLPV4Bynhv8enbYmS67b8LG80qZv2Qus=
github.com/mjl-/bstore v0.0.5 h1:Cx+LWEBnFBsqSxZNMxeVujkfc0kG10lUJaAU4vWSRHo= github.com/mjl-/bstore v0.0.6 h1:ntlu9MkfCkpm2XfBY4+Ws4KK9YzXzewr3+lCueFB+9c=
github.com/mjl-/bstore v0.0.5/go.mod h1:/cD25FNBaDfvL/plFRxI3Ba3E+wcB0XVOS8nJDqndg0= github.com/mjl-/bstore v0.0.6/go.mod h1:/cD25FNBaDfvL/plFRxI3Ba3E+wcB0XVOS8nJDqndg0=
github.com/mjl-/sconf v0.0.6 h1:5Dt58488ZOoVx680zgK2K3vUrokLsp5mXDUACrJlrUc= github.com/mjl-/sconf v0.0.6 h1:5Dt58488ZOoVx680zgK2K3vUrokLsp5mXDUACrJlrUc=
github.com/mjl-/sconf v0.0.6/go.mod h1:uF8OdWtLT8La3i4ln176i1pB0ps9pXGCaABEU55ZkE0= github.com/mjl-/sconf v0.0.6/go.mod h1:uF8OdWtLT8La3i4ln176i1pB0ps9pXGCaABEU55ZkE0=
github.com/mjl-/sherpa v0.6.7 h1:C5F8XQdV5nCuS4fvB+ye/ziUQrajEhOoj/t2w5T14BY= github.com/mjl-/sherpa v0.6.7 h1:C5F8XQdV5nCuS4fvB+ye/ziUQrajEhOoj/t2w5T14BY=

View file

@ -125,7 +125,7 @@ func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomP
} }
} }
db, err := openDB(ctx, dbPath) db, err := openDB(ctx, log, dbPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("open database: %s", err) return nil, fmt.Errorf("open database: %s", err)
} }
@ -230,18 +230,20 @@ func newDB(ctx context.Context, log mlog.Log, path string) (db *bstore.DB, rerr
} }
}() }()
db, err := bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
db, err := bstore.Open(ctx, path, &opts, DBTypes...)
if err != nil { if err != nil {
return nil, fmt.Errorf("open new database: %w", err) return nil, fmt.Errorf("open new database: %w", err)
} }
return db, nil return db, nil
} }
func openDB(ctx context.Context, path string) (*bstore.DB, error) { func openDB(ctx context.Context, log mlog.Log, path string) (*bstore.DB, error) {
if _, err := os.Stat(path); err != nil { if _, err := os.Stat(path); err != nil {
return nil, fmt.Errorf("stat db file: %w", err) return nil, fmt.Errorf("stat db file: %w", err)
} }
return bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
return bstore.Open(ctx, path, &opts, DBTypes...)
} }
// Save stores modifications, e.g. from training, to the database and bloom // Save stores modifications, e.g. from training, to the database and bloom

View file

@ -14,7 +14,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"time" "time"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -68,27 +67,17 @@ var (
var DBTypes = []any{PolicyRecord{}} // Types stored in DB. var DBTypes = []any{PolicyRecord{}} // Types stored in DB.
var DB *bstore.DB // Exported for backups. var DB *bstore.DB // Exported for backups.
var mutex sync.Mutex
func database(ctx context.Context) (rdb *bstore.DB, rerr error) {
mutex.Lock()
defer mutex.Unlock()
if DB == nil {
p := mox.DataDirPath("mtasts.db")
os.MkdirAll(filepath.Dir(p), 0770)
db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...)
if err != nil {
return nil, err
}
DB = db
}
return DB, nil
}
// Init opens the database and starts a goroutine that refreshes policies in // Init opens the database and starts a goroutine that refreshes policies in
// the database, and keeps doing so periodically. // the database, and keeps doing so periodically.
func Init(refresher bool) error { func Init(refresher bool) error {
_, err := database(mox.Shutdown) log := mlog.New("mtastsdb", nil)
p := mox.DataDirPath("mtasts.db")
os.MkdirAll(filepath.Dir(p), 0770)
opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
var err error
DB, err = bstore.Open(mox.Shutdown, p, &opts, DBTypes...)
if err != nil { if err != nil {
return err return err
} }
@ -102,14 +91,12 @@ func Init(refresher bool) error {
} }
// Close closes the database. // Close closes the database.
func Close() { func Close() error {
mutex.Lock() if err := DB.Close(); err != nil {
defer mutex.Unlock() return fmt.Errorf("close db: %w", err)
if DB != nil {
err := DB.Close()
mlog.New("mtastsdb", nil).Check(err, "closing database")
DB = nil
} }
DB = nil
return nil
} }
// lookup looks up a policy for the domain in the database. // lookup looks up a policy for the domain in the database.
@ -119,16 +106,11 @@ func Close() {
// Returns ErrNotFound if record is not present. // Returns ErrNotFound if record is not present.
// Returns ErrBackoff if a recent attempt to fetch a record failed. // Returns ErrBackoff if a recent attempt to fetch a record failed.
func lookup(ctx context.Context, log mlog.Log, domain dns.Domain) (*PolicyRecord, error) { func lookup(ctx context.Context, log mlog.Log, domain dns.Domain) (*PolicyRecord, error) {
db, err := database(ctx)
if err != nil {
return nil, err
}
if domain.IsZero() { if domain.IsZero() {
return nil, fmt.Errorf("empty domain") return nil, fmt.Errorf("empty domain")
} }
now := timeNow() now := timeNow()
q := bstore.QueryDB[PolicyRecord](ctx, db) q := bstore.QueryDB[PolicyRecord](ctx, DB)
q.FilterNonzero(PolicyRecord{Domain: domain.Name()}) q.FilterNonzero(PolicyRecord{Domain: domain.Name()})
q.FilterGreater("ValidEnd", now) q.FilterGreater("ValidEnd", now)
pr, err := q.Get() pr, err := q.Get()
@ -139,7 +121,7 @@ func lookup(ctx context.Context, log mlog.Log, domain dns.Domain) (*PolicyRecord
} }
pr.LastUse = now pr.LastUse = now
if err := db.Update(ctx, &pr); err != nil { if err := DB.Update(ctx, &pr); err != nil {
log.Errorx("marking cached mta-sts policy as used in database", err) log.Errorx("marking cached mta-sts policy as used in database", err)
} }
if pr.Backoff { if pr.Backoff {
@ -151,12 +133,7 @@ func lookup(ctx context.Context, log mlog.Log, domain dns.Domain) (*PolicyRecord
// Upsert adds the policy to the database, overwriting an existing policy for the domain. // Upsert adds the policy to the database, overwriting an existing policy for the domain.
// Policy can be nil, indicating a failure to fetch the policy. // Policy can be nil, indicating a failure to fetch the policy.
func Upsert(ctx context.Context, domain dns.Domain, recordID string, policy *mtasts.Policy, policyText string) error { func Upsert(ctx context.Context, domain dns.Domain, recordID string, policy *mtasts.Policy, policyText string) error {
db, err := database(ctx) return DB.Write(ctx, func(tx *bstore.Tx) error {
if err != nil {
return err
}
return db.Write(ctx, func(tx *bstore.Tx) error {
pr := PolicyRecord{Domain: domain.Name()} pr := PolicyRecord{Domain: domain.Name()}
err := tx.Get(&pr) err := tx.Get(&pr)
if err != nil && err != bstore.ErrAbsent { if err != nil && err != bstore.ErrAbsent {
@ -195,11 +172,7 @@ func Upsert(ctx context.Context, domain dns.Domain, recordID string, policy *mta
// PolicyRecords returns all policies in the database, sorted descending by last // PolicyRecords returns all policies in the database, sorted descending by last
// use, domain. // use, domain.
func PolicyRecords(ctx context.Context) ([]PolicyRecord, error) { func PolicyRecords(ctx context.Context) ([]PolicyRecord, error) {
db, err := database(ctx) return bstore.QueryDB[PolicyRecord](ctx, DB).SortDesc("LastUse", "Domain").List()
if err != nil {
return nil, err
}
return bstore.QueryDB[PolicyRecord](ctx, db).SortDesc("LastUse", "Domain").List()
} }
// Get retrieves an MTA-STS policy for domain and whether it is fresh. // Get retrieves an MTA-STS policy for domain and whether it is fresh.

View file

@ -51,19 +51,14 @@ func refresh() int {
// jitter to the timing. Each refresh is done in a new goroutine, so a single slow // jitter to the timing. Each refresh is done in a new goroutine, so a single slow
// refresh doesn't mess up the timing. // refresh doesn't mess up the timing.
func refresh1(ctx context.Context, log mlog.Log, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) { func refresh1(ctx context.Context, log mlog.Log, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) {
db, err := database(ctx)
if err != nil {
return 0, err
}
now := timeNow() now := timeNow()
qdel := bstore.QueryDB[PolicyRecord](ctx, db) qdel := bstore.QueryDB[PolicyRecord](ctx, DB)
qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour)) qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour))
if _, err := qdel.Delete(); err != nil { if _, err := qdel.Delete(); err != nil {
return 0, fmt.Errorf("deleting old unused policies: %s", err) return 0, fmt.Errorf("deleting old unused policies: %s", err)
} }
qup := bstore.QueryDB[PolicyRecord](ctx, db) qup := bstore.QueryDB[PolicyRecord](ctx, DB)
qup.FilterLess("LastUpdate", now.Add(-12*time.Hour)) qup.FilterLess("LastUpdate", now.Add(-12*time.Hour))
prs, err := qup.List() prs, err := qup.List()
if err != nil { if err != nil {
@ -89,7 +84,7 @@ func refresh1(ctx context.Context, log mlog.Log, resolver dns.Resolver, sleep fu
log.Debug("will refresh mta-sts policies over next 3 hours", slog.Int("count", len(prs))) log.Debug("will refresh mta-sts policies over next 3 hours", slog.Int("count", len(prs)))
start := timeNow() start := timeNow()
for i, pr := range prs { for i, pr := range prs {
go refreshDomain(ctx, log, db, resolver, pr) go refreshDomain(ctx, log, DB, resolver, pr)
if i < len(prs)-1 { if i < len(prs)-1 {
interval := 3 * int64(time.Hour) / int64(len(prs)-1) interval := 3 * int64(time.Hour) / int64(len(prs)-1)
extra := time.Duration(rand.Int63n(interval) - interval/2) extra := time.Duration(rand.Int63n(interval) - interval/2)

View file

@ -9,7 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log" golog "log"
"math/big" "math/big"
"net" "net"
"net/http" "net/http"
@ -39,15 +39,14 @@ func TestRefresh(t *testing.T) {
os.Remove(dbpath) os.Remove(dbpath)
defer os.Remove(dbpath) defer os.Remove(dbpath)
if err := Init(false); err != nil { log := mlog.New("mtastsdb", nil)
t.Fatalf("init database: %s", err)
}
defer Close()
db, err := database(ctxbg) err := Init(false)
if err != nil { tcheckf(t, err, "init database")
t.Fatalf("database: %s", err) defer func() {
} err := Close()
tcheckf(t, err, "close database")
}()
cert := fakeCert(t, false) cert := fakeCert(t, false)
defer func() { defer func() {
@ -70,7 +69,7 @@ func TestRefresh(t *testing.T) {
} }
pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy, policy.String()} pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy, policy.String()}
if err := db.Insert(ctxbg, &pr); err != nil { if err := DB.Insert(ctxbg, &pr); err != nil {
t.Fatalf("insert policy: %s", err) t.Fatalf("insert policy: %s", err)
} }
} }
@ -114,7 +113,7 @@ func TestRefresh(t *testing.T) {
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
}, },
ErrorLog: log.New(io.Discard, "", 0), ErrorLog: golog.New(io.Discard, "", 0),
} }
s.ServeTLS(l, "", "") s.ServeTLS(l, "", "")
}() }()
@ -136,7 +135,6 @@ func TestRefresh(t *testing.T) {
t.Fatalf("bad sleep duration %v", d) t.Fatalf("bad sleep duration %v", d)
} }
} }
log := mlog.New("mtastsdb", nil)
if n, err := refresh1(ctxbg, log, resolver, sleep); err != nil || n != 3 { if n, err := refresh1(ctxbg, log, resolver, sleep); err != nil || n != 3 {
t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n) t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
} }
@ -146,7 +144,7 @@ func TestRefresh(t *testing.T) {
time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database. time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
// Should not do any more refreshes and return immediately. // Should not do any more refreshes and return immediately.
q := bstore.QueryDB[PolicyRecord](ctxbg, db) q := bstore.QueryDB[PolicyRecord](ctxbg, DB)
q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"}) q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
if _, err := q.Delete(); err != nil { if _, err := q.Delete(); err != nil {
t.Fatalf("delete record that would be refreshed: %v", err) t.Fatalf("delete record that would be refreshed: %v", err)

View file

@ -24,8 +24,6 @@ import (
func TestHookIncoming(t *testing.T) { func TestHookIncoming(t *testing.T) {
acc, cleanup := setup(t) acc, cleanup := setup(t)
defer cleanup() defer cleanup()
err := Init()
tcheck(t, err, "queue init")
accret, err := store.OpenAccount(pkglog, "retired") accret, err := store.OpenAccount(pkglog, "retired")
tcheck(t, err, "open account for retired") tcheck(t, err, "open account for retired")
@ -119,8 +117,6 @@ func TestHookIncoming(t *testing.T) {
func TestFromIDIncomingDelivery(t *testing.T) { func TestFromIDIncomingDelivery(t *testing.T) {
acc, cleanup := setup(t) acc, cleanup := setup(t)
defer cleanup() defer cleanup()
err := Init()
tcheck(t, err, "queue init")
accret, err := store.OpenAccount(pkglog, "retired") accret, err := store.OpenAccount(pkglog, "retired")
tcheck(t, err, "open account for retired") tcheck(t, err, "open account for retired")
@ -525,8 +521,6 @@ func TestFromIDIncomingDelivery(t *testing.T) {
func TestHookListFilterSort(t *testing.T) { func TestHookListFilterSort(t *testing.T) {
_, cleanup := setup(t) _, cleanup := setup(t)
defer cleanup() defer cleanup()
err := Init()
tcheck(t, err, "queue init")
now := time.Now().Round(0) now := time.Now().Round(0)
h := Hook{0, 0, "fromid", "messageid", "subj", nil, "mjl", "http://localhost", "", false, "delivered", "", now, 0, now, []HookResult{}} h := Hook{0, 0, "fromid", "messageid", "subj", nil, "mjl", "http://localhost", "", false, "delivered", "", now, 0, now, []HookResult{}}
@ -534,7 +528,7 @@ func TestHookListFilterSort(t *testing.T) {
h1.Submitted = now.Add(-time.Second) h1.Submitted = now.Add(-time.Second)
h1.NextAttempt = now.Add(time.Minute) h1.NextAttempt = now.Add(time.Minute)
hl := []Hook{h, h, h, h, h, h1} hl := []Hook{h, h, h, h, h, h1}
err = DB.Write(ctxbg, func(tx *bstore.Tx) error { err := DB.Write(ctxbg, func(tx *bstore.Tx) error {
for i := range hl { for i := range hl {
err := hookInsert(tx, &hl[i], now, time.Minute) err := hookInsert(tx, &hl[i], now, time.Minute)
tcheck(t, err, "insert hook") tcheck(t, err, "insert hook")

View file

@ -350,7 +350,9 @@ func Init() error {
} }
var err error var err error
DB, err = bstore.Open(mox.Shutdown, qpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) log := mlog.New("queue", nil)
opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
DB, err = bstore.Open(mox.Shutdown, qpath, &opts, DBTypes...)
if err == nil { if err == nil {
err = DB.Read(mox.Shutdown, func(tx *bstore.Tx) error { err = DB.Read(mox.Shutdown, func(tx *bstore.Tx) error {
return metricHoldUpdate(tx) return metricHoldUpdate(tx)

View file

@ -28,6 +28,7 @@ import (
"github.com/mjl-/mox/dns" "github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mlog" "github.com/mjl-/mox/mlog"
"github.com/mjl-/mox/mox-" "github.com/mjl-/mox/mox-"
"github.com/mjl-/mox/mtastsdb"
"github.com/mjl-/mox/smtp" "github.com/mjl-/mox/smtp"
"github.com/mjl-/mox/smtpclient" "github.com/mjl-/mox/smtpclient"
"github.com/mjl-/mox/store" "github.com/mjl-/mox/store"
@ -60,6 +61,12 @@ func setup(t *testing.T) (*store.Account, func()) {
mox.Context = ctxbg mox.Context = ctxbg
mox.ConfigStaticPath = filepath.FromSlash("../testdata/queue/mox.conf") mox.ConfigStaticPath = filepath.FromSlash("../testdata/queue/mox.conf")
mox.MustLoadConfig(true, false) mox.MustLoadConfig(true, false)
err := Init()
tcheck(t, err, "queue init")
err = mtastsdb.Init(false)
tcheck(t, err, "mtastsdb init")
err = tlsrptdb.Init()
tcheck(t, err, "tlsrptdb init")
acc, err := store.OpenAccount(log, "mjl") acc, err := store.OpenAccount(log, "mjl")
tcheck(t, err, "open account") tcheck(t, err, "open account")
err = acc.SetPassword(log, "testtest") err = acc.SetPassword(log, "testtest")
@ -72,6 +79,10 @@ func setup(t *testing.T) (*store.Account, func()) {
mox.ShutdownCancel() mox.ShutdownCancel()
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg) mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
Shutdown() Shutdown()
err := mtastsdb.Close()
tcheck(t, err, "mtastsdb close")
err = tlsrptdb.Close()
tcheck(t, err, "tlsrptdb close")
switchStop() switchStop()
} }
} }
@ -95,8 +106,6 @@ func prepareFile(t *testing.T) *os.File {
func TestQueue(t *testing.T) { func TestQueue(t *testing.T) {
acc, cleanup := setup(t) acc, cleanup := setup(t)
defer cleanup() defer cleanup()
err := Init()
tcheck(t, err, "queue init")
idfilter := func(msgID int64) Filter { idfilter := func(msgID int64) Filter {
return Filter{IDs: []int64{msgID}} return Filter{IDs: []int64{msgID}}
@ -951,8 +960,6 @@ func checkTLSResults(t *testing.T, policyDomain, expRecipientDomain string, expI
func TestRetiredHooks(t *testing.T) { func TestRetiredHooks(t *testing.T) {
_, cleanup := setup(t) _, cleanup := setup(t)
defer cleanup() defer cleanup()
err := Init()
tcheck(t, err, "queue init")
addr, err := smtp.ParseAddress("mjl@mox.example") addr, err := smtp.ParseAddress("mjl@mox.example")
tcheck(t, err, "parse address") tcheck(t, err, "parse address")
@ -1193,6 +1200,7 @@ func TestQueueStart(t *testing.T) {
<-done <-done
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg) mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
}() }()
Shutdown() // DB was opened already. Start will open it again. Just close it before.
err := Start(resolver, done) err := Start(resolver, done)
tcheck(t, err, "queue start") tcheck(t, err, "queue start")
@ -1284,8 +1292,6 @@ func TestQueueStart(t *testing.T) {
func TestListFilterSort(t *testing.T) { func TestListFilterSort(t *testing.T) {
_, cleanup := setup(t) _, cleanup := setup(t)
defer cleanup() defer cleanup()
err := Init()
tcheck(t, err, "queue init")
// insert Msgs. insert RetiredMsgs based on that. call list with filters and sort. filter to select a single. filter to paginate one by one, and in reverse. // insert Msgs. insert RetiredMsgs based on that. call list with filters and sort. filter to select a single. filter to paginate one by one, and in reverse.
@ -1301,7 +1307,7 @@ func TestListFilterSort(t *testing.T) {
qm1.Queued = now.Add(-time.Second) qm1.Queued = now.Add(-time.Second)
qm1.NextAttempt = now.Add(time.Minute) qm1.NextAttempt = now.Add(time.Minute)
qml := []Msg{qm, qm, qm, qm, qm, qm1} qml := []Msg{qm, qm, qm, qm, qm, qm1}
err = Add(ctxbg, pkglog, "mjl", mf, qml...) err := Add(ctxbg, pkglog, "mjl", mf, qml...)
tcheck(t, err, "add messages to queue") tcheck(t, err, "add messages to queue")
qm1 = qml[len(qml)-1] qm1 = qml[len(qml)-1]

View file

@ -10,8 +10,6 @@ import (
func TestSuppression(t *testing.T) { func TestSuppression(t *testing.T) {
_, cleanup := setup(t) _, cleanup := setup(t)
defer cleanup() defer cleanup()
err := Init()
tcheck(t, err, "queue init")
l, err := SuppressionList(ctxbg, "bogus") l, err := SuppressionList(ctxbg, "bogus")
tcheck(t, err, "listing suppressions for unknown account") tcheck(t, err, "listing suppressions for unknown account")

View file

@ -71,11 +71,15 @@ func start(mtastsdbRefresher, sendDMARCReports, sendTLSReports, skipForkExec boo
} }
if err := mtastsdb.Init(mtastsdbRefresher); err != nil { if err := mtastsdb.Init(mtastsdbRefresher); err != nil {
return fmt.Errorf("mtasts init: %s", err) return fmt.Errorf("mtastsdb init: %s", err)
} }
if err := tlsrptdb.Init(); err != nil { if err := tlsrptdb.Init(); err != nil {
return fmt.Errorf("tlsrpt init: %s", err) return fmt.Errorf("tlsrptdb init: %s", err)
}
if err := dmarcdb.Init(); err != nil {
return fmt.Errorf("dmarcdb init: %s", err)
} }
done := make(chan struct{}) // Goroutines for messages and webhooks, and cleaners. done := make(chan struct{}) // Goroutines for messages and webhooks, and cleaners.
@ -83,10 +87,6 @@ func start(mtastsdbRefresher, sendDMARCReports, sendTLSReports, skipForkExec boo
return fmt.Errorf("queue start: %s", err) return fmt.Errorf("queue start: %s", err)
} }
// dmarcdb starts after queue because it may start sending reports through the queue.
if err := dmarcdb.Init(); err != nil {
return fmt.Errorf("dmarc init: %s", err)
}
if sendDMARCReports { if sendDMARCReports {
dmarcdb.Start(dns.StrictResolver{Pkg: "dmarcdb"}) dmarcdb.Start(dns.StrictResolver{Pkg: "dmarcdb"})
} }

View file

@ -108,7 +108,8 @@ func TestReputation(t *testing.T) {
p := filepath.FromSlash("../testdata/smtpserver-reputation.db") p := filepath.FromSlash("../testdata/smtpserver-reputation.db")
defer os.Remove(p) defer os.Remove(p)
db, err := bstore.Open(ctxbg, p, &bstore.Options{Timeout: 5 * time.Second}, store.DBTypes...) opts := bstore.Options{Timeout: 5 * time.Second, RegisterLogger: log.Logger}
db, err := bstore.Open(ctxbg, p, &opts, store.DBTypes...)
tcheck(t, err, "open db") tcheck(t, err, "open db")
defer db.Close() defer db.Close()

View file

@ -105,11 +105,6 @@ func newTestServer(t *testing.T, configPath string, resolver dns.Resolver) *test
ts := testserver{t: t, cid: 1, resolver: resolver, tlsmode: smtpclient.TLSOpportunistic} ts := testserver{t: t, cid: 1, resolver: resolver, tlsmode: smtpclient.TLSOpportunistic}
if dmarcdb.EvalDB != nil {
dmarcdb.EvalDB.Close()
dmarcdb.EvalDB = nil
}
log := mlog.New("smtpserver", nil) log := mlog.New("smtpserver", nil)
mox.Context = ctxbg mox.Context = ctxbg
mox.ConfigStaticPath = configPath mox.ConfigStaticPath = configPath
@ -117,7 +112,11 @@ func newTestServer(t *testing.T, configPath string, resolver dns.Resolver) *test
dataDir := mox.ConfigDirPath(mox.Conf.Static.DataDir) dataDir := mox.ConfigDirPath(mox.Conf.Static.DataDir)
os.RemoveAll(dataDir) os.RemoveAll(dataDir)
var err error err := dmarcdb.Init()
tcheck(t, err, "dmarcdb init")
err = tlsrptdb.Init()
tcheck(t, err, "tlsrptdb init")
ts.acc, err = store.OpenAccount(log, "mjl") ts.acc, err = store.OpenAccount(log, "mjl")
tcheck(t, err, "open account") tcheck(t, err, "open account")
err = ts.acc.SetPassword(log, password0) err = ts.acc.SetPassword(log, password0)
@ -136,10 +135,14 @@ func (ts *testserver) close() {
if ts.acc == nil { if ts.acc == nil {
return return
} }
err := dmarcdb.Close()
tcheck(ts.t, err, "dmarcdb close")
err = tlsrptdb.Close()
tcheck(ts.t, err, "tlsrptdb close")
ts.comm.Unregister() ts.comm.Unregister()
queue.Shutdown() queue.Shutdown()
ts.switchStop() ts.switchStop()
err := ts.acc.Close() err = ts.acc.Close()
tcheck(ts.t, err, "closing account") tcheck(ts.t, err, "closing account")
ts.acc.CheckClosed() ts.acc.CheckClosed()
ts.acc = nil ts.acc = nil

View file

@ -907,7 +907,8 @@ func OpenAccountDB(log mlog.Log, accountDir, accountName string) (a *Account, re
os.MkdirAll(accountDir, 0770) os.MkdirAll(accountDir, 0770)
} }
db, err := bstore.Open(context.TODO(), dbpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
db, err := bstore.Open(context.TODO(), dbpath, &opts, DBTypes...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -121,7 +121,8 @@ func TestThreadingUpgrade(t *testing.T) {
// Now clear the threading upgrade, and the threading fields and close the account. // Now clear the threading upgrade, and the threading fields and close the account.
// We open the database file directly, so we don't trigger the consistency checker. // We open the database file directly, so we don't trigger the consistency checker.
db, err := bstore.Open(ctxbg, dbpath, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, DBTypes...) opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
db, err := bstore.Open(ctxbg, dbpath, &opts, DBTypes...)
err = db.Write(ctxbg, func(tx *bstore.Tx) error { err = db.Write(ctxbg, func(tx *bstore.Tx) error {
up := Upgrade{ID: 1} up := Upgrade{ID: 1}
err := tx.Delete(&up) err := tx.Delete(&up)

View file

@ -1,7 +1,11 @@
package tlsrptdb package tlsrptdb
import ( import (
"sync" "context"
"fmt"
"os"
"path/filepath"
"time"
"github.com/mjl-/bstore" "github.com/mjl-/bstore"
@ -12,7 +16,6 @@ import (
var ( var (
ReportDBTypes = []any{Record{}} ReportDBTypes = []any{Record{}}
ReportDB *bstore.DB ReportDB *bstore.DB
mutex sync.Mutex
// Accessed directly by tlsrptsend. // Accessed directly by tlsrptsend.
ResultDBTypes = []any{TLSResult{}, SuppressAddress{}} ResultDBTypes = []any{TLSResult{}, SuppressAddress{}}
@ -21,29 +24,48 @@ var (
// Init opens and possibly initializes the databases. // Init opens and possibly initializes the databases.
func Init() error { func Init() error {
if _, err := reportDB(mox.Shutdown); err != nil { if ReportDB != nil || ResultDB != nil {
return err return fmt.Errorf("already initialized")
} }
if _, err := resultDB(mox.Shutdown); err != nil {
return err log := mlog.New("tlsrptdb", nil)
var err error
ReportDB, err = openReportDB(mox.Shutdown, log)
if err != nil {
return fmt.Errorf("opening report db: %v", err)
}
ResultDB, err = openResultDB(mox.Shutdown, log)
if err != nil {
return fmt.Errorf("opening result db: %v", err)
} }
return nil return nil
} }
// Close closes the database connections. func openReportDB(ctx context.Context, log mlog.Log) (*bstore.DB, error) {
func Close() { p := mox.DataDirPath("tlsrpt.db")
log := mlog.New("tlsrptdb", nil) os.MkdirAll(filepath.Dir(p), 0770)
if ResultDB != nil { opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
err := ResultDB.Close() return bstore.Open(ctx, p, &opts, ReportDBTypes...)
log.Check(err, "closing result database")
ResultDB = nil
} }
mutex.Lock() func openResultDB(ctx context.Context, log mlog.Log) (*bstore.DB, error) {
defer mutex.Unlock() p := mox.DataDirPath("tlsrptresult.db")
if ReportDB != nil { os.MkdirAll(filepath.Dir(p), 0770)
err := ReportDB.Close() opts := bstore.Options{Timeout: 5 * time.Second, Perm: 0660, RegisterLogger: log.Logger}
log.Check(err, "closing report database") return bstore.Open(ctx, p, &opts, ResultDBTypes...)
}
// Close closes the database connections.
func Close() error {
if err := ResultDB.Close(); err != nil {
return fmt.Errorf("closing result db: %w", err)
}
ResultDB = nil
if err := ReportDB.Close(); err != nil {
return fmt.Errorf("closing report db: %w", err)
}
ReportDB = nil ReportDB = nil
} return nil
} }

View file

@ -5,8 +5,6 @@ import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"os"
"path/filepath"
"time" "time"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
@ -55,21 +53,6 @@ type Record struct {
Report tlsrpt.Report Report tlsrpt.Report
} }
func reportDB(ctx context.Context) (rdb *bstore.DB, rerr error) {
mutex.Lock()
defer mutex.Unlock()
if ReportDB == nil {
p := mox.DataDirPath("tlsrpt.db")
os.MkdirAll(filepath.Dir(p), 0770)
db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, ReportDBTypes...)
if err != nil {
return nil, err
}
ReportDB = db
}
return ReportDB, nil
}
// AddReport adds a TLS report to the database. // AddReport adds a TLS report to the database.
// //
// The report should have come in over SMTP, with a DKIM-validated // The report should have come in over SMTP, with a DKIM-validated
@ -82,17 +65,12 @@ func reportDB(ctx context.Context) (rdb *bstore.DB, rerr error) {
// //
// Prometheus metrics are updated only for configured domains. // Prometheus metrics are updated only for configured domains.
func AddReport(ctx context.Context, log mlog.Log, verifiedFromDomain dns.Domain, mailFrom string, hostReport bool, r *tlsrpt.Report) error { func AddReport(ctx context.Context, log mlog.Log, verifiedFromDomain dns.Domain, mailFrom string, hostReport bool, r *tlsrpt.Report) error {
db, err := reportDB(ctx)
if err != nil {
return err
}
if len(r.Policies) == 0 { if len(r.Policies) == 0 {
return fmt.Errorf("no policies in report") return fmt.Errorf("no policies in report")
} }
var inserted int var inserted int
return db.Write(ctx, func(tx *bstore.Tx) error { return ReportDB.Write(ctx, func(tx *bstore.Tx) error {
for _, p := range r.Policies { for _, p := range r.Policies {
pp := p.Policy pp := p.Policy
@ -132,22 +110,13 @@ func AddReport(ctx context.Context, log mlog.Log, verifiedFromDomain dns.Domain,
// Records returns all TLS reports in the database. // Records returns all TLS reports in the database.
func Records(ctx context.Context) ([]Record, error) { func Records(ctx context.Context) ([]Record, error) {
db, err := reportDB(ctx) return bstore.QueryDB[Record](ctx, ReportDB).List()
if err != nil {
return nil, err
}
return bstore.QueryDB[Record](ctx, db).List()
} }
// RecordID returns the report for the ID. // RecordID returns the report for the ID.
func RecordID(ctx context.Context, id int64) (Record, error) { func RecordID(ctx context.Context, id int64) (Record, error) {
db, err := reportDB(ctx)
if err != nil {
return Record{}, err
}
e := Record{ID: id} e := Record{ID: id}
err = db.Get(ctx, &e) err := ReportDB.Get(ctx, &e)
return e, err return e, err
} }
@ -155,12 +124,7 @@ func RecordID(ctx context.Context, id int64) (Record, error) {
// given policy domain. If policy domain is empty, records for all domains are // given policy domain. If policy domain is empty, records for all domains are
// returned. // returned.
func RecordsPeriodDomain(ctx context.Context, start, end time.Time, policyDomain dns.Domain) ([]Record, error) { func RecordsPeriodDomain(ctx context.Context, start, end time.Time, policyDomain dns.Domain) ([]Record, error) {
db, err := reportDB(ctx) q := bstore.QueryDB[Record](ctx, ReportDB)
if err != nil {
return nil, err
}
q := bstore.QueryDB[Record](ctx, db)
var zerodom dns.Domain var zerodom dns.Domain
if policyDomain != zerodom { if policyDomain != zerodom {
q.FilterNonzero(Record{Domain: policyDomain.Name()}) q.FilterNonzero(Record{Domain: policyDomain.Name()})

View file

@ -3,14 +3,11 @@ package tlsrptdb
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"path/filepath"
"time" "time"
"github.com/mjl-/bstore" "github.com/mjl-/bstore"
"github.com/mjl-/mox/dns" "github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mox-"
"github.com/mjl-/mox/tlsrpt" "github.com/mjl-/mox/tlsrpt"
) )
@ -70,33 +67,13 @@ type SuppressAddress struct {
Comment string Comment string
} }
func resultDB(ctx context.Context) (rdb *bstore.DB, rerr error) {
mutex.Lock()
defer mutex.Unlock()
if ResultDB == nil {
p := mox.DataDirPath("tlsrptresult.db")
os.MkdirAll(filepath.Dir(p), 0770)
db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, ResultDBTypes...)
if err != nil {
return nil, err
}
ResultDB = db
}
return ResultDB, nil
}
// AddTLSResults adds or merges all tls results for delivering to a policy domain, // AddTLSResults adds or merges all tls results for delivering to a policy domain,
// on its UTC day to a recipient domain to the database. Results may cause multiple // on its UTC day to a recipient domain to the database. Results may cause multiple
// separate reports to be sent. // separate reports to be sent.
func AddTLSResults(ctx context.Context, results []TLSResult) error { func AddTLSResults(ctx context.Context, results []TLSResult) error {
db, err := resultDB(ctx)
if err != nil {
return err
}
now := time.Now() now := time.Now()
err = db.Write(ctx, func(tx *bstore.Tx) error { err := ResultDB.Write(ctx, func(tx *bstore.Tx) error {
for _, result := range results { for _, result := range results {
// Ensure all slices are non-nil. We do this now so all readers will marshal to // Ensure all slices are non-nil. We do this now so all readers will marshal to
// compliant with the JSON schema. And also for consistent equality checks when // compliant with the JSON schema. And also for consistent equality checks when
@ -148,102 +125,57 @@ func AddTLSResults(ctx context.Context, results []TLSResult) error {
// Results returns all TLS results in the database, for all policy domains each // Results returns all TLS results in the database, for all policy domains each
// with potentially multiple days. Sorted by RecipientDomain and day. // with potentially multiple days. Sorted by RecipientDomain and day.
func Results(ctx context.Context) ([]TLSResult, error) { func Results(ctx context.Context) ([]TLSResult, error) {
db, err := resultDB(ctx) return bstore.QueryDB[TLSResult](ctx, ResultDB).SortAsc("PolicyDomain", "DayUTC", "RecipientDomain").List()
if err != nil {
return nil, err
}
return bstore.QueryDB[TLSResult](ctx, db).SortAsc("PolicyDomain", "DayUTC", "RecipientDomain").List()
} }
// ResultsDomain returns all TLSResults for a policy domain, potentially for // ResultsDomain returns all TLSResults for a policy domain, potentially for
// multiple days. // multiple days.
func ResultsPolicyDomain(ctx context.Context, policyDomain dns.Domain) ([]TLSResult, error) { func ResultsPolicyDomain(ctx context.Context, policyDomain dns.Domain) ([]TLSResult, error) {
db, err := resultDB(ctx) return bstore.QueryDB[TLSResult](ctx, ResultDB).FilterNonzero(TLSResult{PolicyDomain: policyDomain.Name()}).SortAsc("DayUTC", "RecipientDomain").List()
if err != nil {
return nil, err
}
return bstore.QueryDB[TLSResult](ctx, db).FilterNonzero(TLSResult{PolicyDomain: policyDomain.Name()}).SortAsc("DayUTC", "RecipientDomain").List()
} }
// ResultsRecipientDomain returns all TLSResults for a recipient domain, // ResultsRecipientDomain returns all TLSResults for a recipient domain,
// potentially for multiple days. // potentially for multiple days.
func ResultsRecipientDomain(ctx context.Context, recipientDomain dns.Domain) ([]TLSResult, error) { func ResultsRecipientDomain(ctx context.Context, recipientDomain dns.Domain) ([]TLSResult, error) {
db, err := resultDB(ctx) return bstore.QueryDB[TLSResult](ctx, ResultDB).FilterNonzero(TLSResult{RecipientDomain: recipientDomain.Name()}).SortAsc("DayUTC", "PolicyDomain").List()
if err != nil {
return nil, err
}
return bstore.QueryDB[TLSResult](ctx, db).FilterNonzero(TLSResult{RecipientDomain: recipientDomain.Name()}).SortAsc("DayUTC", "PolicyDomain").List()
} }
// RemoveResultsPolicyDomain removes all TLSResults for the policy domain on the // RemoveResultsPolicyDomain removes all TLSResults for the policy domain on the
// day from the database. // day from the database.
func RemoveResultsPolicyDomain(ctx context.Context, policyDomain dns.Domain, dayUTC string) error { func RemoveResultsPolicyDomain(ctx context.Context, policyDomain dns.Domain, dayUTC string) error {
db, err := resultDB(ctx) _, err := bstore.QueryDB[TLSResult](ctx, ResultDB).FilterNonzero(TLSResult{PolicyDomain: policyDomain.Name(), DayUTC: dayUTC}).Delete()
if err != nil {
return err
}
_, err = bstore.QueryDB[TLSResult](ctx, db).FilterNonzero(TLSResult{PolicyDomain: policyDomain.Name(), DayUTC: dayUTC}).Delete()
return err return err
} }
// RemoveResultsRecipientDomain removes all TLSResults for the recipient domain on // RemoveResultsRecipientDomain removes all TLSResults for the recipient domain on
// the day from the database. // the day from the database.
func RemoveResultsRecipientDomain(ctx context.Context, recipientDomain dns.Domain, dayUTC string) error { func RemoveResultsRecipientDomain(ctx context.Context, recipientDomain dns.Domain, dayUTC string) error {
db, err := resultDB(ctx) _, err := bstore.QueryDB[TLSResult](ctx, ResultDB).FilterNonzero(TLSResult{RecipientDomain: recipientDomain.Name(), DayUTC: dayUTC}).Delete()
if err != nil {
return err
}
_, err = bstore.QueryDB[TLSResult](ctx, db).FilterNonzero(TLSResult{RecipientDomain: recipientDomain.Name(), DayUTC: dayUTC}).Delete()
return err return err
} }
// SuppressAdd adds an address to the suppress list. // SuppressAdd adds an address to the suppress list.
func SuppressAdd(ctx context.Context, ba *SuppressAddress) error { func SuppressAdd(ctx context.Context, ba *SuppressAddress) error {
db, err := resultDB(ctx) return ResultDB.Insert(ctx, ba)
if err != nil {
return err
}
return db.Insert(ctx, ba)
} }
// SuppressList returns all reporting addresses on the suppress list. // SuppressList returns all reporting addresses on the suppress list.
func SuppressList(ctx context.Context) ([]SuppressAddress, error) { func SuppressList(ctx context.Context) ([]SuppressAddress, error) {
db, err := resultDB(ctx) return bstore.QueryDB[SuppressAddress](ctx, ResultDB).SortDesc("ID").List()
if err != nil {
return nil, err
}
return bstore.QueryDB[SuppressAddress](ctx, db).SortDesc("ID").List()
} }
// SuppressRemove removes a reporting address record from the suppress list. // SuppressRemove removes a reporting address record from the suppress list.
func SuppressRemove(ctx context.Context, id int64) error { func SuppressRemove(ctx context.Context, id int64) error {
db, err := resultDB(ctx) return ResultDB.Delete(ctx, &SuppressAddress{ID: id})
if err != nil {
return err
}
return db.Delete(ctx, &SuppressAddress{ID: id})
} }
// SuppressUpdate updates the until field of a reporting address record. // SuppressUpdate updates the until field of a reporting address record.
func SuppressUpdate(ctx context.Context, id int64, until time.Time) error { func SuppressUpdate(ctx context.Context, id int64, until time.Time) error {
db, err := resultDB(ctx)
if err != nil {
return err
}
ba := SuppressAddress{ID: id} ba := SuppressAddress{ID: id}
err = db.Get(ctx, &ba) err := ResultDB.Get(ctx, &ba)
if err != nil { if err != nil {
return err return err
} }
ba.Until = until ba.Until = until
return db.Update(ctx, &ba) return ResultDB.Update(ctx, &ba)
} }

View file

@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"os" "os"
"reflect" "reflect"
"sort" "sort"
@ -50,6 +51,17 @@ var errSchemaCheck = errors.New("schema check")
// to "changed", an error is returned if there is no schema change. If it is set to // to "changed", an error is returned if there is no schema change. If it is set to
// "unchanged", an error is returned if there was a schema change. // "unchanged", an error is returned if there was a schema change.
func (db *DB) Register(ctx context.Context, typeValues ...any) error { func (db *DB) Register(ctx context.Context, typeValues ...any) error {
return db.register(ctx, slog.New(discardHandler{}), typeValues...)
}
type discardHandler struct{}
func (l discardHandler) Enabled(context.Context, slog.Level) bool { return false }
func (l discardHandler) Handle(context.Context, slog.Record) error { return nil }
func (l discardHandler) WithAttrs(attrs []slog.Attr) slog.Handler { return l }
func (l discardHandler) WithGroup(name string) slog.Handler { return l }
func (db *DB) register(ctx context.Context, log *slog.Logger, typeValues ...any) error {
// We will drop/create new indices as needed. For changed indices, we drop // We will drop/create new indices as needed. For changed indices, we drop
// and recreate. E.g. if an index becomes a unique index, or if a field in // and recreate. E.g. if an index becomes a unique index, or if a field in
// an index changes. These values map type and index name to their index. // an index changes. These values map type and index name to their index.
@ -148,6 +160,9 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error {
tv.Version = 1 tv.Version = 1
if st.Current != nil { if st.Current != nil {
tv.Version = st.Current.Version + 1 tv.Version = st.Current.Version + 1
log.Debug("updating schema for type", slog.String("type", tv.name), slog.Uint64("version", uint64(tv.Version)))
} else {
log.Debug("adding schema for new type", slog.String("type", tv.name), slog.Uint64("version", uint64(tv.Version)))
} }
k, v, err := packSchema(tv) k, v, err := packSchema(tv)
if err != nil { if err != nil {
@ -299,6 +314,7 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error {
} }
foundField = true foundField = true
log.Debug("verifying foreign key constraint for new reference", slog.String("fromtype", ntname), slog.String("fromfield", f.Name), slog.String("totype", name))
// For newly added references, check they are valid. // For newly added references, check they are valid.
b, err := tx.recordsBucket(ntname, ntv.fillPercent) b, err := tx.recordsBucket(ntname, ntv.fillPercent)
@ -396,6 +412,7 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error {
db.types[st.Type] = *st db.types[st.Type] = *st
db.typeNames[st.Name] = *st db.typeNames[st.Name] = *st
ntvp = &ntv ntvp = &ntv
log.Debug("updating schema for type due to new incoming reference", slog.String("type", name), slog.Uint64("version", uint64(ntv.Version)))
} }
k, v, err := packSchema(ntvp) k, v, err := packSchema(ntvp)
@ -453,6 +470,7 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error {
if !drop { if !drop {
continue continue
} }
log.Debug("dropping old/modified index", slog.String("type", name), slog.String("indexname", iname))
b, err := tx.typeBucket(name) b, err := tx.typeBucket(name)
if err != nil { if err != nil {
return err return err
@ -482,6 +500,7 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error {
if !create { if !create {
continue continue
} }
log.Debug("preparing for new/modified index for type", slog.String("type", name), slog.String("indexname", iname))
b, err := tx.typeBucket(name) b, err := tx.typeBucket(name)
if err != nil { if err != nil {
return err return err
@ -578,6 +597,8 @@ func (db *DB) Register(ctx context.Context, typeValues ...any) error {
return nil return nil
} }
log.Debug("creating new/modified indexes for type", slog.String("type", name))
// Now do all sorts + inserts. // Now do all sorts + inserts.
for i, ib := range ibs { for i, ib := range ibs {
idx := idxs[i] idx := idxs[i]

View file

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"log/slog"
"os" "os"
"reflect" "reflect"
"sync" "sync"
@ -285,6 +286,7 @@ type Options struct {
Timeout time.Duration // Abort if opening DB takes longer than Timeout. If not set, the deadline from the context is used. Timeout time.Duration // Abort if opening DB takes longer than Timeout. If not set, the deadline from the context is used.
Perm fs.FileMode // Permissions for new file if created. If zero, 0600 is used. Perm fs.FileMode // Permissions for new file if created. If zero, 0600 is used.
MustExist bool // Before opening, check that file exists. If not, io/fs.ErrNotExist is returned. MustExist bool // Before opening, check that file exists. If not, io/fs.ErrNotExist is returned.
RegisterLogger *slog.Logger // For debug logging about schema upgrades.
} }
// Open opens a bstore database and registers types by calling Register. // Open opens a bstore database and registers types by calling Register.
@ -326,7 +328,16 @@ func Open(ctx context.Context, path string, opts *Options, typeValues ...any) (*
typeNames := map[string]storeType{} typeNames := map[string]storeType{}
types := map[reflect.Type]storeType{} types := map[reflect.Type]storeType{}
db := &DB{bdb: bdb, typeNames: typeNames, types: types} db := &DB{bdb: bdb, typeNames: typeNames, types: types}
if err := db.Register(ctx, typeValues...); err != nil { var log *slog.Logger
if opts != nil {
log = opts.RegisterLogger
}
if log == nil {
log = slog.New(discardHandler{})
} else {
log = log.With("dbpath", path)
}
if err := db.register(ctx, log, typeValues...); err != nil {
bdb.Close() bdb.Close()
return nil, err return nil, err
} }

2
vendor/modules.txt vendored
View file

@ -16,7 +16,7 @@ github.com/mjl-/adns/internal/singleflight
# github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05 # github.com/mjl-/autocert v0.0.0-20231214125928-31b7400acb05
## explicit; go 1.20 ## explicit; go 1.20
github.com/mjl-/autocert github.com/mjl-/autocert
# github.com/mjl-/bstore v0.0.5 # github.com/mjl-/bstore v0.0.6
## explicit; go 1.19 ## explicit; go 1.19
github.com/mjl-/bstore github.com/mjl-/bstore
# github.com/mjl-/sconf v0.0.6 # github.com/mjl-/sconf v0.0.6

View file

@ -120,7 +120,8 @@ possibly making them potentially no longer readable by the previous version.
checkf(err, path, "reading bolt database") checkf(err, path, "reading bolt database")
bdb.Close() bdb.Close()
db, err := bstore.Open(ctxbg, path, nil, types...) opts := bstore.Options{RegisterLogger: c.log.Logger}
db, err := bstore.Open(ctxbg, path, &opts, types...)
checkf(err, path, "open database with bstore") checkf(err, path, "open database with bstore")
if err != nil { if err != nil {
return return
@ -162,7 +163,8 @@ possibly making them potentially no longer readable by the previous version.
// Check that all messages present in the database also exist on disk. // Check that all messages present in the database also exist on disk.
seen := map[string]struct{}{} seen := map[string]struct{}{}
db, err := bstore.Open(ctxbg, dbpath, &bstore.Options{MustExist: true}, queue.DBTypes...) opts := bstore.Options{MustExist: true, RegisterLogger: c.log.Logger}
db, err := bstore.Open(ctxbg, dbpath, &opts, queue.DBTypes...)
checkf(err, dbpath, "opening queue database to check messages") checkf(err, dbpath, "opening queue database to check messages")
if err == nil { if err == nil {
err := bstore.QueryDB[queue.Msg](ctxbg, db).ForEach(func(m queue.Msg) error { err := bstore.QueryDB[queue.Msg](ctxbg, db).ForEach(func(m queue.Msg) error {
@ -237,7 +239,8 @@ possibly making them potentially no longer readable by the previous version.
// And check consistency of UIDs with the mailbox UIDNext, and check UIDValidity. // And check consistency of UIDs with the mailbox UIDNext, and check UIDValidity.
seen := map[string]struct{}{} seen := map[string]struct{}{}
dbpath := filepath.Join(accdir, "index.db") dbpath := filepath.Join(accdir, "index.db")
db, err := bstore.Open(ctxbg, dbpath, &bstore.Options{MustExist: true}, store.DBTypes...) opts := bstore.Options{MustExist: true, RegisterLogger: c.log.Logger}
db, err := bstore.Open(ctxbg, dbpath, &opts, store.DBTypes...)
checkf(err, dbpath, "opening account database to check messages") checkf(err, dbpath, "opening account database to check messages")
if err == nil { if err == nil {
uidvalidity := store.NextUIDValidity{ID: 1} uidvalidity := store.NextUIDValidity{ID: 1}

View file

@ -17,6 +17,7 @@ import (
"github.com/mjl-/mox/dns" "github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mlog" "github.com/mjl-/mox/mlog"
"github.com/mjl-/mox/mox-" "github.com/mjl-/mox/mox-"
"github.com/mjl-/mox/mtastsdb"
"github.com/mjl-/mox/queue" "github.com/mjl-/mox/queue"
"github.com/mjl-/mox/store" "github.com/mjl-/mox/store"
) )
@ -58,6 +59,8 @@ func TestAPI(t *testing.T) {
defer store.Switchboard()() defer store.Switchboard()()
log := mlog.New("webmail", nil) log := mlog.New("webmail", nil)
err := mtastsdb.Init(false)
tcheck(t, err, "mtastsdb init")
acc, err := store.OpenAccount(log, "mjl") acc, err := store.OpenAccount(log, "mjl")
tcheck(t, err, "open account") tcheck(t, err, "open account")
const pw0 = "te\u0301st \u00a0\u2002\u200a" // NFD and various unicode spaces. const pw0 = "te\u0301st \u00a0\u2002\u200a" // NFD and various unicode spaces.
@ -65,7 +68,9 @@ func TestAPI(t *testing.T) {
err = acc.SetPassword(log, pw0) err = acc.SetPassword(log, pw0)
tcheck(t, err, "set password") tcheck(t, err, "set password")
defer func() { defer func() {
err := acc.Close() err := mtastsdb.Close()
tcheck(t, err, "mtastsdb close")
err = acc.Close()
pkglog.Check(err, "closing account") pkglog.Check(err, "closing account")
acc.CheckClosed() acc.CheckClosed()
}() }()