From ae645ef2e996824b457990e17ad3c87731c29ebe Mon Sep 17 00:00:00 2001
From: Tw <tw19881113@gmail.com>
Date: Mon, 8 May 2017 10:36:58 +0800
Subject: [PATCH] Introduce `limits` middleware

1. Replace original `maxrequestbody` directive.
2. Add request header limit.

fix issue #1587

Signed-off-by: Tw <tw19881113@gmail.com>
---
 caddyhttp/caddyhttp.go                        |   2 +-
 caddyhttp/httpserver/plugin.go                |   2 +-
 caddyhttp/httpserver/replacer.go              |   2 +-
 caddyhttp/httpserver/server.go                | 111 +++++-----------
 caddyhttp/httpserver/server_test.go           |  35 +++++-
 caddyhttp/httpserver/siteconfig.go            |  10 +-
 caddyhttp/limits/handler.go                   |  90 +++++++++++++
 caddyhttp/limits/handler_test.go              |  35 ++++++
 .../maxrequestbody.go => limits/setup.go}     |  99 ++++++++++-----
 .../setup_test.go}                            | 118 ++++++++++++++----
 caddyhttp/proxy/proxy.go                      |   2 +-
 11 files changed, 363 insertions(+), 143 deletions(-)
 create mode 100644 caddyhttp/limits/handler.go
 create mode 100644 caddyhttp/limits/handler_test.go
 rename caddyhttp/{maxrequestbody/maxrequestbody.go => limits/setup.go} (67%)
 rename caddyhttp/{maxrequestbody/maxrequestbody_test.go => limits/setup_test.go} (64%)

diff --git a/caddyhttp/caddyhttp.go b/caddyhttp/caddyhttp.go
index 2f1d8dc7b..429702a54 100644
--- a/caddyhttp/caddyhttp.go
+++ b/caddyhttp/caddyhttp.go
@@ -16,9 +16,9 @@ import (
 	_ "github.com/mholt/caddy/caddyhttp/header"
 	_ "github.com/mholt/caddy/caddyhttp/index"
 	_ "github.com/mholt/caddy/caddyhttp/internalsrv"
+	_ "github.com/mholt/caddy/caddyhttp/limits"
 	_ "github.com/mholt/caddy/caddyhttp/log"
 	_ "github.com/mholt/caddy/caddyhttp/markdown"
-	_ "github.com/mholt/caddy/caddyhttp/maxrequestbody"
 	_ "github.com/mholt/caddy/caddyhttp/mime"
 	_ "github.com/mholt/caddy/caddyhttp/pprof"
 	_ "github.com/mholt/caddy/caddyhttp/proxy"
diff --git a/caddyhttp/httpserver/plugin.go b/caddyhttp/httpserver/plugin.go
index 6f632f6d6..7c652be85 100644
--- a/caddyhttp/httpserver/plugin.go
+++ b/caddyhttp/httpserver/plugin.go
@@ -436,7 +436,7 @@ var directives = []string{
 	"root",
 	"index",
 	"bind",
-	"maxrequestbody", // TODO: 'limits'
+	"limits",
 	"timeouts",
 	"tls",
 
diff --git a/caddyhttp/httpserver/replacer.go b/caddyhttp/httpserver/replacer.go
index e9b51a033..ad5b12c12 100644
--- a/caddyhttp/httpserver/replacer.go
+++ b/caddyhttp/httpserver/replacer.go
@@ -302,7 +302,7 @@ func (r *replacer) getSubstitution(key string) string {
 		}
 		_, err := ioutil.ReadAll(r.request.Body)
 		if err != nil {
-			if _, ok := err.(MaxBytesExceeded); ok {
+			if err == MaxBytesExceededErr {
 				return r.emptyValue
 			}
 		}
diff --git a/caddyhttp/httpserver/server.go b/caddyhttp/httpserver/server.go
index 6db0c0bd6..e7e29ace5 100644
--- a/caddyhttp/httpserver/server.go
+++ b/caddyhttp/httpserver/server.go
@@ -4,8 +4,8 @@ package httpserver
 import (
 	"context"
 	"crypto/tls"
+	"errors"
 	"fmt"
-	"io"
 	"log"
 	"net"
 	"net/http"
@@ -66,6 +66,7 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
 		sites:       group,
 		connTimeout: GracefulTimeout,
 	}
+	s.Server = makeHTTPServerWithHeaderLimit(s.Server, group)
 	s.Server.Handler = s // this is weird, but whatever
 
 	// extract TLS settings from each site config to build
@@ -127,6 +128,32 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
 	return s, nil
 }
 
+// makeHTTPServerWithHeaderLimit apply minimum header limit within a group to given http.Server
+func makeHTTPServerWithHeaderLimit(s *http.Server, group []*SiteConfig) *http.Server {
+	var min int64
+	for _, cfg := range group {
+		limit := cfg.Limits.MaxRequestHeaderSize
+		if limit == 0 {
+			continue
+		}
+
+		// not set yet
+		if min == 0 {
+			min = limit
+		}
+
+		// find a better one
+		if limit < min {
+			min = limit
+		}
+	}
+
+	if min > 0 {
+		s.MaxHeaderBytes = int(min)
+	}
+	return s
+}
+
 // makeHTTPServerWithTimeouts makes an http.Server from the group of
 // configs in a way that configures timeouts (or, if not set, it uses
 // the default timeouts) by combining the configuration of each
@@ -359,20 +386,6 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
 		}
 	}
 
