admin: Enforce and refactor origin checking

Using URLs seems a little cleaner and more correct

cf: https://caddy.community/t/protect-admin-endpoint/15114

(This used to work. Something must have changed recently.)
This commit is contained in:
Matthew Holt 2022-02-15 12:08:12 -07:00
parent 1d0425b26f
commit 40b54434f3
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5

View file

@ -42,6 +42,7 @@ import (
"github.com/caddyserver/certmagic" "github.com/caddyserver/certmagic"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap" "go.uber.org/zap"
"go.uber.org/zap/zapcore"
) )
// AdminConfig configures Caddy's API endpoint, which is used // AdminConfig configures Caddy's API endpoint, which is used
@ -192,6 +193,7 @@ func (admin AdminConfig) newAdminHandler(addr NetworkAddress, remote bool) admin
} else { } else {
muxWrap.enforceHost = !addr.isWildcardInterface() muxWrap.enforceHost = !addr.isWildcardInterface()
muxWrap.allowedOrigins = admin.allowedOrigins(addr) muxWrap.allowedOrigins = admin.allowedOrigins(addr)
muxWrap.enforceOrigin = admin.EnforceOrigin
} }
addRouteWithMetrics := func(pattern string, handlerLabel string, h http.Handler) { addRouteWithMetrics := func(pattern string, handlerLabel string, h http.Handler) {
@ -252,7 +254,7 @@ func (admin AdminConfig) newAdminHandler(addr NetworkAddress, remote bool) admin
// will be used as the default origin. If admin.Origins is // will be used as the default origin. If admin.Origins is
// empty, no origins will be allowed, effectively bricking the // empty, no origins will be allowed, effectively bricking the
// endpoint for non-unix-socket endpoints, but whatever. // endpoint for non-unix-socket endpoints, but whatever.
func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string { func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []*url.URL {
uniqueOrigins := make(map[string]struct{}) uniqueOrigins := make(map[string]struct{})
for _, o := range admin.Origins { for _, o := range admin.Origins {
uniqueOrigins[o] = struct{}{} uniqueOrigins[o] = struct{}{}
@ -276,8 +278,23 @@ func (admin AdminConfig) allowedOrigins(addr NetworkAddress) []string {
uniqueOrigins[addr.JoinHostPort(0)] = struct{}{} uniqueOrigins[addr.JoinHostPort(0)] = struct{}{}
} }
} }
allowed := make([]string, 0, len(uniqueOrigins)) allowed := make([]*url.URL, 0, len(uniqueOrigins))
for origin := range uniqueOrigins { for originStr := range uniqueOrigins {
var origin *url.URL
if strings.Contains(originStr, "://") {
var err error
origin, err = url.Parse(originStr)
if err != nil {
continue
}
origin.Path = ""
origin.RawPath = ""
origin.Fragment = ""
origin.RawFragment = ""
origin.RawQuery = ""
} else {
origin = &url.URL{Host: originStr}
}
allowed = append(allowed, origin) allowed = append(allowed, origin)
} }
return allowed return allowed
@ -358,7 +375,7 @@ func replaceLocalAdminServer(cfg *Config) error {
adminLogger.Info("admin endpoint started", adminLogger.Info("admin endpoint started",
zap.String("address", addr.String()), zap.String("address", addr.String()),
zap.Bool("enforce_origin", adminConfig.EnforceOrigin), zap.Bool("enforce_origin", adminConfig.EnforceOrigin),
zap.Strings("origins", handler.allowedOrigins)) zap.Array("origins", loggableURLArray(handler.allowedOrigins)))
if !handler.enforceHost { if !handler.enforceHost {
adminLogger.Warn("admin endpoint on open interface; host checking disabled", adminLogger.Warn("admin endpoint on open interface; host checking disabled",
@ -650,10 +667,10 @@ type AdminRoute struct {
type adminHandler struct { type adminHandler struct {
mux *http.ServeMux mux *http.ServeMux
// security for local/plaintext) endpoint, on by default // security for local/plaintext endpoint
enforceOrigin bool enforceOrigin bool
enforceHost bool enforceHost bool
allowedOrigins []string allowedOrigins []*url.URL
// security for remote/encrypted endpoint // security for remote/encrypted endpoint
remoteControl *RemoteAdmin remoteControl *RemoteAdmin
@ -779,8 +796,8 @@ func (h adminHandler) handleError(w http.ResponseWriter, r *http.Request, err er
// rebinding attacks. // rebinding attacks.
func (h adminHandler) checkHost(r *http.Request) error { func (h adminHandler) checkHost(r *http.Request) error {
var allowed bool var allowed bool
for _, allowedHost := range h.allowedOrigins { for _, allowedOrigin := range h.allowedOrigins {
if r.Host == allowedHost { if r.Host == allowedOrigin.Host {
allowed = true allowed = true
break break
} }
@ -799,43 +816,45 @@ func (h adminHandler) checkHost(r *http.Request) error {
// sites from issuing requests to our listener. It // sites from issuing requests to our listener. It
// returns the origin that was obtained from r. // returns the origin that was obtained from r.
func (h adminHandler) checkOrigin(r *http.Request) (string, error) { func (h adminHandler) checkOrigin(r *http.Request) (string, error) {
origin := h.getOriginHost(r) originStr, origin := h.getOrigin(r)
if origin == "" { if origin == nil {
return origin, APIError{ return "", APIError{
HTTPStatus: http.StatusForbidden, HTTPStatus: http.StatusForbidden,
Err: fmt.Errorf("missing required Origin header"), Err: fmt.Errorf("required Origin header is missing or invalid"),
} }
} }
if !h.originAllowed(origin) { if !h.originAllowed(origin) {
return origin, APIError{ return "", APIError{
HTTPStatus: http.StatusForbidden, HTTPStatus: http.StatusForbidden,
Err: fmt.Errorf("client is not allowed to access from origin %s", origin), Err: fmt.Errorf("client is not allowed to access from origin '%s'", originStr),
} }
} }
return origin, nil return origin.String(), nil
} }
func (h adminHandler) getOriginHost(r *http.Request) string { func (h adminHandler) getOrigin(r *http.Request) (string, *url.URL) {
origin := r.Header.Get("Origin") origin := r.Header.Get("Origin")
if origin == "" { if origin == "" {
origin = r.Header.Get("Referer") origin = r.Header.Get("Referer")
} }
originURL, err := url.Parse(origin) originURL, err := url.Parse(origin)
if err == nil && originURL.Host != "" { if err != nil {
origin = originURL.Host return origin, nil
} }
return origin originURL.Path = ""
originURL.RawPath = ""
originURL.Fragment = ""
originURL.RawFragment = ""
originURL.RawQuery = ""
return origin, originURL
} }
func (h adminHandler) originAllowed(origin string) bool { func (h adminHandler) originAllowed(origin *url.URL) bool {
for _, allowedOrigin := range h.allowedOrigins { for _, allowedOrigin := range h.allowedOrigins {
originCopy := origin if allowedOrigin.Scheme != "" && origin.Scheme != allowedOrigin.Scheme {
if !strings.Contains(allowedOrigin, "://") { continue
// no scheme specified, so allow both
originCopy = strings.TrimPrefix(originCopy, "http://")
originCopy = strings.TrimPrefix(originCopy, "https://")
} }
if originCopy == allowedOrigin { if origin.Host == allowedOrigin.Host {
return true return true
} }
} }
@ -1189,6 +1208,18 @@ func decodeBase64DERCert(certStr string) (*x509.Certificate, error) {
return x509.ParseCertificate(derBytes) return x509.ParseCertificate(derBytes)
} }
type loggableURLArray []*url.URL
func (ua loggableURLArray) MarshalLogArray(enc zapcore.ArrayEncoder) error {
if ua == nil {
return nil
}
for _, u := range ua {
enc.AppendString(u.String())
}
return nil
}
var ( var (
// DefaultAdminListen is the address for the local admin // DefaultAdminListen is the address for the local admin
// listener, if none is specified at startup. // listener, if none is specified at startup.