diff --git a/dns/mock.go b/dns/mock.go index 2a8e84b..b0d420c 100644 --- a/dns/mock.go +++ b/dns/mock.go @@ -44,6 +44,9 @@ func (r MockResolver) servfail(s string) *net.DNSError { } func (r MockResolver) LookupCNAME(ctx context.Context, name string) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } if _, ok := r.Fail[Mockreq{"cname", name}]; ok { return "", r.servfail(name) } @@ -54,6 +57,9 @@ func (r MockResolver) LookupCNAME(ctx context.Context, name string) (string, err } func (r MockResolver) LookupAddr(ctx context.Context, ip string) ([]string, error) { + if err := ctx.Err(); err != nil { + return nil, err + } if _, ok := r.Fail[Mockreq{"ptr", ip}]; ok { return nil, r.servfail(ip) } @@ -65,18 +71,30 @@ func (r MockResolver) LookupAddr(ctx context.Context, ip string) ([]string, erro } func (r MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) { + if err := ctx.Err(); err != nil { + return nil, err + } return nil, r.servfail("ns not implemented") } func (r MockResolver) LookupPort(ctx context.Context, network, service string) (port int, err error) { + if err := ctx.Err(); err != nil { + return 0, err + } return 0, r.servfail("port not implemented") } func (r MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) { + if err := ctx.Err(); err != nil { + return "", nil, err + } return "", nil, r.servfail("srv not implemented") } func (r MockResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) { + if err := ctx.Err(); err != nil { + return nil, err + } if _, ok := r.Fail[Mockreq{"ipaddr", host}]; ok { return nil, r.servfail(host) } @@ -96,6 +114,9 @@ func (r MockResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAd } func (r MockResolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) { + if err := ctx.Err(); err != nil { + return nil, err + } if _, ok := r.Fail[Mockreq{"host", host}]; ok { return nil, r.servfail(host) } @@ -111,6 +132,9 @@ func (r MockResolver) LookupHost(ctx context.Context, host string) (addrs []stri } func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net.IP, error) { + if err := ctx.Err(); err != nil { + return nil, err + } if _, ok := r.Fail[Mockreq{"ip", host}]; ok { return nil, r.servfail(host) } @@ -134,6 +158,9 @@ func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net } func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) { + if err := ctx.Err(); err != nil { + return nil, err + } if _, ok := r.Fail[Mockreq{"mx", name}]; ok { return nil, r.servfail(name) } @@ -145,6 +172,9 @@ func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, err } func (r MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) { + if err := ctx.Err(); err != nil { + return nil, err + } if _, ok := r.Fail[Mockreq{"txt", name}]; ok { return nil, r.servfail(name) } diff --git a/queue/queue.go b/queue/queue.go index 0776f65..f345196 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -658,7 +658,10 @@ func gatherHosts(resolver dns.Resolver, m Msg, cid int64, qlog *mlog.Log) (hosts } // No MX record. First attempt CNAME lookup. ../rfc/5321:3838 ../rfc/3974:197 + ctx, cancel = context.WithTimeout(cidctx, 30*time.Second) + defer cancel() cname, err := resolver.LookupCNAME(ctx, effectiveDomain.ASCII+".") + cancel() if err != nil && !dns.IsNotFound(err) { return nil, effectiveDomain, false, fmt.Errorf("%w: cname lookup for %s: %v", errDNS, effectiveDomain, err) }