-	// Apply the path-based request body size limit
-	// The error returned by MaxBytesReader is meant to be handled
-	// by whichever middleware/plugin that receives it when calling
-	// .Read() or a similar method on the request body
-	// TODO: Make this middleware instead?
-	if r.Body != nil {
-		for _, pathlimit := range vhost.MaxRequestBodySizes {
-			if Path(r.URL.Path).Matches(pathlimit.Path) {
-				r.Body = MaxBytesReader(w, r.Body, pathlimit.Limit)
-				break
-			}
-		}
-	}
-
 	return vhost.middlewareChain.ServeHTTP(w, r)
 }
 
@@ -465,73 +478,9 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) {
 	return ln.TCPListener.File()
 }
 
-// MaxBytesExceeded is the error type returned by MaxBytesReader
+// MaxBytesExceeded is the error returned by MaxBytesReader
 // when the request body exceeds the limit imposed
-type MaxBytesExceeded struct{}
-
-func (err MaxBytesExceeded) Error() string {
-	return "http: request body too large"
-}
-
-// MaxBytesReader and its associated methods are borrowed from the
-// Go Standard library (comments intact). The only difference is that
-// it returns a MaxBytesExceeded error instead of a generic error message
-// when the request body has exceeded the requested limit
-func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
-	return &maxBytesReader{w: w, r: r, n: n}
-}
-
-type maxBytesReader struct {
-	w   http.ResponseWriter
-	r   io.ReadCloser // underlying reader
-	n   int64         // max bytes remaining
-	err error         // sticky error
-}
-
-func (l *maxBytesReader) Read(p []byte) (n int, err error) {
-	if l.err != nil {
-		return 0, l.err
-	}
-	if len(p) == 0 {
-		return 0, nil
-	}
-	// If they asked for a 32KB byte read but only 5 bytes are
-	// remaining, no need to read 32KB. 6 bytes will answer the
-	// question of the whether we hit the limit or go past it.
-	if int64(len(p)) > l.n+1 {
-		p = p[:l.n+1]
-	}
-	n, err = l.r.Read(p)
-
-	if int64(n) <= l.n {
-		l.n -= int64(n)
-		l.err = err
-		return n, err
-	}
-
-	n = int(l.n)
-	l.n = 0
-
-	// The server code and client code both use
-	// maxBytesReader. This "requestTooLarge" check is
-	// only used by the server code. To prevent binaries
-	// which only using the HTTP Client code (such as
-	// cmd/go) from also linking in the HTTP server, don't
-	// use a static type assertion to the server
-	// "*response" type. Check this interface instead:
-	type requestTooLarger interface {
-		requestTooLarge()
-	}
-	if res, ok := l.w.(requestTooLarger); ok {
-		res.requestTooLarge()
-	}
-	l.err = MaxBytesExceeded{}
-	return n, l.err
-}
-
-func (l *maxBytesReader) Close() error {
-	return l.r.Close()
-}
+var MaxBytesExceededErr = errors.New("http: request body too large")
 
 // DefaultErrorFunc responds to an HTTP request with a simple description
 // of the specified HTTP status code.
