mox/mtastsdb/db_test.go

159 lines
4.9 KiB
Go
Raw Normal View History

2023-01-30 16:27:06 +03:00
package mtastsdb
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"reflect"
"testing"
"time"
"github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mox-"
"github.com/mjl-/mox/mtasts"
)
func tcheckf(t *testing.T, err error, format string, args ...any) {
if err != nil {
t.Fatalf("%s: %s", fmt.Sprintf(format, args...), err)
}
}
func TestDB(t *testing.T) {
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
mox.Conf.Static.DataDir = "."
dbpath := mox.DataDirPath("mtasts.db")
os.MkdirAll(filepath.Dir(dbpath), 0770)
os.Remove(dbpath)
defer os.Remove(dbpath)
if err := Init(false); err != nil {
t.Fatalf("init database: %s", err)
}
defer Close()
ctx := context.Background()
// Mock time.
now := time.Now().Round(0)
timeNow = func() time.Time { return now }
defer func() { timeNow = time.Now }()
if p, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != ErrNotFound {
t.Fatalf("expected not found, got %v, %#v", err, p)
}
policy1 := mtasts.Policy{
Version: "STSv1",
Mode: mtasts.ModeTesting,
MX: []mtasts.STSMX{
{Domain: dns.Domain{ASCII: "mx1.example.com"}},
{Domain: dns.Domain{ASCII: "mx2.example.com"}},
{Domain: dns.Domain{ASCII: "mx.backup-example.com"}},
},
MaxAgeSeconds: 1296000,
}
if err := Upsert(dns.Domain{ASCII: "example.com"}, "123", &policy1); err != nil {
t.Fatalf("upsert record: %s", err)
}
if got, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != nil {
t.Fatalf("lookup after insert: %s", err)
} else if !reflect.DeepEqual(got.Policy, policy1) {
t.Fatalf("mismatch between inserted and retrieved: got %#v, want %#v", got, policy1)
}
policy2 := mtasts.Policy{
Version: "STSv1",
Mode: mtasts.ModeEnforce,
MX: []mtasts.STSMX{
{Domain: dns.Domain{ASCII: "mx1.example.com"}},
},
MaxAgeSeconds: 360000,
}
if err := Upsert(dns.Domain{ASCII: "example.com"}, "124", &policy2); err != nil {
t.Fatalf("upsert record: %s", err)
}
if got, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != nil {
t.Fatalf("lookup after insert: %s", err)
} else if !reflect.DeepEqual(got.Policy, policy2) {
t.Fatalf("mismatch between inserted and retrieved: got %v, want %v", got, policy2)
}
// Check if database holds expected record.
records, err := PolicyRecords(context.Background())
tcheckf(t, err, "policyrecords")
expRecords := []PolicyRecord{
{"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2},
}
records[0].Policy = mtasts.Policy{}
expRecords[0].Policy = mtasts.Policy{}
if !reflect.DeepEqual(records, expRecords) {
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
}
if err := Upsert(dns.Domain{ASCII: "other.example.com"}, "", nil); err != nil {
t.Fatalf("upsert record: %s", err)
}
records, err = PolicyRecords(context.Background())
tcheckf(t, err, "policyrecords")
expRecords = []PolicyRecord{
{"other.example.com", now, now.Add(5 * 60 * time.Second), now, now, true, "", mtasts.Policy{Mode: mtasts.ModeNone, MaxAgeSeconds: 5 * 60}},
{"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2},
}
if !reflect.DeepEqual(records, expRecords) {
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
}
if _, err := lookup(context.Background(), dns.Domain{ASCII: "other.example.com"}); err != ErrBackoff {
t.Fatalf("got %#v, expected ErrBackoff", err)
}
resolver := dns.MockResolver{
TXT: map[string][]string{
"_mta-sts.example.com.": {"v=STSv1; id=124"},
"_mta-sts.other.example.com.": {"v=STSv1; id=1"},
"_mta-sts.temperror.example.com.": {""},
},
Fail: map[dns.Mockreq]struct{}{
{Type: "txt", Name: "_mta-sts.temperror.example.com."}: {},
},
}
testGet := func(domain string, expPolicy *mtasts.Policy, expFresh bool, expErr error) {
t.Helper()
p, fresh, err := Get(context.Background(), resolver, dns.Domain{ASCII: domain})
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
t.Fatalf("got err %v, expected %v", err, expErr)
}
if !reflect.DeepEqual(p, expPolicy) || fresh != expFresh {
t.Fatalf("got policy %#v, fresh %v, expected %#v, %v", p, fresh, expPolicy, expFresh)
}
}
testGet("example.com", &policy2, true, nil)
testGet("other.example.com", nil, false, nil) // Back off, already in database.
testGet("absent.example.com", nil, true, nil) // No MTA-STS.
testGet("temperror.example.com", nil, false, mtasts.ErrDNS)
// Force refetch of policy, that will fail.
mtasts.HTTPClient.Transport = &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return nil, fmt.Errorf("bad")
},
}
defer func() {
mtasts.HTTPClient.Transport = nil
}()
resolver.TXT["_mta-sts.example.com."] = []string{"v=STSv1; id=125"}
testGet("example.com", &policy2, false, nil)
// Cached policy but no longer a DNS record.
delete(resolver.TXT, "_mta-sts.example.com.")
testGet("example.com", &policy2, false, nil)
}