mirror of
https://github.com/mjl-/mox.git
synced 2025-01-15 18:06:27 +03:00
d1b87cdb0d
since we are now at go1.21 as minimum.
252 lines
6.5 KiB
Go
252 lines
6.5 KiB
Go
package dns
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"slices"
|
|
|
|
"github.com/mjl-/adns"
|
|
)
|
|
|
|
// MockResolver is a Resolver used for testing.
|
|
// Set DNS records in the fields, which map FQDNs (with trailing dot) to values.
|
|
type MockResolver struct {
|
|
PTR map[string][]string
|
|
A map[string][]string
|
|
AAAA map[string][]string
|
|
TXT map[string][]string
|
|
MX map[string][]*net.MX
|
|
TLSA map[string][]adns.TLSA // Keys are e.g. _25._tcp.<host>.
|
|
CNAME map[string]string
|
|
Fail []string // Records of the form "type name", e.g. "cname localhost." that will return a servfail.
|
|
AllAuthentic bool // Default value for authentic in responses. Overridden with Authentic and Inauthentic
|
|
Authentic []string // Like Fail, but records that cause the response to be authentic.
|
|
Inauthentic []string // Like Authentic, but making response inauthentic.
|
|
}
|
|
|
|
type mockReq struct {
|
|
Type string // E.g. "cname", "txt", "mx", "ptr", etc.
|
|
Name string // Name of request. For TLSA, the full requested DNS name, e.g. _25._tcp.<host>.
|
|
}
|
|
|
|
func (mr mockReq) String() string {
|
|
return mr.Type + " " + mr.Name
|
|
}
|
|
|
|
var _ Resolver = MockResolver{}
|
|
|
|
func (r MockResolver) result(ctx context.Context, mr mockReq) (string, adns.Result, error) {
|
|
result := adns.Result{Authentic: r.AllAuthentic}
|
|
|
|
if err := ctx.Err(); err != nil {
|
|
return "", result, err
|
|
}
|
|
|
|
updateAuthentic := func(mock string) {
|
|
if slices.Contains(r.Authentic, mock) {
|
|
result.Authentic = true
|
|
}
|
|
if slices.Contains(r.Inauthentic, mock) {
|
|
result.Authentic = false
|
|
}
|
|
}
|
|
|
|
for {
|
|
if slices.Contains(r.Fail, mr.String()) {
|
|
updateAuthentic(mr.String())
|
|
return mr.Name, adns.Result{}, r.servfail(mr.Name)
|
|
}
|
|
|
|
cname, ok := r.CNAME[mr.Name]
|
|
if !ok {
|
|
updateAuthentic(mr.String())
|
|
break
|
|
}
|
|
updateAuthentic("cname " + mr.Name)
|
|
if mr.Type == "cname" {
|
|
return mr.Name, result, nil
|
|
}
|
|
mr.Name = cname
|
|
}
|
|
return mr.Name, result, nil
|
|
}
|
|
|
|
func (r MockResolver) nxdomain(s string) error {
|
|
return &adns.DNSError{
|
|
Err: "no record",
|
|
Name: s,
|
|
Server: "mock",
|
|
IsNotFound: true,
|
|
}
|
|
}
|
|
|
|
func (r MockResolver) servfail(s string) error {
|
|
return &adns.DNSError{
|
|
Err: "temp error",
|
|
Name: s,
|
|
Server: "mock",
|
|
IsTemporary: true,
|
|
}
|
|
}
|
|
|
|
func (r MockResolver) LookupPort(ctx context.Context, network, service string) (port int, err error) {
|
|
if err := ctx.Err(); err != nil {
|
|
return 0, err
|
|
}
|
|
return net.LookupPort(network, service)
|
|
}
|
|
|
|
func (r MockResolver) LookupCNAME(ctx context.Context, name string) (string, adns.Result, error) {
|
|
mr := mockReq{"cname", name}
|
|
name, result, err := r.result(ctx, mr)
|
|
if err != nil {
|
|
return name, result, err
|
|
}
|
|
cname, ok := r.CNAME[name]
|
|
if !ok {
|
|
return cname, result, r.nxdomain(name)
|
|
}
|
|
return cname, result, nil
|
|
}
|
|
|
|
func (r MockResolver) LookupAddr(ctx context.Context, ip string) ([]string, adns.Result, error) {
|
|
mr := mockReq{"ptr", ip}
|
|
_, result, err := r.result(ctx, mr)
|
|
if err != nil {
|
|
return nil, result, err
|
|
}
|
|
l, ok := r.PTR[ip]
|
|
if !ok {
|
|
return nil, result, r.nxdomain(ip)
|
|
}
|
|
return l, result, nil
|
|
}
|
|
|
|
func (r MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, adns.Result, error) {
|
|
mr := mockReq{"ns", name}
|
|
_, result, err := r.result(ctx, mr)
|
|
if err != nil {
|
|
return nil, result, err
|
|
}
|
|
return nil, result, r.servfail("ns not implemented")
|
|
}
|
|
|
|
func (r MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, adns.Result, error) {
|
|
xname := fmt.Sprintf("_%s._%s.%s", service, proto, name)
|
|
mr := mockReq{"srv", xname}
|
|
name, result, err := r.result(ctx, mr)
|
|
if err != nil {
|
|
return name, nil, result, err
|
|
}
|
|
return name, nil, result, r.servfail("srv not implemented")
|
|
}
|
|
|
|
func (r MockResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, adns.Result, error) {
|
|
// todo: make closer to resolver, doing a & aaaa lookups, including their error/(in)secure status.
|
|
mr := mockReq{"ipaddr", host}
|
|
_, result, err := r.result(ctx, mr)
|
|
if err != nil {
|
|
return nil, result, err
|
|
}
|
|
addrs, result1, err := r.LookupHost(ctx, host)
|
|
result.Authentic = result.Authentic && result1.Authentic
|
|
if err != nil {
|
|
return nil, result, err
|
|
}
|
|
ips := make([]net.IPAddr, len(addrs))
|
|
for i, a := range addrs {
|
|
ip := net.ParseIP(a)
|
|
if ip == nil {
|
|
return nil, result, fmt.Errorf("malformed ip %q", a)
|
|
}
|
|
ips[i] = net.IPAddr{IP: ip}
|
|
}
|
|
return ips, result, nil
|
|
}
|
|
|
|
func (r MockResolver) LookupHost(ctx context.Context, host string) ([]string, adns.Result, error) {
|
|
// todo: make closer to resolver, doing a & aaaa lookups, including their error/(in)secure status.
|
|
mr := mockReq{"host", host}
|
|
_, result, err := r.result(ctx, mr)
|
|
if err != nil {
|
|
return nil, result, err
|
|
}
|
|
var addrs []string
|
|
addrs = append(addrs, r.A[host]...)
|
|
addrs = append(addrs, r.AAAA[host]...)
|
|
if len(addrs) == 0 {
|
|
return nil, result, r.nxdomain(host)
|
|
}
|
|
return addrs, result, nil
|
|
}
|
|
|
|
func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net.IP, adns.Result, error) {
|
|
mr := mockReq{"ip", host}
|
|
name, result, err := r.result(ctx, mr)
|
|
if err != nil {
|
|
return nil, result, err
|
|
}
|
|
var ips []net.IP
|
|
switch network {
|
|
case "ip", "ip4":
|
|
for _, ip := range r.A[name] {
|
|
ips = append(ips, net.ParseIP(ip))
|
|
}
|
|
}
|
|
switch network {
|
|
case "ip", "ip6":
|
|
for _, ip := range r.AAAA[name] {
|
|
ips = append(ips, net.ParseIP(ip))
|
|
}
|
|
}
|
|
if len(ips) == 0 {
|
|
return nil, result, r.nxdomain(host)
|
|
}
|
|
return ips, result, nil
|
|
}
|
|
|
|
func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adns.Result, error) {
|
|
mr := mockReq{"mx", name}
|
|
name, result, err := r.result(ctx, mr)
|
|
if err != nil {
|
|
return nil, result, err
|
|
}
|
|
l, ok := r.MX[name]
|
|
if !ok {
|
|
return nil, result, r.nxdomain(name)
|
|
}
|
|
return l, result, nil
|
|
}
|
|
|
|
func (r MockResolver) LookupTXT(ctx context.Context, name string) ([]string, adns.Result, error) {
|
|
mr := mockReq{"txt", name}
|
|
name, result, err := r.result(ctx, mr)
|
|
if err != nil {
|
|
return nil, result, err
|
|
}
|
|
l, ok := r.TXT[name]
|
|
if !ok {
|
|
return nil, result, r.nxdomain(name)
|
|
}
|
|
return l, result, nil
|
|
}
|
|
|
|
func (r MockResolver) LookupTLSA(ctx context.Context, port int, protocol string, host string) ([]adns.TLSA, adns.Result, error) {
|
|
var name string
|
|
if port == 0 && protocol == "" {
|
|
name = host
|
|
} else {
|
|
name = fmt.Sprintf("_%d._%s.%s", port, protocol, host)
|
|
}
|
|
mr := mockReq{"tlsa", name}
|
|
name, result, err := r.result(ctx, mr)
|
|
if err != nil {
|
|
return nil, result, err
|
|
}
|
|
l, ok := r.TLSA[name]
|
|
if !ok {
|
|
return nil, result, r.nxdomain(name)
|
|
}
|
|
return l, result, nil
|
|
}
|