acmeserver: Configurable resolvers, fix smallstep deprecations (#5500)

* acmeserver: Configurable `resolvers`, fix smallstep deprecations

* Improve default net/port

* Update proxy resolvers parsing to use the new function

* Update listeners.go

Co-authored-by: itsxaos <33079230+itsxaos@users.noreply.github.com>

---------

Co-authored-by: itsxaos <33079230+itsxaos@users.noreply.github.com>
This commit is contained in:
Francis Lavoie 2023-05-03 13:07:22 -04:00 committed by GitHub
parent 1af419e7ec
commit 3f20a7c9f3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 284 additions and 53 deletions

View file

@ -303,13 +303,19 @@ func IsUnixNetwork(netw string) bool {
// Network addresses are distinct from URLs and do not // Network addresses are distinct from URLs and do not
// use URL syntax. // use URL syntax.
func ParseNetworkAddress(addr string) (NetworkAddress, error) { func ParseNetworkAddress(addr string) (NetworkAddress, error) {
return ParseNetworkAddressWithDefaults(addr, "tcp", 0)
}
// ParseNetworkAddressWithDefaults is like ParseNetworkAddress but allows
// the default network and port to be specified.
func ParseNetworkAddressWithDefaults(addr, defaultNetwork string, defaultPort uint) (NetworkAddress, error) {
var host, port string var host, port string
network, host, port, err := SplitNetworkAddress(addr) network, host, port, err := SplitNetworkAddress(addr)
if err != nil { if err != nil {
return NetworkAddress{}, err return NetworkAddress{}, err
} }
if network == "" { if network == "" {
network = "tcp" network = defaultNetwork
} }
if IsUnixNetwork(network) { if IsUnixNetwork(network) {
return NetworkAddress{ return NetworkAddress{
@ -318,7 +324,10 @@ func ParseNetworkAddress(addr string) (NetworkAddress, error) {
}, nil }, nil
} }
var start, end uint64 var start, end uint64
if port != "" { if port == "" {
start = uint64(defaultPort)
end = uint64(defaultPort)
} else {
before, after, found := strings.Cut(port, "-") before, after, found := strings.Cut(port, "-")
if !found { if !found {
after = before after = before

View file

@ -175,47 +175,57 @@ func TestJoinNetworkAddress(t *testing.T) {
func TestParseNetworkAddress(t *testing.T) { func TestParseNetworkAddress(t *testing.T) {
for i, tc := range []struct { for i, tc := range []struct {
input string input string
expectAddr NetworkAddress defaultNetwork string
expectErr bool defaultPort uint
expectAddr NetworkAddress
expectErr bool
}{ }{
{ {
input: "", input: "",
expectErr: true, expectErr: true,
}, },
{ {
input: ":", input: ":",
defaultNetwork: "udp",
expectAddr: NetworkAddress{ expectAddr: NetworkAddress{
Network: "tcp", Network: "udp",
}, },
}, },
{ {
input: "[::]", input: "[::]",
defaultNetwork: "udp",
defaultPort: 53,
expectAddr: NetworkAddress{ expectAddr: NetworkAddress{
Network: "tcp", Network: "udp",
Host: "::", Host: "::",
StartPort: 53,
EndPort: 53,
}, },
}, },
{ {
input: ":1234", input: ":1234",
defaultNetwork: "udp",
expectAddr: NetworkAddress{ expectAddr: NetworkAddress{
Network: "tcp", Network: "udp",
Host: "", Host: "",
StartPort: 1234, StartPort: 1234,
EndPort: 1234, EndPort: 1234,
}, },
}, },
{ {
input: "tcp/:1234", input: "udp/:1234",
defaultNetwork: "udp",
expectAddr: NetworkAddress{ expectAddr: NetworkAddress{
Network: "tcp", Network: "udp",
Host: "", Host: "",
StartPort: 1234, StartPort: 1234,
EndPort: 1234, EndPort: 1234,
}, },
}, },
{ {
input: "tcp6/:1234", input: "tcp6/:1234",
defaultNetwork: "tcp",
expectAddr: NetworkAddress{ expectAddr: NetworkAddress{
Network: "tcp6", Network: "tcp6",
Host: "", Host: "",
@ -224,7 +234,8 @@ func TestParseNetworkAddress(t *testing.T) {
}, },
}, },
{ {
input: "tcp4/localhost:1234", input: "tcp4/localhost:1234",
defaultNetwork: "tcp",
expectAddr: NetworkAddress{ expectAddr: NetworkAddress{
Network: "tcp4", Network: "tcp4",
Host: "localhost", Host: "localhost",
@ -233,14 +244,16 @@ func TestParseNetworkAddress(t *testing.T) {
}, },
}, },
{ {
input: "unix//foo/bar", input: "unix//foo/bar",
defaultNetwork: "tcp",
expectAddr: NetworkAddress{ expectAddr: NetworkAddress{
Network: "unix", Network: "unix",
Host: "/foo/bar", Host: "/foo/bar",
}, },
}, },
{ {
input: "localhost:1234-1234", input: "localhost:1234-1234",
defaultNetwork: "tcp",
expectAddr: NetworkAddress{ expectAddr: NetworkAddress{
Network: "tcp", Network: "tcp",
Host: "localhost", Host: "localhost",
@ -249,11 +262,13 @@ func TestParseNetworkAddress(t *testing.T) {
}, },
}, },
{ {
input: "localhost:2-1", input: "localhost:2-1",
expectErr: true, defaultNetwork: "tcp",
expectErr: true,
}, },
{ {
input: "localhost:0", input: "localhost:0",
defaultNetwork: "tcp",
expectAddr: NetworkAddress{ expectAddr: NetworkAddress{
Network: "tcp", Network: "tcp",
Host: "localhost", Host: "localhost",
@ -262,11 +277,138 @@ func TestParseNetworkAddress(t *testing.T) {
}, },
}, },
{ {
input: "localhost:1-999999999999", input: "localhost:1-999999999999",
expectErr: true, defaultNetwork: "tcp",
expectErr: true,
}, },
} { } {
actualAddr, err := ParseNetworkAddress(tc.input) actualAddr, err := ParseNetworkAddressWithDefaults(tc.input, tc.defaultNetwork, tc.defaultPort)
if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err)
}
if !tc.expectErr && err != nil {
t.Errorf("Test %d: Expected no error but got: %v", i, err)
}
if actualAddr.Network != tc.expectAddr.Network {
t.Errorf("Test %d: Expected network '%v' but got '%v'", i, tc.expectAddr, actualAddr)
}
if !reflect.DeepEqual(tc.expectAddr, actualAddr) {
t.Errorf("Test %d: Expected addresses %v but got %v", i, tc.expectAddr, actualAddr)
}
}
}
func TestParseNetworkAddressWithDefaults(t *testing.T) {
for i, tc := range []struct {
input string
defaultNetwork string
defaultPort uint
expectAddr NetworkAddress
expectErr bool
}{
{
input: "",
expectErr: true,
},
{
input: ":",
defaultNetwork: "udp",
expectAddr: NetworkAddress{
Network: "udp",
},
},
{
input: "[::]",
defaultNetwork: "udp",
defaultPort: 53,
expectAddr: NetworkAddress{
Network: "udp",
Host: "::",
StartPort: 53,
EndPort: 53,
},
},
{
input: ":1234",
defaultNetwork: "udp",
expectAddr: NetworkAddress{
Network: "udp",
Host: "",
StartPort: 1234,
EndPort: 1234,
},
},
{
input: "udp/:1234",
defaultNetwork: "udp",
expectAddr: NetworkAddress{
Network: "udp",
Host: "",
StartPort: 1234,
EndPort: 1234,
},
},
{
input: "tcp6/:1234",
defaultNetwork: "tcp",
expectAddr: NetworkAddress{
Network: "tcp6",
Host: "",
StartPort: 1234,
EndPort: 1234,
},
},
{
input: "tcp4/localhost:1234",
defaultNetwork: "tcp",
expectAddr: NetworkAddress{
Network: "tcp4",
Host: "localhost",
StartPort: 1234,
EndPort: 1234,
},
},
{
input: "unix//foo/bar",
defaultNetwork: "tcp",
expectAddr: NetworkAddress{
Network: "unix",
Host: "/foo/bar",
},
},
{
input: "localhost:1234-1234",
defaultNetwork: "tcp",
expectAddr: NetworkAddress{
Network: "tcp",
Host: "localhost",
StartPort: 1234,
EndPort: 1234,
},
},
{
input: "localhost:2-1",
defaultNetwork: "tcp",
expectErr: true,
},
{
input: "localhost:0",
defaultNetwork: "tcp",
expectAddr: NetworkAddress{
Network: "tcp",
Host: "localhost",
StartPort: 0,
EndPort: 0,
},
},
{
input: "localhost:1-999999999999",
defaultNetwork: "tcp",
expectErr: true,
},
} {
actualAddr, err := ParseNetworkAddressWithDefaults(tc.input, tc.defaultNetwork, tc.defaultPort)
if tc.expectErr && err == nil { if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err) t.Errorf("Test %d: Expected error but got: %v", i, err)
} }

View file

@ -8,7 +8,6 @@ import (
"net" "net"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -471,16 +470,9 @@ type UpstreamResolver struct {
// and ensures they're ready to be used. // and ensures they're ready to be used.
func (u *UpstreamResolver) ParseAddresses() error { func (u *UpstreamResolver) ParseAddresses() error {
for _, v := range u.Addresses { for _, v := range u.Addresses {
addr, err := caddy.ParseNetworkAddress(v) addr, err := caddy.ParseNetworkAddressWithDefaults(v, "udp", 53)
if err != nil { if err != nil {
// If a port wasn't specified for the resolver, return err
// try defaulting to 53 and parse again
if strings.Contains(err.Error(), "missing port in address") {
addr, err = caddy.ParseNetworkAddress(v + ":53")
}
if err != nil {
return err
}
} }
if addr.PortRangeSize() != 1 { if addr.PortRangeSize() != 1 {
return fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr) return fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr)

View file

@ -15,7 +15,10 @@
package acmeserver package acmeserver
import ( import (
"context"
"fmt" "fmt"
weakrand "math/rand"
"net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -28,7 +31,7 @@ import (
"github.com/caddyserver/caddy/v2/modules/caddypki" "github.com/caddyserver/caddy/v2/modules/caddypki"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
acmeAPI "github.com/smallstep/certificates/acme/api" "github.com/smallstep/certificates/acme/api"
acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql" acmeNoSQL "github.com/smallstep/certificates/acme/db/nosql"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
@ -76,8 +79,26 @@ type Handler struct {
// changed or removed in the future. // changed or removed in the future.
SignWithRoot bool `json:"sign_with_root,omitempty"` SignWithRoot bool `json:"sign_with_root,omitempty"`
// The addresses of DNS resolvers to use when looking up
// the TXT records for solving DNS challenges.
// It accepts [network addresses](/docs/conventions#network-addresses)
// with port range of only 1. If the host is an IP address,
// it will be dialed directly to resolve the upstream server.
// If the host is not an IP address, the addresses are resolved
// using the [name resolution convention](https://golang.org/pkg/net/#hdr-Name_Resolution)
// of the Go standard library. If the array contains more
// than 1 resolver address, one is chosen at random.
Resolvers []string `json:"resolvers,omitempty"`
logger *zap.Logger
resolvers []caddy.NetworkAddress
ctx caddy.Context
acmeDB acme.DB
acmeAuth *authority.Authority
acmeClient acme.Client
acmeLinker acme.Linker
acmeEndpoints http.Handler acmeEndpoints http.Handler
logger *zap.Logger
} }
// CaddyModule returns the Caddy module information. // CaddyModule returns the Caddy module information.
@ -90,7 +111,9 @@ func (Handler) CaddyModule() caddy.ModuleInfo {
// Provision sets up the ACME server handler. // Provision sets up the ACME server handler.
func (ash *Handler) Provision(ctx caddy.Context) error { func (ash *Handler) Provision(ctx caddy.Context) error {
ash.ctx = ctx
ash.logger = ctx.Logger() ash.logger = ctx.Logger()
// set some defaults // set some defaults
if ash.CA == "" { if ash.CA == "" {
ash.CA = caddypki.DefaultCAID ash.CA = caddypki.DefaultCAID
@ -142,31 +165,30 @@ func (ash *Handler) Provision(ctx caddy.Context) error {
DB: database, DB: database,
} }
auth, err := ca.NewAuthority(authorityConfig) ash.acmeAuth, err = ca.NewAuthority(authorityConfig)
if err != nil { if err != nil {
return err return err
} }
var acmeDB acme.DB ash.acmeDB, err = acmeNoSQL.New(ash.acmeAuth.GetDatabase().(nosql.DB))
if authorityConfig.DB != nil { if err != nil {
acmeDB, err = acmeNoSQL.New(auth.GetDatabase().(nosql.DB)) return fmt.Errorf("configuring ACME DB: %v", err)
if err != nil {
return fmt.Errorf("configuring ACME DB: %v", err)
}
} }
// create the router for the ACME endpoints ash.acmeClient, err = ash.makeClient()
acmeRouterHandler := acmeAPI.NewHandler(acmeAPI.HandlerOptions{ if err != nil {
CA: auth, return err
DB: acmeDB, // stores all the server state }
DNS: ash.Host, // used for directory links
Prefix: strings.Trim(ash.PathPrefix, "/"), // used for directory links ash.acmeLinker = acme.NewLinker(
}) ash.Host,
strings.Trim(ash.PathPrefix, "/"),
)
// extract its http.Handler so we can use it directly // extract its http.Handler so we can use it directly
r := chi.NewRouter() r := chi.NewRouter()
r.Route(ash.PathPrefix, func(r chi.Router) { r.Route(ash.PathPrefix, func(r chi.Router) {
acmeRouterHandler.Route(r) api.Route(r)
}) })
ash.acmeEndpoints = r ash.acmeEndpoints = r
@ -175,6 +197,16 @@ func (ash *Handler) Provision(ctx caddy.Context) error {
func (ash Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { func (ash Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
if strings.HasPrefix(r.URL.Path, ash.PathPrefix) { if strings.HasPrefix(r.URL.Path, ash.PathPrefix) {
acmeCtx := acme.NewContext(
r.Context(),
ash.acmeDB,
ash.acmeClient,
ash.acmeLinker,
nil,
)
acmeCtx = authority.NewContext(acmeCtx, ash.acmeAuth)
r = r.WithContext(acmeCtx)
ash.acmeEndpoints.ServeHTTP(w, r) ash.acmeEndpoints.ServeHTTP(w, r)
return nil return nil
} }
@ -227,6 +259,55 @@ func (ash Handler) openDatabase() (*db.AuthDB, error) {
return database.(databaseCloser).DB, err return database.(databaseCloser).DB, err
} }
// makeClient creates an ACME client which will use a custom
// resolver instead of net.DefaultResolver.
func (ash Handler) makeClient() (acme.Client, error) {
for _, v := range ash.Resolvers {
addr, err := caddy.ParseNetworkAddressWithDefaults(v, "udp", 53)
if err != nil {
return nil, err
}
if addr.PortRangeSize() != 1 {
return nil, fmt.Errorf("resolver address must have exactly one address; cannot call %v", addr)
}
ash.resolvers = append(ash.resolvers, addr)
}
var resolver *net.Resolver
if len(ash.resolvers) != 0 {
dialer := &net.Dialer{
Timeout: 2 * time.Second,
}
resolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
//nolint:gosec
addr := ash.resolvers[weakrand.Intn(len(ash.resolvers))]
return dialer.DialContext(ctx, addr.Network, addr.JoinHostPort(0))
},
}
} else {
resolver = net.DefaultResolver
}
return resolverClient{
Client: acme.NewClient(),
resolver: resolver,
ctx: ash.ctx,
}, nil
}
type resolverClient struct {
acme.Client
resolver *net.Resolver
ctx context.Context
}
func (c resolverClient) LookupTxt(name string) ([]string, error) {
return c.resolver.LookupTXT(c.ctx, name)
}
const defaultPathPrefix = "/acme/" const defaultPathPrefix = "/acme/"
var keyCleaner = regexp.MustCompile(`[^\w.-_]`) var keyCleaner = regexp.MustCompile(`[^\w.-_]`)

View file

@ -29,8 +29,9 @@ func init() {
// parseACMEServer sets up an ACME server handler from Caddyfile tokens. // parseACMEServer sets up an ACME server handler from Caddyfile tokens.
// //
// acme_server [<matcher>] { // acme_server [<matcher>] {
// ca <id> // ca <id>
// lifetime <duration> // lifetime <duration>
// resolvers <addresses...>
// } // }
func parseACMEServer(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error) { func parseACMEServer(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error) {
if !h.Next() { if !h.Next() {
@ -74,6 +75,12 @@ func parseACMEServer(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error
} }
acmeServer.Lifetime = caddy.Duration(dur) acmeServer.Lifetime = caddy.Duration(dur)
case "resolvers":
acmeServer.Resolvers = h.RemainingArgs()
if len(acmeServer.Resolvers) == 0 {
return nil, h.Errf("must specify at least one resolver address")
}
} }
} }
} }