diff --git a/ctl.go b/ctl.go index 0efeb2a..81f73e4 100644 --- a/ctl.go +++ b/ctl.go @@ -1348,6 +1348,8 @@ func servectlcmd(ctx context.Context, ctl *ctl, shutdown func()) { } }() + // todo: can we retrain an account without holding a write lock? perhaps by writing a junkfilter to a new location, and staying informed of message changes while we go through all messages in the account? + acc.WithWLock(func() { conf, _ := acc.Conf() if conf.JunkFilter == nil { diff --git a/junk/filter.go b/junk/filter.go index 387ea49..d37e53c 100644 --- a/junk/filter.go +++ b/junk/filter.go @@ -298,7 +298,7 @@ func (f *Filter) Save() error { } else if err != nil { return err } - return tx.Update(&wordscore{w, wc.Ham + ham, wc.Spam + spam}) + return tx.Update(&wordscore{w, ham, spam}) } if err := update("-", f.hams, f.spams); err != nil { return fmt.Errorf("storing total ham/spam message count: %s", err) @@ -621,10 +621,16 @@ func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{ // Modify the message count. f.modified = true + var fv *uint32 if ham { - f.hams-- + fv = &f.hams } else { - f.spams-- + fv = &f.spams + } + if *fv == 0 { + f.log.Error("attempt to decrease ham/spam message count while already zero", slog.Bool("ham", ham)) + } else { + *fv -= 1 } // Decrease the word counts. @@ -633,10 +639,16 @@ func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{ if !ok { continue } + var v *uint32 if ham { - c.Ham-- + v = &c.Ham } else { - c.Spam-- + v = &c.Spam + } + if *v == 0 { + f.log.Error("attempt to decrease ham/spam word count while already zero", slog.String("word", w), slog.Bool("ham", ham)) + } else { + *v -= 1 } f.cache[w] = c f.changed[w] = c diff --git a/junk/filter_test.go b/junk/filter_test.go index bb67c0f..6aea0ef 100644 --- a/junk/filter_test.go +++ b/junk/filter_test.go @@ -126,7 +126,7 @@ func TestFilter(t *testing.T) { tcheck(t, err, "train spam message") _, err = spamf.Seek(0, 0) tcheck(t, err, "seek spam message") - err = f.TrainMessage(ctxbg, spamf, spamsize, true) + err = f.TrainMessage(ctxbg, spamf, spamsize, false) tcheck(t, err, "train spam message") if !f.modified { @@ -166,16 +166,16 @@ func TestFilter(t *testing.T) { tcheck(t, err, "untrain ham message") _, err = hamf.Seek(0, 0) tcheck(t, err, "seek ham message") - err = f.UntrainMessage(ctxbg, hamf, spamsize, true) + err = f.UntrainMessage(ctxbg, hamf, hamsize, true) tcheck(t, err, "untrain ham message") _, err = spamf.Seek(0, 0) tcheck(t, err, "seek spam message") - err = f.UntrainMessage(ctxbg, spamf, spamsize, true) + err = f.UntrainMessage(ctxbg, spamf, spamsize, false) tcheck(t, err, "untrain spam message") _, err = spamf.Seek(0, 0) tcheck(t, err, "seek spam message") - err = f.UntrainMessage(ctxbg, spamf, spamsize, true) + err = f.UntrainMessage(ctxbg, spamf, spamsize, false) tcheck(t, err, "untrain spam message") if !f.modified {