From dd4c4d7eb649d4d0ae1b381a501f431c684ab27c Mon Sep 17 00:00:00 2001 From: Benny Ng Date: Tue, 1 Nov 2016 12:34:39 +0800 Subject: [PATCH] proxy: record request Body for retry (fixes #1229) --- caddyhttp/proxy/body.go | 40 ++++++++++++++++++++ caddyhttp/proxy/body_test.go | 69 +++++++++++++++++++++++++++++++++++ caddyhttp/proxy/proxy.go | 14 +++++++ caddyhttp/proxy/proxy_test.go | 58 +++++++++++++++++++++++++++++ 4 files changed, 181 insertions(+) create mode 100644 caddyhttp/proxy/body.go create mode 100644 caddyhttp/proxy/body_test.go diff --git a/caddyhttp/proxy/body.go b/caddyhttp/proxy/body.go new file mode 100644 index 00000000..38d00165 --- /dev/null +++ b/caddyhttp/proxy/body.go @@ -0,0 +1,40 @@ +package proxy + +import ( + "bytes" + "io" + "io/ioutil" +) + +type bufferedBody struct { + *bytes.Reader +} + +func (*bufferedBody) Close() error { + return nil +} + +// rewind allows bufferedBody to be read again. +func (b *bufferedBody) rewind() error { + if b == nil { + return nil + } + _, err := b.Seek(0, io.SeekStart) + return err +} + +// newBufferedBody returns *bufferedBody to use in place of src. Closes src +// and returns Read error on src. All content from src is buffered. +func newBufferedBody(src io.ReadCloser) (*bufferedBody, error) { + if src == nil { + return nil, nil + } + b, err := ioutil.ReadAll(src) + src.Close() + if err != nil { + return nil, err + } + return &bufferedBody{ + Reader: bytes.NewReader(b), + }, nil +} diff --git a/caddyhttp/proxy/body_test.go b/caddyhttp/proxy/body_test.go new file mode 100644 index 00000000..5b72784c --- /dev/null +++ b/caddyhttp/proxy/body_test.go @@ -0,0 +1,69 @@ +package proxy + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" +) + +func TestBodyRetry(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(w, r.Body) + r.Body.Close() + })) + defer ts.Close() + + testcase := "test content" + req, err := http.NewRequest(http.MethodPost, ts.URL, bytes.NewBufferString(testcase)) + if err != nil { + t.Fatal(err) + } + + body, err := newBufferedBody(req.Body) + if err != nil { + t.Fatal(err) + } + if body != nil { + req.Body = body + } + + // simulate fail request + host := req.URL.Host + req.URL.Host = "example.com" + body.rewind() + _, _ = http.DefaultTransport.RoundTrip(req) + + // retry request + req.URL.Host = host + body.rewind() + resp, err := http.DefaultTransport.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + result, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if string(result) != testcase { + t.Fatalf("result = %s, want %s", result, testcase) + } + + // try one more time for body reuse + body.rewind() + resp, err = http.DefaultTransport.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + result, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if string(result) != testcase { + t.Fatalf("result = %s, want %s", result, testcase) + } +} diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index 71c7476b..11f2d5d0 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -94,6 +94,15 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // outreq is the request that makes a roundtrip to the backend outreq := createUpstreamRequest(r) + // record and replace outreq body + body, err := newBufferedBody(outreq.Body) + if err != nil { + return http.StatusBadRequest, errors.New("failed to read downstream request body") + } + if body != nil { + outreq.Body = body + } + // The keepRetrying function will return true if we should // loop and try to select another host, or false if we // should break and stop retrying. @@ -164,6 +173,11 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) } + // rewind request body to its beginning + if err := body.rewind(); err != nil { + return http.StatusInternalServerError, errors.New("unable to rewind downstream request body") + } + // tell the proxy to serve the request atomic.AddInt64(&host.Conns, 1) backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn) diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index af02a17c..290cae93 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/mholt/caddy/caddyfile" "github.com/mholt/caddy/caddyhttp/httpserver" "golang.org/x/net/websocket" @@ -836,6 +837,63 @@ func TestProxyDirectorURL(t *testing.T) { } } +func TestReverseProxyRetry(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + // set up proxy + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(w, r.Body) + r.Body.Close() + })) + defer backend.Close() + + su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(` + proxy / localhost:65535 localhost:65534 `+backend.URL+` { + policy round_robin + fail_timeout 5s + max_fails 1 + try_duration 5s + try_interval 250ms + } + `))) + if err != nil { + t.Fatal(err) + } + + p := &Proxy{ + Next: httpserver.EmptyNext, // prevents panic in some cases when test fails + Upstreams: su, + } + + // middle is required to simulate closable downstream request body + middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err = p.ServeHTTP(w, r) + if err != nil { + t.Error(err) + } + })) + defer middle.Close() + + testcase := "test content" + r, err := http.NewRequest("POST", middle.URL, bytes.NewBufferString(testcase)) + if err != nil { + t.Fatal(err) + } + resp, err := http.DefaultTransport.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(b) != testcase { + t.Fatalf("string(b) = %s, want %s", string(b), testcase) + } +} + func newFakeUpstream(name string, insecure bool) *fakeUpstream { uri, _ := url.Parse(name) u := &fakeUpstream{