package mtastsdb import ( "context" "errors" "fmt" mathrand "math/rand" "runtime/debug" "time" "golang.org/x/exp/slog" "github.com/mjl-/bstore" "github.com/mjl-/mox/dns" "github.com/mjl-/mox/metrics" "github.com/mjl-/mox/mlog" "github.com/mjl-/mox/mox-" "github.com/mjl-/mox/mtasts" ) func refresh() int { interval := 24 * time.Hour ticker := time.NewTicker(interval) defer ticker.Stop() var refreshed int // Pro-actively refresh policies every 24 hours. ../rfc/8461:583 for { ticker.Reset(interval) log := mlog.New("mtastsdb", nil).WithCid(mox.Cid()) n, err := refresh1(mox.Context, log, dns.StrictResolver{Pkg: "mtastsdb"}, time.Sleep) log.Check(err, "periodic refresh of cached mtasts policies") if n > 0 { refreshed += n } select { case <-mox.Shutdown.Done(): return refreshed case <-ticker.C: } } } // refresh policies that have not been updated in the past 12 hours and remove // policies not used for 180 days. We start with the first domain immediately, so // an admin can see any (configuration) issues that are logged. We spread the // refreshes evenly over the next 3 hours, randomizing the domains, and we add some // jitter to the timing. Each refresh is done in a new goroutine, so a single slow // refresh doesn't mess up the timing. func refresh1(ctx context.Context, log mlog.Log, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) { db, err := database(ctx) if err != nil { return 0, err } now := timeNow() qdel := bstore.QueryDB[PolicyRecord](ctx, db) qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour)) if _, err := qdel.Delete(); err != nil { return 0, fmt.Errorf("deleting old unused policies: %s", err) } qup := bstore.QueryDB[PolicyRecord](ctx, db) qup.FilterLess("LastUpdate", now.Add(-12*time.Hour)) prs, err := qup.List() if err != nil { return 0, fmt.Errorf("querying policies to refresh: %s", err) } if len(prs) == 0 { // Nothing to do. return 0, nil } // Randomize list. rand := mathrand.New(mathrand.NewSource(time.Now().UnixNano())) for i := range prs { if i == 0 { continue } j := rand.Intn(i + 1) prs[i], prs[j] = prs[j], prs[i] } // Launch goroutine with the refresh. log.Debug("will refresh mta-sts policies over next 3 hours", slog.Int("count", len(prs))) start := timeNow() for i, pr := range prs { go refreshDomain(ctx, log, db, resolver, pr) if i < len(prs)-1 { interval := 3 * int64(time.Hour) / int64(len(prs)-1) extra := time.Duration(rand.Int63n(interval) - interval/2) next := start.Add(time.Duration(int64(i+1)*interval) + extra) d := next.Sub(timeNow()) if d > 0 { sleep(d) } } } return len(prs), nil } func refreshDomain(ctx context.Context, log mlog.Log, db *bstore.DB, resolver dns.Resolver, pr PolicyRecord) { defer func() { x := recover() if x != nil { // Should not happen, but make sure errors don't take down the application. log.Error("refresh1", slog.Any("panic", x)) debug.PrintStack() metrics.PanicInc(metrics.Mtastsdb) } }() ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() d, err := dns.ParseDomain(pr.Domain) if err != nil { log.Errorx("refreshing mta-sts policy: parsing policy domain", err, slog.Any("domain", d)) return } log.Debug("refreshing mta-sts policy for domain", slog.Any("domain", d)) record, _, err := mtasts.LookupRecord(ctx, log.Logger, resolver, d) if err == nil && record.ID == pr.RecordID { qup := bstore.QueryDB[PolicyRecord](ctx, db) qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate}) now := timeNow() update := PolicyRecord{ LastUpdate: now, ValidEnd: now.Add(time.Duration(pr.MaxAgeSeconds) * time.Second), } if n, err := qup.UpdateNonzero(update); err != nil { log.Errorx("updating refreshed, unmodified policy in database", err) } else if n != 1 { log.Info("expected to update 1 policy after refresh", slog.Int("count", n)) } return } if err != nil && pr.Mode == mtasts.ModeNone { if errors.Is(err, mtasts.ErrNoRecord) { // Policy was in mode "none". Now it doesn't have a policy anymore. Remove from our // database so we don't keep refreshing it. err := db.Delete(ctx, &pr) log.Check(err, "removing mta-sts policy with mode none, dns record is gone") } // Else, don't bother operator with temporary error about policy none. // ../rfc/8461:587 return } else if err != nil { log.Errorx("looking up mta-sts record for domain", err, slog.Any("domain", d)) // Try to fetch new policy. It could be just DNS that is down. We don't want to let our policy expire. } p, _, err := mtasts.FetchPolicy(ctx, log.Logger, d) if err != nil { if !errors.Is(err, mtasts.ErrNoPolicy) || pr.Mode != mtasts.ModeNone { log.Errorx("refreshing mtasts policy for domain", err, slog.Any("domain", d)) } return } now := timeNow() update := map[string]any{ "LastUpdate": now, "ValidEnd": now.Add(time.Duration(p.MaxAgeSeconds) * time.Second), "Backoff": false, "Policy": *p, } if record != nil { update["RecordID"] = record.ID } qup := bstore.QueryDB[PolicyRecord](ctx, db) qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate}) if n, err := qup.UpdateFields(update); err != nil { log.Errorx("updating refreshed, modified policy in database", err) } else if n != 1 { log.Info("updating refreshed, did not update 1 policy", slog.Int("count", n)) } }