mirror of
https://github.com/mjl-/mox.git
synced 2025-01-15 18:06:27 +03:00
202 lines
5.4 KiB
Go
202 lines
5.4 KiB
Go
|
package junk
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"math"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/mjl-/mox/mlog"
|
||
|
)
|
||
|
|
||
|
func tcheck(t *testing.T, err error, msg string) {
|
||
|
t.Helper()
|
||
|
if err != nil {
|
||
|
t.Fatalf("%s: %s", msg, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func tlistdir(t *testing.T, name string) []string {
|
||
|
t.Helper()
|
||
|
l, err := os.ReadDir(name)
|
||
|
tcheck(t, err, "readdir")
|
||
|
names := make([]string, len(l))
|
||
|
for i, e := range l {
|
||
|
names[i] = e.Name()
|
||
|
}
|
||
|
return names
|
||
|
}
|
||
|
|
||
|
func TestFilter(t *testing.T) {
|
||
|
log := mlog.New("junk")
|
||
|
params := Params{
|
||
|
Onegrams: true,
|
||
|
Twograms: true,
|
||
|
Threegrams: false,
|
||
|
MaxPower: 0.1,
|
||
|
TopWords: 10,
|
||
|
IgnoreWords: 0.1,
|
||
|
RareWords: 1,
|
||
|
}
|
||
|
dbPath := "../testdata/junk/filter.db"
|
||
|
bloomPath := "../testdata/junk/filter.bloom"
|
||
|
os.Remove(dbPath)
|
||
|
os.Remove(bloomPath)
|
||
|
f, err := NewFilter(log, params, dbPath, bloomPath)
|
||
|
tcheck(t, err, "new filter")
|
||
|
err = f.Close()
|
||
|
tcheck(t, err, "close filter")
|
||
|
|
||
|
f, err = OpenFilter(log, params, dbPath, bloomPath, true)
|
||
|
tcheck(t, err, "open filter")
|
||
|
|
||
|
// Ensure these dirs exist. Developers should bring their own ham/spam example
|
||
|
// emails.
|
||
|
os.MkdirAll("../testdata/train/ham", 0770)
|
||
|
os.MkdirAll("../testdata/train/spam", 0770)
|
||
|
|
||
|
hamdir := "../testdata/train/ham"
|
||
|
spamdir := "../testdata/train/spam"
|
||
|
hamfiles := tlistdir(t, hamdir)
|
||
|
if len(hamfiles) > 100 {
|
||
|
hamfiles = hamfiles[:100]
|
||
|
}
|
||
|
spamfiles := tlistdir(t, spamdir)
|
||
|
if len(spamfiles) > 100 {
|
||
|
spamfiles = spamfiles[:100]
|
||
|
}
|
||
|
|
||
|
err = f.TrainDirs(hamdir, "", spamdir, hamfiles, nil, spamfiles)
|
||
|
tcheck(t, err, "train dirs")
|
||
|
|
||
|
if len(hamfiles) == 0 || len(spamfiles) == 0 {
|
||
|
fmt.Println("not training, no ham and/or spam messages, add them to testdata/train/ham and testdata/train/spam")
|
||
|
return
|
||
|
}
|
||
|
|
||
|
prob, _, _, _, err := f.ClassifyMessagePath(filepath.Join(hamdir, hamfiles[0]))
|
||
|
tcheck(t, err, "classify ham message")
|
||
|
if prob > 0.1 {
|
||
|
t.Fatalf("trained ham file has prob %v, expected <= 0.1", prob)
|
||
|
}
|
||
|
|
||
|
prob, _, _, _, err = f.ClassifyMessagePath(filepath.Join(spamdir, spamfiles[0]))
|
||
|
tcheck(t, err, "classify spam message")
|
||
|
if prob < 0.9 {
|
||
|
t.Fatalf("trained spam file has prob %v, expected > 0.9", prob)
|
||
|
}
|
||
|
|
||
|
err = f.Close()
|
||
|
tcheck(t, err, "close filter")
|
||
|
|
||
|
// Start again with empty filter. We'll train a few messages and check they are
|
||
|
// classified as ham/spam. Then we untrain to see they are no longer classified.
|
||
|
os.Remove(dbPath)
|
||
|
os.Remove(bloomPath)
|
||
|
f, err = NewFilter(log, params, dbPath, bloomPath)
|
||
|
tcheck(t, err, "open filter")
|
||
|
|
||
|
hamf, err := os.Open(filepath.Join(hamdir, hamfiles[0]))
|
||
|
tcheck(t, err, "open hamfile")
|
||
|
defer hamf.Close()
|
||
|
hamstat, err := hamf.Stat()
|
||
|
tcheck(t, err, "stat hamfile")
|
||
|
hamsize := hamstat.Size()
|
||
|
|
||
|
spamf, err := os.Open(filepath.Join(spamdir, spamfiles[0]))
|
||
|
tcheck(t, err, "open spamfile")
|
||
|
defer spamf.Close()
|
||
|
spamstat, err := spamf.Stat()
|
||
|
tcheck(t, err, "stat spamfile")
|
||
|
spamsize := spamstat.Size()
|
||
|
|
||
|
// Train each message twice, to prevent single occurrences from being ignored.
|
||
|
err = f.TrainMessage(hamf, hamsize, true)
|
||
|
tcheck(t, err, "train ham message")
|
||
|
_, err = hamf.Seek(0, 0)
|
||
|
tcheck(t, err, "seek ham message")
|
||
|
err = f.TrainMessage(hamf, hamsize, true)
|
||
|
tcheck(t, err, "train ham message")
|
||
|
|
||
|
err = f.TrainMessage(spamf, spamsize, false)
|
||
|
tcheck(t, err, "train spam message")
|
||
|
_, err = spamf.Seek(0, 0)
|
||
|
tcheck(t, err, "seek spam message")
|
||
|
err = f.TrainMessage(spamf, spamsize, true)
|
||
|
tcheck(t, err, "train spam message")
|
||
|
|
||
|
if !f.modified {
|
||
|
t.Fatalf("filter not modified after training")
|
||
|
}
|
||
|
if !f.bloom.Modified() {
|
||
|
t.Fatalf("bloom filter not modified after training")
|
||
|
}
|
||
|
|
||
|
err = f.Save()
|
||
|
tcheck(t, err, "save filter")
|
||
|
if f.modified || f.bloom.Modified() {
|
||
|
t.Fatalf("filter or bloom filter still modified after save")
|
||
|
}
|
||
|
|
||
|
// Classify and verify.
|
||
|
_, err = hamf.Seek(0, 0)
|
||
|
tcheck(t, err, "seek ham message")
|
||
|
prob, _, _, _, err = f.ClassifyMessageReader(hamf, hamsize)
|
||
|
tcheck(t, err, "classify ham")
|
||
|
if prob > 0.1 {
|
||
|
t.Fatalf("got prob %v, expected <= 0.1", prob)
|
||
|
}
|
||
|
|
||
|
_, err = spamf.Seek(0, 0)
|
||
|
tcheck(t, err, "seek spam message")
|
||
|
prob, _, _, _, err = f.ClassifyMessageReader(spamf, spamsize)
|
||
|
tcheck(t, err, "classify spam")
|
||
|
if prob < 0.9 {
|
||
|
t.Fatalf("got prob %v, expected >= 0.9", prob)
|
||
|
}
|
||
|
|
||
|
// Untrain ham & spam.
|
||
|
_, err = hamf.Seek(0, 0)
|
||
|
tcheck(t, err, "seek ham message")
|
||
|
err = f.UntrainMessage(hamf, hamsize, true)
|
||
|
tcheck(t, err, "untrain ham message")
|
||
|
_, err = hamf.Seek(0, 0)
|
||
|
tcheck(t, err, "seek ham message")
|
||
|
err = f.UntrainMessage(hamf, spamsize, true)
|
||
|
tcheck(t, err, "untrain ham message")
|
||
|
|
||
|
_, err = spamf.Seek(0, 0)
|
||
|
tcheck(t, err, "seek spam message")
|
||
|
err = f.UntrainMessage(spamf, spamsize, true)
|
||
|
tcheck(t, err, "untrain spam message")
|
||
|
_, err = spamf.Seek(0, 0)
|
||
|
tcheck(t, err, "seek spam message")
|
||
|
err = f.UntrainMessage(spamf, spamsize, true)
|
||
|
tcheck(t, err, "untrain spam message")
|
||
|
|
||
|
if !f.modified {
|
||
|
t.Fatalf("filter not modified after untraining")
|
||
|
}
|
||
|
|
||
|
// Classify again, should be unknown.
|
||
|
_, err = hamf.Seek(0, 0)
|
||
|
tcheck(t, err, "seek ham message")
|
||
|
prob, _, _, _, err = f.ClassifyMessageReader(hamf, hamsize)
|
||
|
tcheck(t, err, "classify ham")
|
||
|
if math.Abs(prob-0.5) > 0.1 {
|
||
|
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
|
||
|
}
|
||
|
|
||
|
_, err = spamf.Seek(0, 0)
|
||
|
tcheck(t, err, "seek spam message")
|
||
|
prob, _, _, _, err = f.ClassifyMessageReader(spamf, spamsize)
|
||
|
tcheck(t, err, "classify spam")
|
||
|
if math.Abs(prob-0.5) > 0.1 {
|
||
|
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
|
||
|
}
|
||
|
|
||
|
err = f.Close()
|
||
|
tcheck(t, err, "close filter")
|
||
|
}
|