diff --git a/caddyhttp/httpserver/server_test.go b/caddyhttp/httpserver/server_test.go
index 69d2f7453..036ee3dd8 100644
--- a/caddyhttp/httpserver/server_test.go
+++ b/caddyhttp/httpserver/server_test.go
@@ -15,7 +15,7 @@ func TestAddress(t *testing.T) {
 	}
 }
 
-func TestMakeHTTPServer(t *testing.T) {
+func TestMakeHTTPServerWithTimeouts(t *testing.T) {
 	for i, tc := range []struct {
 		group    []*SiteConfig
 		expected Timeouts
@@ -111,3 +111,36 @@ func TestMakeHTTPServer(t *testing.T) {
 		}
 	}
 }
+
+func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
+	for name, c := range map[string]struct {
+		group  []*SiteConfig
+		expect int
+	}{
+		"disable": {
+			group:  []*SiteConfig{{}},
+			expect: 0,
+		},
+		"oneSite": {
+			group: []*SiteConfig{{Limits: Limits{
+				MaxRequestHeaderSize: 100,
+			}}},
+			expect: 100,
+		},
+		"multiSites": {
+			group: []*SiteConfig{
+				{Limits: Limits{MaxRequestHeaderSize: 100}},
+				{Limits: Limits{MaxRequestHeaderSize: 50}},
+			},
+			expect: 50,
+		},
+	} {
+		c := c
+		t.Run(name, func(t *testing.T) {
+			actual := makeHTTPServerWithHeaderLimit(&http.Server{}, c.group)
+			if got := actual.MaxHeaderBytes; got != c.expect {
+				t.Errorf("Expect %d, but got %d", c.expect, got)
+			}
+		})
+	}
+}
diff --git a/caddyhttp/httpserver/siteconfig.go b/caddyhttp/httpserver/siteconfig.go
index de18bd477..a091868f7 100644
--- a/caddyhttp/httpserver/siteconfig.go
+++ b/caddyhttp/httpserver/siteconfig.go
@@ -38,8 +38,8 @@ type SiteConfig struct {
 	// for a request.
 	HiddenFiles []string
 
-	// Max amount of bytes a request can send on a given path
-	MaxRequestBodySizes []PathLimit
+	// Max request's header/body size
+	Limits Limits
 
 	// The path to the Caddyfile used to generate this site config
 	originCaddyfile string
@@ -71,6 +71,12 @@ type Timeouts struct {
 	IdleTimeoutSet       bool
 }
 
+// Limits specify size limit of request's header and body.
+type Limits struct {
+	MaxRequestHeaderSize int64
+	MaxRequestBodySizes  []PathLimit
+}
+
 // PathLimit is a mapping from a site's path to its corresponding
 // maximum request body size (in bytes)
 type PathLimit struct {
diff --git a/caddyhttp/limits/handler.go b/caddyhttp/limits/handler.go
new file mode 100644
index 000000000..52fe60ab1
--- /dev/null
+++ b/caddyhttp/limits/handler.go
@@ -0,0 +1,90 @@
+package limits
+
+import (
+	"io"
+	"net/http"
+
+	"github.com/mholt/caddy/caddyhttp/httpserver"
+)
+
+// Limit is a middleware to control request body size
+type Limit struct {
+	Next       httpserver.Handler
+	BodyLimits []httpserver.PathLimit
+}
+
+func (l Limit) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
+	if r.Body == nil {
+		return l.Next.ServeHTTP(w, r)
+	}
+
+	// apply the path-based request body size limit.
+	for _, bl := range l.BodyLimits {
+		if httpserver.Path(r.URL.Path).Matches(bl.Path) {
+			r.Body = MaxBytesReader(w, r.Body, bl.Limit)
+			break
+		}
+	}
+
+	return l.Next.ServeHTTP(w, r)
+}
+
+// MaxBytesReader and its associated methods are borrowed from the
+// Go Standard library (comments intact). The only difference is that
+// it returns a MaxBytesExceeded error instead of a generic error message
+// when the request body has exceeded the requested limit
+func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
+	return &maxBytesReader{w: w, r: r, n: n}
+}
+
+type maxBytesReader struct {
+	w   http.ResponseWriter
+	r   io.ReadCloser // underlying reader
+	n   int64         // max bytes remaining
+	err error         // sticky error
+}
+
+func (l *maxBytesReader) Read(p []byte) (n int, err error) {
+	if l.err != nil {
+		return 0, l.err
+	}
+	if len(p) == 0 {
+		return 0, nil
+	}
+	// If they asked for a 32KB byte read but only 5 bytes are
+	// remaining, no need to read 32KB. 6 bytes will answer the
+	// question of the whether we hit the limit or go past it.
+	if int64(len(p)) > l.n+1 {
+		p = p[:l.n+1]
+	}
+	n, err = l.r.Read(p)
+
+	if int64(n) <= l.n {
+		l.n -= int64(n)
+		l.err = err
+		return n, err
+	}
+
+	n = int(l.n)
+	l.n = 0
+
+	// The server code and client code both use
+	// maxBytesReader. This "requestTooLarge" check is
+	// only used by the server code. To prevent binaries
+	// which only using the HTTP Client code (such as
+	// cmd/go) from also linking in the HTTP server, don't
+	// use a static type assertion to the server
+	// "*response" type. Check this interface instead:
+	type requestTooLarger interface {
+		requestTooLarge()
+	}
+	if res, ok := l.w.(requestTooLarger); ok {
+		res.requestTooLarge()
+	}
+	l.err = httpserver.MaxBytesExceededErr
+	return n, l.err
+}
+
+func (l *maxBytesReader) Close() error {
+	return l.r.Close()
+}
diff --git a/caddyhttp/limits/handler_test.go b/caddyhttp/limits/handler_test.go
new file mode 100644
index 000000000..be5144ede
--- /dev/null
+++ b/caddyhttp/limits/handler_test.go
@@ -0,0 +1,35 @@
+package limits
+
+import (
+	"io/ioutil"
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"testing"
+
+	"github.com/mholt/caddy/caddyhttp/httpserver"
+)
+
+func TestBodySizeLimit(t *testing.T) {
+	var (
+		gotContent    []byte
+		gotError      error
+		expectContent = "hello"
+	)
+	l := Limit{
+		Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
+			gotContent, gotError = ioutil.ReadAll(r.Body)
+			return 0, nil
+		}),
+		BodyLimits: []httpserver.PathLimit{{Path: "/", Limit: int64(len(expectContent))}},
+	}
+
+	r := httptest.NewRequest("GET", "/", strings.NewReader(expectContent+expectContent))
+	l.ServeHTTP(httptest.NewRecorder(), r)
+	if got := string(gotContent); got != expectContent {
+		t.Errorf("expected content[%s], got[%s]", expectContent, got)
+	}
+	if gotError != httpserver.MaxBytesExceededErr {
+		t.Errorf("expect error %v, got %v", httpserver.MaxBytesExceededErr, gotError)
+	}
+}
diff --git a/caddyhttp/maxrequestbody/maxrequestbody.go b/caddyhttp/limits/setup.go
similarity index 67%
rename from caddyhttp/maxrequestbody/maxrequestbody.go
rename to caddyhttp/limits/setup.go
index 8b5fef464..1c75153f2 100644
--- a/caddyhttp/maxrequestbody/maxrequestbody.go
+++ b/caddyhttp/limits/setup.go
@@ -1,4 +1,4 @@
-package maxrequestbody
+package limits
 
 import (
 	"errors"
@@ -12,13 +12,13 @@ import (
 
 const (
 	serverType = "http"
-	pluginName = "maxrequestbody"
+	pluginName = "limits"
 )
 
 func init() {
 	caddy.RegisterPlugin(pluginName, caddy.Plugin{
 		ServerType: serverType,
-		Action:     setupMaxRequestBody,
+		Action:     setupLimits,
 	})
 }
 
@@ -28,56 +28,97 @@ type pathLimitUnparsed struct {
 	Limit string
 }
 
-func setupMaxRequestBody(c *caddy.Controller) error {
+func setupLimits(c *caddy.Controller) error {
+	bls, err := parseLimits(c)
+	if err != nil {
+		return err
+	}
+
+	httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
+		return Limit{Next: next, BodyLimits: bls}
+	})
+	return nil
+}
+
+func parseLimits(c *caddy.Controller) ([]httpserver.PathLimit, error) {
 	config := httpserver.GetConfig(c)
 
 	if !c.Next() {
-		return c.ArgErr()
+		return nil, c.ArgErr()
 	}
 
 	args := c.RemainingArgs()
 	argList := []pathLimitUnparsed{}
+	headerLimit := ""
 
 	switch len(args) {
 	case 0:
-		// Format: { <path> <limit> ... }
+		// Format: limits {
+		//	header <limit>
+		//	body <path> <limit>
+		//	body <limit>
+		//	...
+		// }
 		for c.NextBlock() {
-			path := c.Val()
-			if !c.NextArg() {
-				// Uneven pairing of path/limit
-				return c.ArgErr()
+			kind := c.Val()
+			pathOrLimit := c.RemainingArgs()
+			switch kind {
+			case "header":
+				if len(pathOrLimit) != 1 {
+					return nil, c.ArgErr()
+				}
+				headerLimit = pathOrLimit[0]
+			case "body":
+				if len(pathOrLimit) == 1 {
+					argList = append(argList, pathLimitUnparsed{
+						Path:  "/",
+						Limit: pathOrLimit[0],
+					})
+					break
+				}
+
+				if len(pathOrLimit) == 2 {
+					argList = append(argList, pathLimitUnparsed{
+						Path:  pathOrLimit[0],
+						Limit: pathOrLimit[1],
+					})
+					break
+				}
+
+				fallthrough
+			default:
+				return nil, c.ArgErr()
 			}
-			argList = append(argList, pathLimitUnparsed{
-				Path:  path,
-				Limit: c.Val(),
-			})
 		}
 	case 1:
-		// Format: <limit>
+		// Format: limits <limit>
+		headerLimit = args[0]
 		argList = []pathLimitUnparsed{{
 			Path:  "/",
 			Limit: args[0],
 		}}
-	case 2:
-		// Format: <path> <limit>
-		argList = []pathLimitUnparsed{{
-			Path:  args[0],
-			Limit: args[1],
-		}}
 	default:
-		return c.ArgErr()
+		return nil, c.ArgErr()
 	}
 
-	pathLimit, err := parseArguments(argList)
-	if err != nil {
-		return c.ArgErr()
+	if headerLimit != "" {
+		size := parseSize(headerLimit)
+		if size < 1 { // also disallow size = 0
+			return nil, c.ArgErr()
+		}
+		config.Limits.MaxRequestHeaderSize = size
 	}
 
-	SortPathLimits(pathLimit)
+	if len(argList) > 0 {
+		pathLimit, err := parseArguments(argList)
+		if err != nil {
+			return nil, c.ArgErr()
+		}
+		SortPathLimits(pathLimit)
+		config.Limits.MaxRequestBodySizes = pathLimit
+	}
 
-	config.MaxRequestBodySizes = pathLimit
-
-	return nil
+	return config.Limits.MaxRequestBodySizes, nil
 }
 
 func parseArguments(args []pathLimitUnparsed) ([]httpserver.PathLimit, error) {
diff --git a/caddyhttp/maxrequestbody/maxrequestbody_test.go b/caddyhttp/limits/setup_test.go
similarity index 64%
rename from caddyhttp/maxrequestbody/maxrequestbody_test.go
rename to caddyhttp/limits/setup_test.go
index 7d2c31e78..08a36f901 100644
--- a/caddyhttp/maxrequestbody/maxrequestbody_test.go
+++ b/caddyhttp/limits/setup_test.go
@@ -1,4 +1,4 @@
-package maxrequestbody
+package limits
 
 import (
 	"reflect"
@@ -14,32 +14,98 @@ const (
 	GB = 1024 * 1024 * 1024
 )
 
-func TestSetupMaxRequestBody(t *testing.T) {
-	cases := []struct {
-		input    string
-		hasError bool
+func TestParseLimits(t *testing.T) {
+	for name, c := range map[string]struct {
+		input     string
+		shouldErr bool
+		expect    httpserver.Limits
 	}{
-		// Format: { <path> <limit> ... }
-		{input: "maxrequestbody / 20MB", hasError: false},
-		// Format: <limit>
-		{input: "maxrequestbody 999KB", hasError: false},
-		// Format: { <path> <limit> ... }
-		{input: "maxrequestbody { /images 50MB /upload 10MB\n/test 10KB }", hasError: false},
-
-		// Wrong formats
-		{input: "maxrequestbody typo { /images 50MB }", hasError: true},
-		{input: "maxrequestbody 999MB /home 20KB", hasError: true},
-	}
-	for caseNum, c := range cases {
-		controller := caddy.NewTestController("", c.input)
-		err := setupMaxRequestBody(controller)
-
-		if c.hasError && (err == nil) {
-			t.Errorf("Expecting error for case %v but none encountered", caseNum)
-		}
-		if !c.hasError && (err != nil) {
-			t.Errorf("Expecting no error for case %v but encountered %v", caseNum, err)
-		}
+		"catchAll": {
+			input: `limits 2kb`,
+			expect: httpserver.Limits{
+				MaxRequestHeaderSize: 2 * KB,
+				MaxRequestBodySizes:  []httpserver.PathLimit{{Path: "/", Limit: 2 * KB}},
+			},
+		},
+		"onlyHeader": {
+			input: `limits {
+				header 2kb
+			}`,
+			expect: httpserver.Limits{
+				MaxRequestHeaderSize: 2 * KB,
+			},
+		},
+		"onlyBody": {
+			input: `limits {
+				body 2kb
+			}`,
+			expect: httpserver.Limits{
+				MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/", Limit: 2 * KB}},
+			},
+		},
+		"onlyBodyWithPath": {
+			input: `limits {
+				body /test 2kb
+			}`,
+			expect: httpserver.Limits{
+				MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/test", Limit: 2 * KB}},
+			},
+		},
+		"mixture": {
+			input: `limits {
+				header 1kb
+				body 2kb
+				body /bar 3kb
+			}`,
+			expect: httpserver.Limits{
+				MaxRequestHeaderSize: 1 * KB,
+				MaxRequestBodySizes: []httpserver.PathLimit{
+					{Path: "/bar", Limit: 3 * KB},
+					{Path: "/", Limit: 2 * KB},
+				},
+			},
+		},
+		"invalidFormat": {
+			input:     `limits a b`,
+			shouldErr: true,
+		},
+		"invalidHeaderFormat": {
+			input: `limits {
+				header / 100
+			}`,
+			shouldErr: true,
+		},
+		"invalidBodyFormat": {
+			input: `limits {
+				body / 100 200
+			}`,
+			shouldErr: true,
+		},
+		"invalidKind": {
+			input: `limits {
+				head 100
+			}`,
+			shouldErr: true,
+		},
+		"invalidLimitSize": {
+			input:     `limits 10bk`,
+			shouldErr: true,
+		},
+	} {
+		c := c
+		t.Run(name, func(t *testing.T) {
+			controller := caddy.NewTestController("", c.input)
+			_, err := parseLimits(controller)
+			if c.shouldErr && err == nil {
+				t.Error("failed to get expected error")
+			}
+			if !c.shouldErr && err != nil {
+				t.Errorf("got unexpected error: %v", err)
+			}
+			if got := httpserver.GetConfig(controller).Limits; !reflect.DeepEqual(got, c.expect) {
+				t.Errorf("expect %#v, but got %#v", c.expect, got)
+			}
+		})
 	}
 }
 
diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go
index 044a3127c..27293deb1 100644
--- a/caddyhttp/proxy/proxy.go
+++ b/caddyhttp/proxy/proxy.go
@@ -228,7 +228,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
 			return 0, nil
 		}
 
-		if _, ok := backendErr.(httpserver.MaxBytesExceeded); ok {
+		if backendErr == httpserver.MaxBytesExceededErr {
 			return http.StatusRequestEntityTooLarge, backendErr
 		}