package mtastsdb import ( "context" "crypto/ed25519" cryptorand "crypto/rand" "crypto/tls" "crypto/x509" "errors" "fmt" "io" golog "log" "math/big" "net" "net/http" "os" "path/filepath" "sync" "testing" "time" "github.com/mjl-/bstore" "github.com/mjl-/mox/dns" "github.com/mjl-/mox/mlog" "github.com/mjl-/mox/mox-" "github.com/mjl-/mox/mtasts" ) var ctxbg = context.Background() func TestRefresh(t *testing.T) { mox.Shutdown = ctxbg mox.ConfigStaticPath = filepath.FromSlash("../testdata/mtasts/fake.conf") mox.Conf.Static.DataDir = "." dbpath := mox.DataDirPath("mtasts.db") os.MkdirAll(filepath.Dir(dbpath), 0770) os.Remove(dbpath) defer os.Remove(dbpath) log := mlog.New("mtastsdb", nil) err := Init(false) tcheckf(t, err, "init database") defer func() { err := Close() tcheckf(t, err, "close database") }() cert := fakeCert(t, false) defer func() { mtasts.HTTPClient.Transport = nil }() insert := func(domain string, validEnd, lastUpdate, lastUse time.Time, backoff bool, recordID string, mode mtasts.Mode, maxAge int, mx string) { t.Helper() mxd, err := dns.ParseDomain(mx) if err != nil { t.Fatalf("parsing mx domain %q: %s", mx, err) } policy := mtasts.Policy{ Version: "STSv1", Mode: mode, MX: []mtasts.MX{{Wildcard: false, Domain: mxd}}, MaxAgeSeconds: maxAge, Extensions: nil, } pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy, policy.String()} if err := DB.Insert(ctxbg, &pr); err != nil { t.Fatalf("insert policy: %s", err) } } now := time.Now() // Updated just now. insert("mox.example", now.Add(24*time.Hour), now, now, false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com") // To be removed. insert("stale.mox.example", now.Add(-time.Hour), now, now.Add(-181*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com") // To be refreshed, same id. insert("refresh.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com") // To be refreshed and succeed. insert("policyok.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com") // To be refreshed and fail to fetch. insert("policybad.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com") resolver := dns.MockResolver{ TXT: map[string][]string{ "_mta-sts.refresh.mox.example.": {"v=STSv1; id=1"}, "_mta-sts.policyok.mox.example.": {"v=STSv1; id=2"}, "_mta-sts.policybad.mox.example.": {"v=STSv1; id=2"}, }, } pool := x509.NewCertPool() pool.AddCert(cert.Leaf) l := newPipeListener() defer l.Close() go func() { mux := &http.ServeMux{} mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) { if r.Host == "mta-sts.policybad.mox.example" { w.WriteHeader(500) return } fmt.Fprintf(w, "version: STSv1\nmode: enforce\nmx: mx.mox.example.com\nmax_age: 3600\n") }) s := &http.Server{ Handler: mux, TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, }, ErrorLog: golog.New(io.Discard, "", 0), } s.ServeTLS(l, "", "") }() mtasts.HTTPClient.Transport = &http.Transport{ Dial: func(network, addr string) (net.Conn, error) { return l.Dial() }, TLSClientConfig: &tls.Config{ RootCAs: pool, }, } slept := 0 sleep := func(d time.Duration) { slept++ interval := 3 * time.Hour / 2 if d < time.Duration(slept)*interval-interval/2 || d > time.Duration(slept)*interval+interval/2 { t.Fatalf("bad sleep duration %v", d) } } if n, err := refresh1(ctxbg, log, resolver, sleep); err != nil || n != 3 { t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n) } if slept != 2 { t.Fatalf("bad sleeps, %d instead of 2", slept) } time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database. // Should not do any more refreshes and return immediately. q := bstore.QueryDB[PolicyRecord](ctxbg, DB) q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"}) if _, err := q.Delete(); err != nil { t.Fatalf("delete record that would be refreshed: %v", err) } mox.Context = ctxbg mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg) mox.ShutdownCancel() n := refresh() if n != 0 { t.Fatalf("refresh found unexpected work, n %d", n) } mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg) } type pipeListener struct { sync.Mutex closed bool C chan net.Conn } var _ net.Listener = &pipeListener{} func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} } func (l *pipeListener) Dial() (net.Conn, error) { l.Lock() defer l.Unlock() if l.closed { return nil, errors.New("closed") } c, s := net.Pipe() l.C <- s return c, nil } func (l *pipeListener) Accept() (net.Conn, error) { conn := <-l.C if conn == nil { return nil, io.EOF } return conn, nil } func (l *pipeListener) Close() error { l.Lock() defer l.Unlock() if !l.closed { l.closed = true close(l.C) } return nil } func (l *pipeListener) Addr() net.Addr { return pipeAddr{} } type pipeAddr struct{} func (a pipeAddr) Network() string { return "pipe" } func (a pipeAddr) String() string { return "pipe" } func fakeCert(t *testing.T, expired bool) tls.Certificate { notAfter := time.Now() if expired { notAfter = notAfter.Add(-time.Hour) } else { notAfter = notAfter.Add(time.Hour) } privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real! template := &x509.Certificate{ SerialNumber: big.NewInt(1), // Required field... DNSNames: []string{"mta-sts.policybad.mox.example", "mta-sts.policyok.mox.example"}, NotBefore: time.Now().Add(-time.Hour), NotAfter: notAfter, } localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey) if err != nil { t.Fatalf("making certificate: %s", err) } cert, err := x509.ParseCertificate(localCertBuf) if err != nil { t.Fatalf("parsing generated certificate: %s", err) } c := tls.Certificate{ Certificate: [][]byte{localCertBuf}, PrivateKey: privKey, Leaf: cert, } return c }