do not lookup cname after looking up the txt for mta-sts, and follow cnames for mocks

because the txt would already follow cnames.
the additional cname lookup didn't hurt, it just didn't do anything.
i probably didn't realize that before looking deeper into dns.
This commit is contained in:
Mechiel Lukkien 2023-10-14 22:42:26 +02:00
parent 8ca198882e
commit 101c2703d2
No known key found for this signature in database
8 changed files with 43 additions and 76 deletions

View file

@ -184,20 +184,20 @@ func (r MockResolver) LookupHost(ctx context.Context, host string) ([]string, ad
func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net.IP, adns.Result, error) { func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net.IP, adns.Result, error) {
mr := mockReq{"ip", host} mr := mockReq{"ip", host}
_, result, err := r.result(ctx, mr) name, result, err := r.result(ctx, mr)
if err != nil { if err != nil {
return nil, result, err return nil, result, err
} }
var ips []net.IP var ips []net.IP
switch network { switch network {
case "ip", "ip4": case "ip", "ip4":
for _, ip := range r.A[host] { for _, ip := range r.A[name] {
ips = append(ips, net.ParseIP(ip)) ips = append(ips, net.ParseIP(ip))
} }
} }
switch network { switch network {
case "ip", "ip6": case "ip", "ip6":
for _, ip := range r.AAAA[host] { for _, ip := range r.AAAA[name] {
ips = append(ips, net.ParseIP(ip)) ips = append(ips, net.ParseIP(ip))
} }
} }
@ -209,7 +209,7 @@ func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net
func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adns.Result, error) { func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adns.Result, error) {
mr := mockReq{"mx", name} mr := mockReq{"mx", name}
_, result, err := r.result(ctx, mr) name, result, err := r.result(ctx, mr)
if err != nil { if err != nil {
return nil, result, err return nil, result, err
} }
@ -222,7 +222,7 @@ func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adn
func (r MockResolver) LookupTXT(ctx context.Context, name string) ([]string, adns.Result, error) { func (r MockResolver) LookupTXT(ctx context.Context, name string) ([]string, adns.Result, error) {
mr := mockReq{"txt", name} mr := mockReq{"txt", name}
_, result, err := r.result(ctx, mr) name, result, err := r.result(ctx, mr)
if err != nil { if err != nil {
return nil, result, err return nil, result, err
} }
@ -241,7 +241,7 @@ func (r MockResolver) LookupTLSA(ctx context.Context, port int, protocol string,
name = fmt.Sprintf("_%d._%s.%s", port, protocol, host) name = fmt.Sprintf("_%d._%s.%s", port, protocol, host)
} }
mr := mockReq{"tlsa", name} mr := mockReq{"tlsa", name}
_, result, err := r.result(ctx, mr) name, result, err := r.result(ctx, mr)
if err != nil { if err != nil {
return nil, result, err return nil, result, err
} }

View file

@ -162,45 +162,26 @@ var (
) )
// LookupRecord looks up the MTA-STS TXT DNS record at "_mta-sts.<domain>", // LookupRecord looks up the MTA-STS TXT DNS record at "_mta-sts.<domain>",
// following CNAME records, and returns the parsed MTA-STS record, the DNS TXT // following CNAME records, and returns the parsed MTA-STS record and the DNS TXT
// record and any CNAMEs that were followed. // record.
func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (rrecord *Record, rtxt string, rcnames []string, rerr error) { func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (rrecord *Record, rtxt string, rerr error) {
log := xlog.WithContext(ctx) log := xlog.WithContext(ctx)
start := time.Now() start := time.Now()
defer func() { defer func() {
log.Debugx("mtasts lookup result", rerr, mlog.Field("domain", domain), mlog.Field("record", rrecord), mlog.Field("cnames", rcnames), mlog.Field("duration", time.Since(start))) log.Debugx("mtasts lookup result", rerr, mlog.Field("domain", domain), mlog.Field("record", rrecord), mlog.Field("duration", time.Since(start)))
}() }()
// ../rfc/8461:289 // ../rfc/8461:289
// ../rfc/8461:351 // ../rfc/8461:351
// We lookup the txt record, but must follow CNAME records when the TXT does not exist. // We lookup the txt record, but must follow CNAME records when the TXT does not
var cnames []string // exist. LookupTXT follows CNAMEs.
name := "_mta-sts." + domain.ASCII + "." name := "_mta-sts." + domain.ASCII + "."
var txts []string var txts []string
for { txts, _, err := dns.WithPackage(resolver, "mtasts").LookupTXT(ctx, name)
var err error
txts, _, err = dns.WithPackage(resolver, "mtasts").LookupTXT(ctx, name)
if dns.IsNotFound(err) { if dns.IsNotFound(err) {
// DNS has no specified limit on how many CNAMEs to follow. Chains of 10 CNAMEs return nil, "", ErrNoRecord
// have been seen on the internet.
if len(cnames) > 16 {
return nil, "", cnames, fmt.Errorf("too many cnames")
}
cname, _, err := dns.WithPackage(resolver, "mtasts").LookupCNAME(ctx, name)
if dns.IsNotFound(err) {
return nil, "", cnames, ErrNoRecord
}
if err != nil {
return nil, "", cnames, fmt.Errorf("%w: %s", ErrDNS, err)
}
cnames = append(cnames, cname)
name = cname
continue
} else if err != nil { } else if err != nil {
return nil, "", cnames, fmt.Errorf("%w: %s", ErrDNS, err) return nil, "", fmt.Errorf("%w: %s", ErrDNS, err)
} else {
break
}
} }
var text string var text string
@ -215,18 +196,18 @@ func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain)
continue continue
} }
if err != nil { if err != nil {
return nil, "", cnames, err return nil, "", err
} }
if record != nil { if record != nil {
return nil, "", cnames, ErrMultipleRecords return nil, "", ErrMultipleRecords
} }
record = r record = r
text = txt text = txt
} }
if record == nil { if record == nil {
return nil, "", cnames, ErrNoRecord return nil, "", ErrNoRecord
} }
return record, text, cnames, nil return record, text, nil
} }
// Policy fetch errors. // Policy fetch errors.
@ -330,7 +311,7 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (record
log.Debugx("mtasts get result", err, mlog.Field("domain", domain), mlog.Field("record", record), mlog.Field("policy", policy), mlog.Field("duration", time.Since(start))) log.Debugx("mtasts get result", err, mlog.Field("domain", domain), mlog.Field("record", record), mlog.Field("policy", policy), mlog.Field("duration", time.Since(start)))
}() }()
record, _, _, err = LookupRecord(ctx, resolver, domain) record, _, err = LookupRecord(ctx, resolver, domain)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -21,9 +21,12 @@ import (
"github.com/mjl-/adns" "github.com/mjl-/adns"
"github.com/mjl-/mox/dns" "github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mlog"
) )
func TestLookup(t *testing.T) { func TestLookup(t *testing.T) {
mlog.SetConfig(map[string]mlog.Level{"": mlog.LevelDebug})
resolver := dns.MockResolver{ resolver := dns.MockResolver{
TXT: map[string][]string{ TXT: map[string][]string{
"_mta-sts.a.example.": {"v=STSv1; id=1"}, "_mta-sts.a.example.": {"v=STSv1; id=1"},
@ -37,39 +40,37 @@ func TestLookup(t *testing.T) {
CNAME: map[string]string{ CNAME: map[string]string{
"_mta-sts.a.cnames.example.": "_mta-sts.b.cnames.example.", "_mta-sts.a.cnames.example.": "_mta-sts.b.cnames.example.",
"_mta-sts.b.cnames.example.": "_mta-sts.c.cnames.example.", "_mta-sts.b.cnames.example.": "_mta-sts.c.cnames.example.",
"_mta-sts.followtemperror.example.": "_mta-sts.cnametemperror.example.", "_mta-sts.followtemperror.example.": "_mta-sts.temperror.example.",
}, },
Fail: []string{ Fail: []string{
"txt _mta-sts.temperror.example.", "txt _mta-sts.temperror.example.",
"cname _mta-sts.cnametemperror.example.",
}, },
} }
test := func(host string, expRecord *Record, expCNAMEs []string, expErr error) { test := func(host string, expRecord *Record, expErr error) {
t.Helper() t.Helper()
record, _, cnames, err := LookupRecord(context.Background(), resolver, dns.Domain{ASCII: host}) record, _, err := LookupRecord(context.Background(), resolver, dns.Domain{ASCII: host})
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) { if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
t.Fatalf("lookup: got err %#v, expected %#v", err, expErr) t.Fatalf("lookup: got err %#v, expected %#v", err, expErr)
} }
if err != nil { if err != nil {
return return
} }
if !reflect.DeepEqual(record, expRecord) || !reflect.DeepEqual(cnames, expCNAMEs) { if !reflect.DeepEqual(record, expRecord) {
t.Fatalf("lookup: got record %#v, cnames %#v, expected %#v %#v", record, cnames, expRecord, expCNAMEs) t.Fatalf("lookup: got record %#v, expected %#v", record, expRecord)
} }
} }
test("absent.example", nil, nil, ErrNoRecord) test("absent.example", nil, ErrNoRecord)
test("other.example", nil, nil, ErrNoRecord) test("other.example", nil, ErrNoRecord)
test("a.example", &Record{Version: "STSv1", ID: "1"}, nil, nil) test("a.example", &Record{Version: "STSv1", ID: "1"}, nil)
test("one.example", &Record{Version: "STSv1", ID: "1"}, nil, nil) test("one.example", &Record{Version: "STSv1", ID: "1"}, nil)
test("bad.example", nil, nil, ErrRecordSyntax) test("bad.example", nil, ErrRecordSyntax)
test("multiple.example", nil, nil, ErrMultipleRecords) test("multiple.example", nil, ErrMultipleRecords)
test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, []string{"_mta-sts.b.cnames.example.", "_mta-sts.c.cnames.example."}, nil) test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, nil)
test("temperror.example", nil, nil, ErrDNS) test("temperror.example", nil, ErrDNS)
test("cnametemperror.example", nil, nil, ErrDNS) test("followtemperror.example", nil, ErrDNS)
test("followtemperror.example", nil, nil, ErrDNS)
} }
func TestMatches(t *testing.T) { func TestMatches(t *testing.T) {

View file

@ -263,7 +263,7 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy
policy = &cachedPolicy.Policy policy = &cachedPolicy.Policy
nctx, cancel := context.WithTimeout(ctx, 30*time.Second) nctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() defer cancel()
record, _, _, err := mtasts.LookupRecord(nctx, resolver, domain) record, _, err := mtasts.LookupRecord(nctx, resolver, domain)
if err != nil { if err != nil {
if !errors.Is(err, mtasts.ErrNoRecord) { if !errors.Is(err, mtasts.ErrNoRecord) {
// Could be a temporary DNS or configuration error. // Could be a temporary DNS or configuration error.

View file

@ -125,7 +125,7 @@ func refreshDomain(ctx context.Context, db *bstore.DB, resolver dns.Resolver, pr
return return
} }
log.Debug("refreshing mta-sts policy for domain", mlog.Field("domain", d)) log.Debug("refreshing mta-sts policy for domain", mlog.Field("domain", d))
record, _, _, err := mtasts.LookupRecord(ctx, resolver, d) record, _, err := mtasts.LookupRecord(ctx, resolver, d)
if err == nil && record.ID == pr.RecordID { if err == nil && record.ID == pr.RecordID {
qup := bstore.QueryDB[PolicyRecord](ctx, db) qup := bstore.QueryDB[PolicyRecord](ctx, db)
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate}) qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})

View file

@ -325,7 +325,6 @@ type MTASTSRecord struct {
mtasts.Record mtasts.Record
} }
type MTASTSCheckResult struct { type MTASTSCheckResult struct {
CNAMEs []string
TXT string TXT string
Record *MTASTSRecord Record *MTASTSRecord
PolicyText string PolicyText string
@ -1180,15 +1179,10 @@ Ensure a DNS TXT record like the following exists:
defer logPanic(ctx) defer logPanic(ctx)
defer wg.Done() defer wg.Done()
record, txt, cnames, err := mtasts.LookupRecord(ctx, resolver, domain) record, txt, err := mtasts.LookupRecord(ctx, resolver, domain)
if err != nil { if err != nil {
addf(&r.MTASTS.Errors, "Looking up MTA-STS record: %s", err) addf(&r.MTASTS.Errors, "Looking up MTA-STS record: %s", err)
} }
if cnames != nil {
r.MTASTS.CNAMEs = cnames
} else {
r.MTASTS.CNAMEs = []string{}
}
r.MTASTS.TXT = txt r.MTASTS.TXT = txt
if record != nil { if record != nil {
r.MTASTS.Record = &MTASTSRecord{*record} r.MTASTS.Record = &MTASTSRecord{*record}

View file

@ -951,8 +951,7 @@ const domainDNSCheck = async (d) => {
const detailsTLSRPT = !checks.TLSRPT.TXT ? [] : [ const detailsTLSRPT = !checks.TLSRPT.TXT ? [] : [
dom.div('TXT record: ' + checks.TLSRPT.TXT), dom.div('TXT record: ' + checks.TLSRPT.TXT),
] ]
const detailsMTASTS = empty(checks.MTASTS.CNAMEs) && !checks.MTASTS.TXT && !checks.MTASTS.PolicyText ? [] : [ const detailsMTASTS = !checks.MTASTS.TXT && !checks.MTASTS.PolicyText ? [] : [
dom.div('CNAMEs followed: ' + (checks.MTASTS.CNAMEs.join(', ') || '(none)')),
!checks.MTASTS.TXT ? [] : dom.div('MTA-STS record: ' + checks.MTASTS.TXT), !checks.MTASTS.TXT ? [] : dom.div('MTA-STS record: ' + checks.MTASTS.TXT),
!checks.MTASTS.PolicyText ? [] : dom.div('MTA-STS policy: ', dom('pre.literal', style({maxWidth: '60em'}), checks.MTASTS.PolicyText)), !checks.MTASTS.PolicyText ? [] : dom.div('MTA-STS policy: ', dom('pre.literal', style({maxWidth: '60em'}), checks.MTASTS.PolicyText)),
] ]

View file

@ -1626,14 +1626,6 @@
"Name": "MTASTSCheckResult", "Name": "MTASTSCheckResult",
"Docs": "", "Docs": "",
"Fields": [ "Fields": [
{
"Name": "CNAMEs",
"Docs": "",
"Typewords": [
"[]",
"string"
]
},
{ {
"Name": "TXT", "Name": "TXT",
"Docs": "", "Docs": "",