Mechiel Lukkien ec967ef321
use new sherpadoc rename mechanism to remove some typename stuttering
the stuttering was introduced to make the same type name declared in multiple
packages, and used in the admin sherpa api, unique. with sherpadoc's new
rename, we can make them unique when generating the api definition/docs, and
the Go code can use nicer names.
2024-04-19 10:51:24 +02:00

160 lines
5.1 KiB

package mtastsdb
import (
func tcheckf(t *testing.T, err error, format string, args ...any) {
if err != nil {
t.Fatalf("%s: %s", fmt.Sprintf(format, args...), err)
func TestDB(t *testing.T) {
mox.Shutdown = ctxbg
mox.ConfigStaticPath = filepath.FromSlash("../testdata/mtasts/fake.conf")
mox.Conf.Static.DataDir = "."
dbpath := mox.DataDirPath("mtasts.db")
os.MkdirAll(filepath.Dir(dbpath), 0770)
defer os.Remove(dbpath)
log := mlog.New("mtastsdb", nil)
if err := Init(false); err != nil {
t.Fatalf("init database: %s", err)
defer Close()
// Mock time.
now := time.Now().Round(0)
timeNow = func() time.Time { return now }
defer func() { timeNow = time.Now }()
if p, err := lookup(ctxbg, log, dns.Domain{ASCII: ""}); err != ErrNotFound {
t.Fatalf("expected not found, got %v, %#v", err, p)
policy1 := mtasts.Policy{
Version: "STSv1",
Mode: mtasts.ModeTesting,
MX: []mtasts.MX{
{Domain: dns.Domain{ASCII: ""}},
{Domain: dns.Domain{ASCII: ""}},
{Domain: dns.Domain{ASCII: ""}},
MaxAgeSeconds: 1296000,
if err := Upsert(ctxbg, dns.Domain{ASCII: ""}, "123", &policy1, policy1.String()); err != nil {
t.Fatalf("upsert record: %s", err)
if got, err := lookup(ctxbg, log, dns.Domain{ASCII: ""}); err != nil {
t.Fatalf("lookup after insert: %s", err)
} else if !reflect.DeepEqual(got.Policy, policy1) {
t.Fatalf("mismatch between inserted and retrieved: got %#v, want %#v", got, policy1)
policy2 := mtasts.Policy{
Version: "STSv1",
Mode: mtasts.ModeEnforce,
MX: []mtasts.MX{
{Domain: dns.Domain{ASCII: ""}},
MaxAgeSeconds: 360000,
if err := Upsert(ctxbg, dns.Domain{ASCII: ""}, "124", &policy2, policy2.String()); err != nil {
t.Fatalf("upsert record: %s", err)
if got, err := lookup(ctxbg, log, dns.Domain{ASCII: ""}); err != nil {
t.Fatalf("lookup after insert: %s", err)
} else if !reflect.DeepEqual(got.Policy, policy2) {
t.Fatalf("mismatch between inserted and retrieved: got %v, want %v", got, policy2)
// Check if database holds expected record.
records, err := PolicyRecords(ctxbg)
tcheckf(t, err, "policyrecords")
expRecords := []PolicyRecord{
{"", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2, policy2.String()},
records[0].Policy = mtasts.Policy{}
expRecords[0].Policy = mtasts.Policy{}
if !reflect.DeepEqual(records, expRecords) {
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
if err := Upsert(ctxbg, dns.Domain{ASCII: ""}, "", nil, ""); err != nil {
t.Fatalf("upsert record: %s", err)
records, err = PolicyRecords(ctxbg)
tcheckf(t, err, "policyrecords")
policyNone := mtasts.Policy{Mode: mtasts.ModeNone, MaxAgeSeconds: 5 * 60}
expRecords = []PolicyRecord{
{"", now, now.Add(5 * 60 * time.Second), now, now, true, "", policyNone, ""},
{"", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2, policy2.String()},
if !reflect.DeepEqual(records, expRecords) {
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
if _, err := lookup(ctxbg, log, dns.Domain{ASCII: ""}); err != ErrBackoff {
t.Fatalf("got %#v, expected ErrBackoff", err)
resolver := dns.MockResolver{
TXT: map[string][]string{
"": {"v=STSv1; id=124"},
"": {"v=STSv1; id=1"},
"": {""},
Fail: []string{
testGet := func(domain string, expPolicy *mtasts.Policy, expFresh bool, expErr error) {
p, _, fresh, err := Get(ctxbg, log.Logger, resolver, dns.Domain{ASCII: domain})
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
t.Fatalf("got err %v, expected %v", err, expErr)
if !reflect.DeepEqual(p, expPolicy) || fresh != expFresh {
t.Fatalf("got policy %#v, fresh %v, expected %#v, %v", p, fresh, expPolicy, expFresh)
testGet("", &policy2, true, nil)
testGet("", nil, false, nil) // Back off, already in database.
testGet("", nil, true, nil) // No MTA-STS.
testGet("", nil, false, mtasts.ErrDNS)
// Force refetch of policy, that will fail.
mtasts.HTTPClient.Transport = &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return nil, fmt.Errorf("bad")
defer func() {
mtasts.HTTPClient.Transport = nil
resolver.TXT[""] = []string{"v=STSv1; id=125"}
testGet("", &policy2, false, nil)
// Cached policy but no longer a DNS record.
delete(resolver.TXT, "")
testGet("", &policy2, false, nil)