1
0
Fork 0
mirror of https://github.com/mjl-/mox.git synced 2025-01-18 19:35:37 +03:00
mox/mtastsdb/refresh_test.go

234 lines
6.2 KiB
Go

package mtastsdb
import (
"context"
"crypto/ed25519"
cryptorand "crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/mjl-/bstore"
"github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mox-"
"github.com/mjl-/mox/mtasts"
)
var ctxbg = context.Background()
func TestRefresh(t *testing.T) {
mox.Shutdown = ctxbg
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
mox.Conf.Static.DataDir = "."
dbpath := mox.DataDirPath("mtasts.db")
os.MkdirAll(filepath.Dir(dbpath), 0770)
os.Remove(dbpath)
defer os.Remove(dbpath)
if err := Init(false); err != nil {
t.Fatalf("init database: %s", err)
}
defer Close()
db, err := database(ctxbg)
if err != nil {
t.Fatalf("database: %s", err)
}
cert := fakeCert(t, false)
defer func() {
mtasts.HTTPClient.Transport = nil
}()
insert := func(domain string, validEnd, lastUpdate, lastUse time.Time, backoff bool, recordID string, mode mtasts.Mode, maxAge int, mx string) {
t.Helper()
mxd, err := dns.ParseDomain(mx)
if err != nil {
t.Fatalf("parsing mx domain %q: %s", mx, err)
}
policy := mtasts.Policy{
Version: "STSv1",
Mode: mode,
MX: []mtasts.STSMX{{Wildcard: false, Domain: mxd}},
MaxAgeSeconds: maxAge,
Extensions: nil,
}
pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy}
if err := db.Insert(ctxbg, &pr); err != nil {
t.Fatalf("insert policy: %s", err)
}
}
now := time.Now()
// Updated just now.
insert("mox.example", now.Add(24*time.Hour), now, now, false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
// To be removed.
insert("stale.mox.example", now.Add(-time.Hour), now, now.Add(-181*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
// To be refreshed, same id.
insert("refresh.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
// To be refreshed and succeed.
insert("policyok.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
// To be refreshed and fail to fetch.
insert("policybad.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
resolver := dns.MockResolver{
TXT: map[string][]string{
"_mta-sts.refresh.mox.example.": {"v=STSv1; id=1"},
"_mta-sts.policyok.mox.example.": {"v=STSv1; id=2"},
"_mta-sts.policybad.mox.example.": {"v=STSv1; id=2"},
},
}
pool := x509.NewCertPool()
pool.AddCert(cert.Leaf)
l := newPipeListener()
defer l.Close()
go func() {
mux := &http.ServeMux{}
mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
if r.Host == "mta-sts.policybad.mox.example" {
w.WriteHeader(500)
return
}
fmt.Fprintf(w, "version: STSv1\nmode: enforce\nmx: mx.mox.example.com\nmax_age: 3600\n")
})
s := &http.Server{
Handler: mux,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
ErrorLog: log.New(io.Discard, "", 0),
}
s.ServeTLS(l, "", "")
}()
mtasts.HTTPClient.Transport = &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return l.Dial()
},
TLSClientConfig: &tls.Config{
RootCAs: pool,
},
}
slept := 0
sleep := func(d time.Duration) {
slept++
interval := 3 * time.Hour / 2
if d < time.Duration(slept)*interval-interval/2 || d > time.Duration(slept)*interval+interval/2 {
t.Fatalf("bad sleep duration %v", d)
}
}
if n, err := refresh1(ctxbg, resolver, sleep); err != nil || n != 3 {
t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
}
if slept != 2 {
t.Fatalf("bad sleeps, %d instead of 2", slept)
}
time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
// Should not do any more refreshes and return immediately.
q := bstore.QueryDB[PolicyRecord](ctxbg, db)
q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
if _, err := q.Delete(); err != nil {
t.Fatalf("delete record that would be refreshed: %v", err)
}
mox.Context = ctxbg
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
mox.ShutdownCancel()
n := refresh()
if n != 0 {
t.Fatalf("refresh found unexpected work, n %d", n)
}
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
}
type pipeListener struct {
sync.Mutex
closed bool
C chan net.Conn
}
var _ net.Listener = &pipeListener{}
func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
func (l *pipeListener) Dial() (net.Conn, error) {
l.Lock()
defer l.Unlock()
if l.closed {
return nil, errors.New("closed")
}
c, s := net.Pipe()
l.C <- s
return c, nil
}
func (l *pipeListener) Accept() (net.Conn, error) {
conn := <-l.C
if conn == nil {
return nil, io.EOF
}
return conn, nil
}
func (l *pipeListener) Close() error {
l.Lock()
defer l.Unlock()
if !l.closed {
l.closed = true
close(l.C)
}
return nil
}
func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
type pipeAddr struct{}
func (a pipeAddr) Network() string { return "pipe" }
func (a pipeAddr) String() string { return "pipe" }
func fakeCert(t *testing.T, expired bool) tls.Certificate {
notAfter := time.Now()
if expired {
notAfter = notAfter.Add(-time.Hour)
} else {
notAfter = notAfter.Add(time.Hour)
}
privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
template := &x509.Certificate{
SerialNumber: big.NewInt(1), // Required field...
DNSNames: []string{"mta-sts.policybad.mox.example", "mta-sts.policyok.mox.example"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: notAfter,
}
localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
if err != nil {
t.Fatalf("making certificate: %s", err)
}
cert, err := x509.ParseCertificate(localCertBuf)
if err != nil {
t.Fatalf("parsing generated certificate: %s", err)
}
c := tls.Certificate{
Certificate: [][]byte{localCertBuf},
PrivateKey: privKey,
Leaf: cert,
}
return c
}