diff --git a/dns/mock.go b/dns/mock.go index fa0f456..5a09bda 100644 --- a/dns/mock.go +++ b/dns/mock.go @@ -184,20 +184,20 @@ func (r MockResolver) LookupHost(ctx context.Context, host string) ([]string, ad func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net.IP, adns.Result, error) { mr := mockReq{"ip", host} - _, result, err := r.result(ctx, mr) + name, result, err := r.result(ctx, mr) if err != nil { return nil, result, err } var ips []net.IP switch network { case "ip", "ip4": - for _, ip := range r.A[host] { + for _, ip := range r.A[name] { ips = append(ips, net.ParseIP(ip)) } } switch network { case "ip", "ip6": - for _, ip := range r.AAAA[host] { + for _, ip := range r.AAAA[name] { ips = append(ips, net.ParseIP(ip)) } } @@ -209,7 +209,7 @@ func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adns.Result, error) { mr := mockReq{"mx", name} - _, result, err := r.result(ctx, mr) + name, result, err := r.result(ctx, mr) if err != nil { return nil, result, err } @@ -222,7 +222,7 @@ func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adn func (r MockResolver) LookupTXT(ctx context.Context, name string) ([]string, adns.Result, error) { mr := mockReq{"txt", name} - _, result, err := r.result(ctx, mr) + name, result, err := r.result(ctx, mr) if err != nil { return nil, result, err } @@ -241,7 +241,7 @@ func (r MockResolver) LookupTLSA(ctx context.Context, port int, protocol string, name = fmt.Sprintf("_%d._%s.%s", port, protocol, host) } mr := mockReq{"tlsa", name} - _, result, err := r.result(ctx, mr) + name, result, err := r.result(ctx, mr) if err != nil { return nil, result, err } diff --git a/mtasts/mtasts.go b/mtasts/mtasts.go index 4ce0ca0..3f91fe5 100644 --- a/mtasts/mtasts.go +++ b/mtasts/mtasts.go @@ -162,45 +162,26 @@ var ( ) // LookupRecord looks up the MTA-STS TXT DNS record at "_mta-sts.", -// following CNAME records, and returns the parsed MTA-STS record, the DNS TXT -// record and any CNAMEs that were followed. -func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (rrecord *Record, rtxt string, rcnames []string, rerr error) { +// following CNAME records, and returns the parsed MTA-STS record and the DNS TXT +// record. +func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (rrecord *Record, rtxt string, rerr error) { log := xlog.WithContext(ctx) start := time.Now() defer func() { - log.Debugx("mtasts lookup result", rerr, mlog.Field("domain", domain), mlog.Field("record", rrecord), mlog.Field("cnames", rcnames), mlog.Field("duration", time.Since(start))) + log.Debugx("mtasts lookup result", rerr, mlog.Field("domain", domain), mlog.Field("record", rrecord), mlog.Field("duration", time.Since(start))) }() // ../rfc/8461:289 // ../rfc/8461:351 - // We lookup the txt record, but must follow CNAME records when the TXT does not exist. - var cnames []string + // We lookup the txt record, but must follow CNAME records when the TXT does not + // exist. LookupTXT follows CNAMEs. name := "_mta-sts." + domain.ASCII + "." var txts []string - for { - var err error - txts, _, err = dns.WithPackage(resolver, "mtasts").LookupTXT(ctx, name) - if dns.IsNotFound(err) { - // DNS has no specified limit on how many CNAMEs to follow. Chains of 10 CNAMEs - // have been seen on the internet. - if len(cnames) > 16 { - return nil, "", cnames, fmt.Errorf("too many cnames") - } - cname, _, err := dns.WithPackage(resolver, "mtasts").LookupCNAME(ctx, name) - if dns.IsNotFound(err) { - return nil, "", cnames, ErrNoRecord - } - if err != nil { - return nil, "", cnames, fmt.Errorf("%w: %s", ErrDNS, err) - } - cnames = append(cnames, cname) - name = cname - continue - } else if err != nil { - return nil, "", cnames, fmt.Errorf("%w: %s", ErrDNS, err) - } else { - break - } + txts, _, err := dns.WithPackage(resolver, "mtasts").LookupTXT(ctx, name) + if dns.IsNotFound(err) { + return nil, "", ErrNoRecord + } else if err != nil { + return nil, "", fmt.Errorf("%w: %s", ErrDNS, err) } var text string @@ -215,18 +196,18 @@ func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain) continue } if err != nil { - return nil, "", cnames, err + return nil, "", err } if record != nil { - return nil, "", cnames, ErrMultipleRecords + return nil, "", ErrMultipleRecords } record = r text = txt } if record == nil { - return nil, "", cnames, ErrNoRecord + return nil, "", ErrNoRecord } - return record, text, cnames, nil + return record, text, nil } // Policy fetch errors. @@ -330,7 +311,7 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (record log.Debugx("mtasts get result", err, mlog.Field("domain", domain), mlog.Field("record", record), mlog.Field("policy", policy), mlog.Field("duration", time.Since(start))) }() - record, _, _, err = LookupRecord(ctx, resolver, domain) + record, _, err = LookupRecord(ctx, resolver, domain) if err != nil { return nil, nil, err } diff --git a/mtasts/mtasts_test.go b/mtasts/mtasts_test.go index cca1b1f..6018d7b 100644 --- a/mtasts/mtasts_test.go +++ b/mtasts/mtasts_test.go @@ -21,9 +21,12 @@ import ( "github.com/mjl-/adns" "github.com/mjl-/mox/dns" + "github.com/mjl-/mox/mlog" ) func TestLookup(t *testing.T) { + mlog.SetConfig(map[string]mlog.Level{"": mlog.LevelDebug}) + resolver := dns.MockResolver{ TXT: map[string][]string{ "_mta-sts.a.example.": {"v=STSv1; id=1"}, @@ -37,39 +40,37 @@ func TestLookup(t *testing.T) { CNAME: map[string]string{ "_mta-sts.a.cnames.example.": "_mta-sts.b.cnames.example.", "_mta-sts.b.cnames.example.": "_mta-sts.c.cnames.example.", - "_mta-sts.followtemperror.example.": "_mta-sts.cnametemperror.example.", + "_mta-sts.followtemperror.example.": "_mta-sts.temperror.example.", }, Fail: []string{ "txt _mta-sts.temperror.example.", - "cname _mta-sts.cnametemperror.example.", }, } - test := func(host string, expRecord *Record, expCNAMEs []string, expErr error) { + test := func(host string, expRecord *Record, expErr error) { t.Helper() - record, _, cnames, err := LookupRecord(context.Background(), resolver, dns.Domain{ASCII: host}) + record, _, err := LookupRecord(context.Background(), resolver, dns.Domain{ASCII: host}) if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) { t.Fatalf("lookup: got err %#v, expected %#v", err, expErr) } if err != nil { return } - if !reflect.DeepEqual(record, expRecord) || !reflect.DeepEqual(cnames, expCNAMEs) { - t.Fatalf("lookup: got record %#v, cnames %#v, expected %#v %#v", record, cnames, expRecord, expCNAMEs) + if !reflect.DeepEqual(record, expRecord) { + t.Fatalf("lookup: got record %#v, expected %#v", record, expRecord) } } - test("absent.example", nil, nil, ErrNoRecord) - test("other.example", nil, nil, ErrNoRecord) - test("a.example", &Record{Version: "STSv1", ID: "1"}, nil, nil) - test("one.example", &Record{Version: "STSv1", ID: "1"}, nil, nil) - test("bad.example", nil, nil, ErrRecordSyntax) - test("multiple.example", nil, nil, ErrMultipleRecords) - test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, []string{"_mta-sts.b.cnames.example.", "_mta-sts.c.cnames.example."}, nil) - test("temperror.example", nil, nil, ErrDNS) - test("cnametemperror.example", nil, nil, ErrDNS) - test("followtemperror.example", nil, nil, ErrDNS) + test("absent.example", nil, ErrNoRecord) + test("other.example", nil, ErrNoRecord) + test("a.example", &Record{Version: "STSv1", ID: "1"}, nil) + test("one.example", &Record{Version: "STSv1", ID: "1"}, nil) + test("bad.example", nil, ErrRecordSyntax) + test("multiple.example", nil, ErrMultipleRecords) + test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, nil) + test("temperror.example", nil, ErrDNS) + test("followtemperror.example", nil, ErrDNS) } func TestMatches(t *testing.T) { diff --git a/mtastsdb/db.go b/mtastsdb/db.go index efd9835..7b5f262 100644 --- a/mtastsdb/db.go +++ b/mtastsdb/db.go @@ -263,7 +263,7 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy policy = &cachedPolicy.Policy nctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - record, _, _, err := mtasts.LookupRecord(nctx, resolver, domain) + record, _, err := mtasts.LookupRecord(nctx, resolver, domain) if err != nil { if !errors.Is(err, mtasts.ErrNoRecord) { // Could be a temporary DNS or configuration error. diff --git a/mtastsdb/refresh.go b/mtastsdb/refresh.go index cb7a91d..219b82e 100644 --- a/mtastsdb/refresh.go +++ b/mtastsdb/refresh.go @@ -125,7 +125,7 @@ func refreshDomain(ctx context.Context, db *bstore.DB, resolver dns.Resolver, pr return } log.Debug("refreshing mta-sts policy for domain", mlog.Field("domain", d)) - record, _, _, err := mtasts.LookupRecord(ctx, resolver, d) + record, _, err := mtasts.LookupRecord(ctx, resolver, d) if err == nil && record.ID == pr.RecordID { qup := bstore.QueryDB[PolicyRecord](ctx, db) qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate}) diff --git a/webadmin/admin.go b/webadmin/admin.go index 5a4bc74..7d2baa8 100644 --- a/webadmin/admin.go +++ b/webadmin/admin.go @@ -325,7 +325,6 @@ type MTASTSRecord struct { mtasts.Record } type MTASTSCheckResult struct { - CNAMEs []string TXT string Record *MTASTSRecord PolicyText string @@ -1180,15 +1179,10 @@ Ensure a DNS TXT record like the following exists: defer logPanic(ctx) defer wg.Done() - record, txt, cnames, err := mtasts.LookupRecord(ctx, resolver, domain) + record, txt, err := mtasts.LookupRecord(ctx, resolver, domain) if err != nil { addf(&r.MTASTS.Errors, "Looking up MTA-STS record: %s", err) } - if cnames != nil { - r.MTASTS.CNAMEs = cnames - } else { - r.MTASTS.CNAMEs = []string{} - } r.MTASTS.TXT = txt if record != nil { r.MTASTS.Record = &MTASTSRecord{*record} diff --git a/webadmin/admin.html b/webadmin/admin.html index 2caa04a..2a9b5be 100644 --- a/webadmin/admin.html +++ b/webadmin/admin.html @@ -951,8 +951,7 @@ const domainDNSCheck = async (d) => { const detailsTLSRPT = !checks.TLSRPT.TXT ? [] : [ dom.div('TXT record: ' + checks.TLSRPT.TXT), ] - const detailsMTASTS = empty(checks.MTASTS.CNAMEs) && !checks.MTASTS.TXT && !checks.MTASTS.PolicyText ? [] : [ - dom.div('CNAMEs followed: ' + (checks.MTASTS.CNAMEs.join(', ') || '(none)')), + const detailsMTASTS = !checks.MTASTS.TXT && !checks.MTASTS.PolicyText ? [] : [ !checks.MTASTS.TXT ? [] : dom.div('MTA-STS record: ' + checks.MTASTS.TXT), !checks.MTASTS.PolicyText ? [] : dom.div('MTA-STS policy: ', dom('pre.literal', style({maxWidth: '60em'}), checks.MTASTS.PolicyText)), ] diff --git a/webadmin/adminapi.json b/webadmin/adminapi.json index b1b4459..233649b 100644 --- a/webadmin/adminapi.json +++ b/webadmin/adminapi.json @@ -1626,14 +1626,6 @@ "Name": "MTASTSCheckResult", "Docs": "", "Fields": [ - { - "Name": "CNAMEs", - "Docs": "", - "Typewords": [ - "[]", - "string" - ] - }, { "Name": "TXT", "Docs": "",