mirror of
https://github.com/mjl-/mox.git
synced 2025-01-15 01:46:26 +03:00
268 lines
7.4 KiB
Go
268 lines
7.4 KiB
Go
|
package mtasts
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/ed25519"
|
||
|
cryptorand "crypto/rand"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"log"
|
||
|
"math/big"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/mjl-/mox/dns"
|
||
|
)
|
||
|
|
||
|
func TestLookup(t *testing.T) {
|
||
|
resolver := dns.MockResolver{
|
||
|
TXT: map[string][]string{
|
||
|
"_mta-sts.a.example.": {"v=STSv1; id=1"},
|
||
|
"_mta-sts.one.example.": {"v=STSv1; id=1", "bogus"},
|
||
|
"_mta-sts.bad.example.": {"v=STSv1; bogus"},
|
||
|
"_mta-sts.multiple.example.": {"v=STSv1; id=1", "v=STSv1; id=2"},
|
||
|
"_mta-sts.c.cnames.example.": {"v=STSv1; id=1"},
|
||
|
"_mta-sts.temperror.example.": {"v=STSv1; id=1"},
|
||
|
"_mta-sts.other.example.": {"bogus", "more"},
|
||
|
},
|
||
|
CNAME: map[string]string{
|
||
|
"_mta-sts.a.cnames.example.": "_mta-sts.b.cnames.example.",
|
||
|
"_mta-sts.b.cnames.example.": "_mta-sts.c.cnames.example.",
|
||
|
"_mta-sts.followtemperror.example.": "_mta-sts.cnametemperror.example.",
|
||
|
},
|
||
|
Fail: map[dns.Mockreq]struct{}{
|
||
|
{Type: "txt", Name: "_mta-sts.temperror.example."}: {},
|
||
|
{Type: "cname", Name: "_mta-sts.cnametemperror.example."}: {},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
test := func(host string, expRecord *Record, expCNAMEs []string, expErr error) {
|
||
|
t.Helper()
|
||
|
|
||
|
record, _, cnames, err := LookupRecord(context.Background(), resolver, dns.Domain{ASCII: host})
|
||
|
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
|
||
|
t.Fatalf("lookup: got err %#v, expected %#v", err, expErr)
|
||
|
}
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
if !reflect.DeepEqual(record, expRecord) || !reflect.DeepEqual(cnames, expCNAMEs) {
|
||
|
t.Fatalf("lookup: got record %#v, cnames %#v, expected %#v %#v", record, cnames, expRecord, expCNAMEs)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
test("absent.example", nil, nil, ErrNoRecord)
|
||
|
test("other.example", nil, nil, ErrNoRecord)
|
||
|
test("a.example", &Record{Version: "STSv1", ID: "1"}, nil, nil)
|
||
|
test("one.example", &Record{Version: "STSv1", ID: "1"}, nil, nil)
|
||
|
test("bad.example", nil, nil, ErrRecordSyntax)
|
||
|
test("multiple.example", nil, nil, ErrMultipleRecords)
|
||
|
test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, []string{"_mta-sts.b.cnames.example.", "_mta-sts.c.cnames.example."}, nil)
|
||
|
test("temperror.example", nil, nil, ErrDNS)
|
||
|
test("cnametemperror.example", nil, nil, ErrDNS)
|
||
|
test("followtemperror.example", nil, nil, ErrDNS)
|
||
|
}
|
||
|
|
||
|
func TestMatches(t *testing.T) {
|
||
|
p, err := ParsePolicy("version: STSv1\nmode: enforce\nmax_age: 1\nmx: a.example\nmx: *.b.example\n")
|
||
|
if err != nil {
|
||
|
t.Fatalf("parsing policy: %s", err)
|
||
|
}
|
||
|
|
||
|
mustParseDomain := func(s string) dns.Domain {
|
||
|
t.Helper()
|
||
|
d, err := dns.ParseDomain(s)
|
||
|
if err != nil {
|
||
|
t.Fatalf("parsing domain %q: %s", s, err)
|
||
|
}
|
||
|
return d
|
||
|
}
|
||
|
|
||
|
match := func(s string) {
|
||
|
t.Helper()
|
||
|
if !p.Matches(mustParseDomain(s)) {
|
||
|
t.Fatalf("unexpected mismatch for %q", s)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
not := func(s string) {
|
||
|
t.Helper()
|
||
|
if p.Matches(mustParseDomain(s)) {
|
||
|
t.Fatalf("unexpected match for %q", s)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
match("a.example")
|
||
|
match("sub.b.example")
|
||
|
not("b.example")
|
||
|
not("sub.sub.b.example")
|
||
|
not("other")
|
||
|
}
|
||
|
|
||
|
type pipeListener struct {
|
||
|
sync.Mutex
|
||
|
closed bool
|
||
|
C chan net.Conn
|
||
|
}
|
||
|
|
||
|
var _ net.Listener = &pipeListener{}
|
||
|
|
||
|
func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
|
||
|
func (l *pipeListener) Dial() (net.Conn, error) {
|
||
|
l.Lock()
|
||
|
defer l.Unlock()
|
||
|
if l.closed {
|
||
|
return nil, errors.New("closed")
|
||
|
}
|
||
|
c, s := net.Pipe()
|
||
|
l.C <- s
|
||
|
return c, nil
|
||
|
}
|
||
|
func (l *pipeListener) Accept() (net.Conn, error) {
|
||
|
conn := <-l.C
|
||
|
if conn == nil {
|
||
|
return nil, io.EOF
|
||
|
}
|
||
|
return conn, nil
|
||
|
}
|
||
|
func (l *pipeListener) Close() error {
|
||
|
l.Lock()
|
||
|
defer l.Unlock()
|
||
|
if !l.closed {
|
||
|
l.closed = true
|
||
|
close(l.C)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
|
||
|
|
||
|
type pipeAddr struct{}
|
||
|
|
||
|
func (a pipeAddr) Network() string { return "pipe" }
|
||
|
func (a pipeAddr) String() string { return "pipe" }
|
||
|
|
||
|
func fakeCert(t *testing.T, expired bool) tls.Certificate {
|
||
|
notAfter := time.Now()
|
||
|
if expired {
|
||
|
notAfter = notAfter.Add(-time.Hour)
|
||
|
} else {
|
||
|
notAfter = notAfter.Add(time.Hour)
|
||
|
}
|
||
|
|
||
|
privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
|
||
|
|
||
|
template := &x509.Certificate{
|
||
|
SerialNumber: big.NewInt(1), // Required field...
|
||
|
DNSNames: []string{"mta-sts.mox.example"},
|
||
|
NotBefore: time.Now().Add(-time.Hour),
|
||
|
NotAfter: notAfter,
|
||
|
}
|
||
|
localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
|
||
|
if err != nil {
|
||
|
t.Fatalf("making certificate: %s", err)
|
||
|
}
|
||
|
cert, err := x509.ParseCertificate(localCertBuf)
|
||
|
if err != nil {
|
||
|
t.Fatalf("parsing generated certificate: %s", err)
|
||
|
}
|
||
|
c := tls.Certificate{
|
||
|
Certificate: [][]byte{localCertBuf},
|
||
|
PrivateKey: privKey,
|
||
|
Leaf: cert,
|
||
|
}
|
||
|
return c
|
||
|
}
|
||
|
|
||
|
func TestFetch(t *testing.T) {
|
||
|
certok := fakeCert(t, false)
|
||
|
certbad := fakeCert(t, true)
|
||
|
|
||
|
defer func() {
|
||
|
HTTPClient.Transport = nil
|
||
|
}()
|
||
|
|
||
|
resolver := dns.MockResolver{
|
||
|
TXT: map[string][]string{
|
||
|
"_mta-sts.mox.example.": {"v=STSv1; id=1"},
|
||
|
"_mta-sts.other.example.": {"v=STSv1; id=1"},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
test := func(cert tls.Certificate, domain string, status int, policyText string, expPolicy *Policy, expErr error) {
|
||
|
t.Helper()
|
||
|
|
||
|
pool := x509.NewCertPool()
|
||
|
pool.AddCert(cert.Leaf)
|
||
|
|
||
|
l := newPipeListener()
|
||
|
defer l.Close()
|
||
|
go func() {
|
||
|
mux := &http.ServeMux{}
|
||
|
mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
|
||
|
w.Header().Add("Location", "/other") // Ignored except for redirect.
|
||
|
w.WriteHeader(status)
|
||
|
w.Write([]byte(policyText))
|
||
|
})
|
||
|
s := &http.Server{
|
||
|
Handler: mux,
|
||
|
TLSConfig: &tls.Config{
|
||
|
Certificates: []tls.Certificate{cert},
|
||
|
},
|
||
|
ErrorLog: log.New(io.Discard, "", 0),
|
||
|
}
|
||
|
s.ServeTLS(l, "", "")
|
||
|
}()
|
||
|
|
||
|
HTTPClient.Transport = &http.Transport{
|
||
|
Dial: func(network, addr string) (net.Conn, error) {
|
||
|
if strings.HasPrefix(addr, "mta-sts.doesnotexist.example") {
|
||
|
return nil, &net.DNSError{IsNotFound: true}
|
||
|
}
|
||
|
return l.Dial()
|
||
|
},
|
||
|
TLSClientConfig: &tls.Config{
|
||
|
RootCAs: pool,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
p, _, err := FetchPolicy(context.Background(), dns.Domain{ASCII: domain})
|
||
|
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
|
||
|
t.Fatalf("policy: got err %#v, expected %#v", err, expErr)
|
||
|
}
|
||
|
if err == nil && !reflect.DeepEqual(p, expPolicy) {
|
||
|
t.Fatalf("policy: got %#v, expected %#v", p, expPolicy)
|
||
|
}
|
||
|
|
||
|
if domain == "doesnotexist.example" {
|
||
|
expErr = ErrNoRecord
|
||
|
}
|
||
|
|
||
|
_, p, err = Get(context.Background(), resolver, dns.Domain{ASCII: domain})
|
||
|
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
|
||
|
t.Fatalf("get: got err %#v, expected %#v", err, expErr)
|
||
|
}
|
||
|
if err == nil && !reflect.DeepEqual(p, expPolicy) {
|
||
|
t.Fatalf("get: got %#v, expected %#v", p, expPolicy)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
test(certok, "mox.example", 200, "bogus", nil, ErrPolicySyntax)
|
||
|
test(certok, "other.example", 200, "bogus", nil, ErrPolicyFetch)
|
||
|
test(certbad, "mox.example", 200, "bogus", nil, ErrPolicyFetch)
|
||
|
test(certok, "mox.example", 404, "bogus", nil, ErrNoPolicy)
|
||
|
test(certok, "doesnotexist.example", 200, "bogus", nil, ErrNoPolicy)
|
||
|
test(certok, "mox.example", 301, "bogus", nil, ErrPolicyFetch)
|
||
|
test(certok, "mox.example", 500, "bogus", nil, ErrPolicyFetch)
|
||
|
large := make([]byte, 64*1024+2)
|
||
|
test(certok, "mox.example", 200, string(large), nil, ErrPolicySyntax)
|
||
|
validPolicy := "version:STSv1\nmode:none\nmax_age:1"
|
||
|
test(certok, "mox.example", 200, validPolicy, &Policy{Version: "STSv1", Mode: "none", MaxAgeSeconds: 1}, nil)
|
||
|
}
|