caddy/caddyhttp/proxy/proxy_test.go
Angel Santiago 59bf71c293 proxy: Cleanly shutdown health checks on restart (#1524)
* Add a shutdown function and context to staticUpstream so that running goroutines can be cancelled. Add a GetShutdownFunc to Upstream interface to expose the shutdown function to the caddy Controller for performing it on restarts.

* Make fakeUpstream implement new Upstream methods.

Implement new Upstream method for fakeWSUpstream as well.

* Rename GetShutdownFunc to Stop(). Add a waitgroup to the staticUpstream for controlling individual object's goroutines. Add the Stop function to OnRestart and OnShutdown. Add tests for checking to see if healthchecks continue hitting a backend server after stop has been called.

* Go back to using a stop channel since the context adds no additional benefit.
Only register stop function for onShutdown since it's called as part of restart.

* Remove assignment to atomic value

* Incrementing WaitGroup outside of goroutine to avoid race condition. Loading atomic values in test.

* Linting: change counter to just use the default zero value instead of setting it

* Clarify Stop method comments, add comments to stop channel and waitgroup and remove out of date comment about handling stopping the proxy. Stop the ticker when the stop signal is sent
2017-04-02 14:58:15 -06:00

1331 lines
36 KiB
Go

package proxy
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"net/url"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/mholt/caddy/caddyfile"
"github.com/mholt/caddy/caddyhttp/httpserver"
"golang.org/x/net/websocket"
)
// This is a simple wrapper around httptest.NewTLSServer()
// which forcefully enables (among others) HTTP/2 support.
// The httptest package only supports HTTP/1.1 by default.
func newTLSServer(handler http.Handler) *httptest.Server {
ts := httptest.NewUnstartedServer(handler)
ts.TLS = new(tls.Config)
ts.TLS.NextProtos = []string{"h2"}
ts.StartTLS()
return ts
}
func TestReverseProxy(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
verifyHeaders := func(headers http.Header, trailers http.Header) {
if headers.Get("X-Header") != "header-value" {
t.Error("Expected header 'X-Header' to be proxied properly")
}
if trailers == nil {
t.Error("Expected to receive trailers")
}
if trailers.Get("X-Trailer") != "trailer-value" {
t.Error("Expected header 'X-Trailer' to be proxied properly")
}
}
var requestReceived bool
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// read the body (even if it's empty) to make Go parse trailers
io.Copy(ioutil.Discard, r.Body)
verifyHeaders(r.Header, r.Trailer)
requestReceived = true
w.Header().Set("Trailer", "X-Trailer")
w.Header().Set("X-Header", "header-value")
w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, client"))
w.Header().Set("X-Trailer", "trailer-value")
}))
defer backend.Close()
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
}
// create request and response recorder
r := httptest.NewRequest("GET", "/", strings.NewReader("test"))
w := httptest.NewRecorder()
r.ContentLength = -1 // force chunked encoding (required for trailers)
r.Header.Set("X-Header", "header-value")
r.Trailer = map[string][]string{
"X-Trailer": {"trailer-value"},
}
p.ServeHTTP(w, r)
if !requestReceived {
t.Error("Expected backend to receive request, but it didn't")
}
res := w.Result()
verifyHeaders(res.Header, res.Trailer)
// Make sure {upstream} placeholder is set
r.Body = ioutil.NopCloser(strings.NewReader("test"))
rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
rr.Replacer = httpserver.NewReplacer(r, rr, "-")
p.ServeHTTP(rr, r)
if got, want := rr.Replacer.Replace("{upstream}"), backend.URL; got != want {
t.Errorf("Expected custom placeholder {upstream} to be set (%s), but it wasn't; got: %s", want, got)
}
}
func TestReverseProxyInsecureSkipVerify(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var requestReceived bool
var requestWasHTTP2 bool
backend := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestReceived = true
requestWasHTTP2 = r.ProtoAtLeast(2, 0)
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
}
// create request and response recorder
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
if !requestReceived {
t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't")
}
if !requestWasHTTP2 {
t.Error("Even with insecure HTTPS, expected proxy to use HTTP/2")
}
}
// This test will fail when using the race detector without atomic reads &
// writes of UpstreamHost.Conns and UpstreamHost.Unhealthy.
func TestReverseProxyMaxConnLimit(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
const MaxTestConns = 2
connReceived := make(chan bool, MaxTestConns)
connContinue := make(chan bool)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
connReceived <- true
<-connContinue
}))
defer backend.Close()
su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`
proxy / `+backend.URL+` {
max_conns `+fmt.Sprint(MaxTestConns)+`
}
`)))
if err != nil {
t.Fatal(err)
}
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: su,
}
var jobs sync.WaitGroup
for i := 0; i < MaxTestConns; i++ {
jobs.Add(1)
go func(i int) {
defer jobs.Done()
w := httptest.NewRecorder()
code, err := p.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
if err != nil {
t.Errorf("Request %d failed: %v", i, err)
} else if code != 0 {
t.Errorf("Bad return code for request %d: %d", i, code)
} else if w.Code != 200 {
t.Errorf("Bad statuc code for request %d: %d", i, w.Code)
}
}(i)
}
// Wait for all the requests to hit the backend.
for i := 0; i < MaxTestConns; i++ {
<-connReceived
}
// Now we should have MaxTestConns requests connected and sitting on the backend
// server. Verify that the next request is rejected.
w := httptest.NewRecorder()
code, err := p.ServeHTTP(w, httptest.NewRequest("GET", "/", nil))
if code != http.StatusBadGateway {
t.Errorf("Expected request to be rejected, but got: %d [%v]\nStatus code: %d",
code, err, w.Code)
}
// Now let all the requests complete and verify the status codes for those:
close(connContinue)
// Wait for the initial requests to finish and check their results.
jobs.Wait()
}
func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
// Capture the expected panic
defer func() {
r := recover()
if _, ok := r.(httpserver.NonHijackerError); !ok {
t.Error("not get the expected panic")
}
}()
var connCount int32
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { atomic.AddInt32(&connCount, 1) }))
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false)
// Create client request
r := httptest.NewRequest("GET", "/", nil)
r.Header = http.Header{
"Connection": {"Upgrade"},
"Upgrade": {"websocket"},
"Origin": {wsNop.URL},
"Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="},
"Sec-WebSocket-Version": {"13"},
}
nonHijacker := httptest.NewRecorder()
p.ServeHTTP(nonHijacker, r)
}
func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
// No-op websocket backend simply allows the WS connection to be
// accepted then it will be immediately closed. Perfect for testing.
var connCount int32
wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { atomic.AddInt32(&connCount, 1) }))
defer wsNop.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsNop.URL, false)
// Create client request
r := httptest.NewRequest("GET", "/", nil)
r.Header = http.Header{
"Connection": {"Upgrade"},
"Upgrade": {"websocket"},
"Origin": {wsNop.URL},
"Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="},
"Sec-WebSocket-Version": {"13"},
}
// Capture the request
w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)}
// Booya! Do the test.
p.ServeHTTP(w, r)
// Make sure the backend accepted the WS connection.
// Mostly interested in the Upgrade and Connection response headers
// and the 101 status code.
expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n")
actual := w.fakeConn.writeBuf.Bytes()
if !bytes.Equal(actual, expected) {
t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual)
}
if got, want := atomic.LoadInt32(&connCount), int32(1); got != want {
t.Errorf("Expected %d websocket connection, got %d", want, got)
}
}
func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
// Echo server allows us to test that socket bytes are properly
// being proxied.
wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
io.Copy(ws, ws)
}))
defer wsEcho.Close()
// Get proxy to use for the test
p := newWebSocketTestProxy(wsEcho.URL, false)
// This is a full end-end test, so the proxy handler
// has to be part of a server listening on a port. Our
// WS client will connect to this test server, not
// the echo client directly.
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
// Set up WebSocket client
url := strings.Replace(echoProxy.URL, "http://", "ws://", 1)
ws, err := websocket.Dial(url, "", echoProxy.URL)
if err != nil {
t.Fatal(err)
}
defer ws.Close()
// Send test message
trialMsg := "Is it working?"
if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil {
t.Fatal(sendErr)
}
// It should be echoed back to us
var actualMsg string
if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil {
t.Fatal(rcvErr)
}
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
wsEcho := newTLSServer(websocket.Handler(func(ws *websocket.Conn) {
io.Copy(ws, ws)
}))
defer wsEcho.Close()
p := newWebSocketTestProxy(wsEcho.URL, true)
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
// Set up WebSocket client
url := strings.Replace(echoProxy.URL, "https://", "wss://", 1)
wsCfg, err := websocket.NewConfig(url, echoProxy.URL)
if err != nil {
t.Fatal(err)
}
wsCfg.TlsConfig = &tls.Config{InsecureSkipVerify: true}
ws, err := websocket.DialConfig(wsCfg)
if err != nil {
t.Fatal(err)
}
defer ws.Close()
// Send test message
trialMsg := "Is it working?"
if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil {
t.Fatal(sendErr)
}
// It should be echoed back to us
var actualMsg string
if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil {
t.Fatal(rcvErr)
}
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func TestUnixSocketProxy(t *testing.T) {
if runtime.GOOS == "windows" {
return
}
trialMsg := "Is it working?"
var proxySuccess bool
// This is our fake "application" we want to proxy to
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Request was proxied when this is called
proxySuccess = true
fmt.Fprint(w, trialMsg)
}))
// Get absolute path for unix: socket
dir, err := ioutil.TempDir("", "caddy_proxytest")
if err != nil {
t.Fatalf("Failed to make temp dir to contain unix socket. %v", err)
}
defer os.RemoveAll(dir)
socketPath := filepath.Join(dir, "test_socket")
// Change httptest.Server listener to listen to unix: socket
ln, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Unable to listen: %v", err)
}
ts.Listener = ln
ts.Start()
defer ts.Close()
url := strings.Replace(ts.URL, "http://", "unix:", 1)
p := newWebSocketTestProxy(url, false)
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer echoProxy.Close()
res, err := http.Get(echoProxy.URL)
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
t.Fatalf("Unable to GET: %v", err)
}
actualMsg := fmt.Sprintf("%s", greeting)
if !proxySuccess {
t.Errorf("Expected request to be proxied, but it wasn't")
}
if actualMsg != trialMsg {
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
}
}
func GetHTTPProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, messageFormat, r.URL.String())
}))
return newPrefixedWebSocketTestProxy(ts.URL, prefix), ts
}
func GetSocketProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server, string, error) {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, messageFormat, r.URL.String())
}))
dir, err := ioutil.TempDir("", "caddy_proxytest")
if err != nil {
return nil, nil, dir, fmt.Errorf("Failed to make temp dir to contain unix socket. %v", err)
}
socketPath := filepath.Join(dir, "test_socket")
ln, err := net.Listen("unix", socketPath)
if err != nil {
os.RemoveAll(dir)
return nil, nil, dir, fmt.Errorf("Unable to listen: %v", err)
}
ts.Listener = ln
ts.Start()
tsURL := strings.Replace(ts.URL, "http://", "unix:", 1)
return newPrefixedWebSocketTestProxy(tsURL, prefix), ts, dir, nil
}
func GetTestServerMessage(p *Proxy, ts *httptest.Server, path string) (string, error) {
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
// *httptest.Server is passed so it can be `defer`red properly
defer ts.Close()
defer echoProxy.Close()
res, err := http.Get(echoProxy.URL + path)
if err != nil {
return "", fmt.Errorf("Unable to GET: %v", err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
return "", fmt.Errorf("Unable to read body: %v", err)
}
return fmt.Sprintf("%s", greeting), nil
}
func TestUnixSocketProxyPaths(t *testing.T) {
greeting := "Hello route %s"
tests := []struct {
url string
prefix string
expected string
}{
{"", "", fmt.Sprintf(greeting, "/")},
{"/hello", "", fmt.Sprintf(greeting, "/hello")},
{"/foo/bar", "", fmt.Sprintf(greeting, "/foo/bar")},
{"/foo?bar", "", fmt.Sprintf(greeting, "/foo?bar")},
{"/greet?name=john", "", fmt.Sprintf(greeting, "/greet?name=john")},
{"/world?wonderful&colorful", "", fmt.Sprintf(greeting, "/world?wonderful&colorful")},
{"/proxy/hello", "/proxy", fmt.Sprintf(greeting, "/hello")},
{"/proxy/foo/bar", "/proxy", fmt.Sprintf(greeting, "/foo/bar")},
{"/proxy/?foo=bar", "/proxy", fmt.Sprintf(greeting, "/?foo=bar")},
{"/queues/%2F/fetchtasks", "", fmt.Sprintf(greeting, "/queues/%2F/fetchtasks")},
{"/queues/%2F/fetchtasks?foo=bar", "", fmt.Sprintf(greeting, "/queues/%2F/fetchtasks?foo=bar")},
}
for _, test := range tests {
p, ts := GetHTTPProxy(greeting, test.prefix)
actualMsg, err := GetTestServerMessage(p, ts, test.url)
if err != nil {
t.Fatalf("Getting server message failed - %v", err)
}
if actualMsg != test.expected {
t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg)
}
}
if runtime.GOOS == "windows" {
return
}
for _, test := range tests {
p, ts, tmpdir, err := GetSocketProxy(greeting, test.prefix)
if err != nil {
t.Fatalf("Getting socket proxy failed - %v", err)
}
actualMsg, err := GetTestServerMessage(p, ts, test.url)
if err != nil {
os.RemoveAll(tmpdir)
t.Fatalf("Getting server message failed - %v", err)
}
if actualMsg != test.expected {
t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg)
}
os.RemoveAll(tmpdir)
}
}
func TestUpstreamHeadersUpdate(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
var actualHeaders http.Header
var actualHost string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, client"))
actualHeaders = r.Header
actualHost = r.Host
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.UpstreamHeaders = http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"},
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
"+Add-Empty": {"{}"},
"-Remove-Me": {""},
"Replace-Me": {"{hostname}"},
"Clear-Me": {""},
"Host": {"{>Host}"},
}
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{upstream},
}
// create request and response recorder
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
const expectHost = "example.com"
//add initial headers
r.Header.Add("Merge-Me", "Initial")
r.Header.Add("Remove-Me", "Remove-Value")
r.Header.Add("Replace-Me", "Replace-Value")
r.Header.Add("Host", expectHost)
p.ServeHTTP(w, r)
replacer := httpserver.NewReplacer(r, nil, "")
for headerKey, expect := range map[string][]string{
"Merge-Me": {"Initial", "Merge-Value"},
"Add-Me": {"Add-Value"},
"Add-Empty": nil,
"Remove-Me": nil,
"Replace-Me": {replacer.Replace("{hostname}")},
"Clear-Me": nil,
} {
if got := actualHeaders[headerKey]; !reflect.DeepEqual(got, expect) {
t.Errorf("Upstream request does not contain expected %v header: expect %v, but got %v",
headerKey, expect, got)
}
}
if actualHost != expectHost {
t.Errorf("Request sent to upstream backend should have value of Host with %s, but got %s", expectHost, actualHost)
}
}
func TestDownstreamHeadersUpdate(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Merge-Me", "Initial")
w.Header().Add("Remove-Me", "Remove-Value")
w.Header().Add("Replace-Me", "Replace-Value")
w.Header().Add("Content-Type", "text/html")
w.Header().Add("Overwrite-Me", "Overwrite-Value")
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.DownstreamHeaders = http.Header{
"+Merge-Me": {"Merge-Value"},
"+Add-Me": {"Add-Value"},
"-Remove-Me": {""},
"Replace-Me": {"{hostname}"},
}
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{upstream},
}
// create request and response recorder
r := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
// set a predefined skip header
w.Header().Set("Content-Type", "text/css")
// set a predefined overwritten header
w.Header().Set("Overwrite-Me", "Initial")
p.ServeHTTP(w, r)
replacer := httpserver.NewReplacer(r, nil, "")
actualHeaders := w.Header()
for headerKey, expect := range map[string][]string{
"Merge-Me": {"Initial", "Merge-Value"},
"Add-Me": {"Add-Value"},
"Remove-Me": nil,
"Replace-Me": {replacer.Replace("{hostname}")},
"Content-Type": {"text/css"},
"Overwrite-Me": {"Overwrite-Value"},
} {
if got := actualHeaders[headerKey]; !reflect.DeepEqual(got, expect) {
t.Errorf("Downstream response does not contain expected %s header: expect %v, but got %v",
headerKey, expect, got)
}
}
}
var (
upstreamResp1 = []byte("Hello, /")
upstreamResp2 = []byte("Hello, /api/")
)
func newMultiHostTestProxy() *Proxy {
// No-op backends.
upstreamServer1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s", upstreamResp1)
}))
upstreamServer2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s", upstreamResp2)
}))
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{
// The order is important; the short path should go first to ensure
// we choose the most specific route, not the first one.
&fakeUpstream{
name: upstreamServer1.URL,
from: "/",
},
&fakeUpstream{
name: upstreamServer2.URL,
from: "/api",
},
},
}
return p
}
func TestMultiReverseProxyFromClient(t *testing.T) {
p := newMultiHostTestProxy()
// This is a full end-end test, so the proxy handler.
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer proxy.Close()
// Table tests.
var multiProxy = []struct {
url string
body []byte
}{
{
"/",
upstreamResp1,
},
{
"/api/",
upstreamResp2,
},
{
"/messages/",
upstreamResp1,
},
{
"/api/messages/?text=cat",
upstreamResp2,
},
}
for _, tt := range multiProxy {
// Create client request
reqURL := proxy.URL + tt.url
req, err := http.NewRequest("GET", reqURL, nil)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
t.Fatalf("Failed to read response: %v", err)
}
if !bytes.Equal(body, tt.body) {
t.Errorf("Expected '%s' but got '%s' instead", tt.body, body)
}
}
}
func TestHostSimpleProxyNoHeaderForward(t *testing.T) {
var requestHost string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestHost = r.Host
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
}
r := httptest.NewRequest("GET", "/", nil)
r.Host = "test.com"
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
if !strings.Contains(backend.URL, "//") {
t.Fatalf("The URL of the backend server doesn't contains //: %s", backend.URL)
}
expectedHost := strings.Split(backend.URL, "//")
if expectedHost[1] != requestHost {
t.Fatalf("Expected %s as a Host header got %s\n", expectedHost[1], requestHost)
}
}
func TestHostHeaderReplacedUsingForward(t *testing.T) {
var requestHost string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestHost = r.Host
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
proxyHostHeader := "test2.com"
upstream.host.UpstreamHeaders = http.Header{"Host": []string{proxyHostHeader}}
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{upstream},
}
r := httptest.NewRequest("GET", "/", nil)
r.Host = "test.com"
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
if proxyHostHeader != requestHost {
t.Fatalf("Expected %s as a Host header got %s\n", proxyHostHeader, requestHost)
}
}
func TestBasicAuth(t *testing.T) {
basicAuthTestcase(t, nil, nil)
basicAuthTestcase(t, nil, url.UserPassword("username", "password"))
basicAuthTestcase(t, url.UserPassword("usename", "password"), nil)
basicAuthTestcase(t, url.UserPassword("unused", "unused"),
url.UserPassword("username", "password"))
}
func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u, p, ok := r.BasicAuth()
if ok {
w.Write([]byte(u))
}
if ok && p != "" {
w.Write([]byte(":"))
w.Write([]byte(p))
}
}))
defer backend.Close()
backURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("Failed to parse URL: %v", err)
}
backURL.User = upstreamUser
p := &Proxy{
Next: httpserver.EmptyNext,
Upstreams: []Upstream{newFakeUpstream(backURL.String(), false)},
}
r, err := http.NewRequest("GET", "/foo", nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
if clientUser != nil {
u := clientUser.Username()
p, _ := clientUser.Password()
r.SetBasicAuth(u, p)
}
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
if w.Code != 200 {
t.Fatalf("Invalid response code: %d", w.Code)
}
body, _ := ioutil.ReadAll(w.Body)
if clientUser != nil {
if string(body) != clientUser.String() {
t.Fatalf("Invalid auth info: %s", string(body))
}
} else {
if upstreamUser != nil {
if string(body) != upstreamUser.String() {
t.Fatalf("Invalid auth info: %s", string(body))
}
} else {
if string(body) != "" {
t.Fatalf("Invalid auth info: %s", string(body))
}
}
}
}
func TestProxyDirectorURL(t *testing.T) {
for i, c := range []struct {
requestURL string
targetURL string
without string
expectURL string
}{
{
requestURL: `http://localhost:2020/test`,
targetURL: `https://localhost:2021`,
expectURL: `https://localhost:2021/test`,
},
{
requestURL: `http://localhost:2020/test`,
targetURL: `https://localhost:2021/t`,
expectURL: `https://localhost:2021/t/test`,
},
{
requestURL: `http://localhost:2020/test?t=w`,
targetURL: `https://localhost:2021/t`,
expectURL: `https://localhost:2021/t/test?t=w`,
},
{
requestURL: `http://localhost:2020/test`,
targetURL: `https://localhost:2021/t?foo=bar`,
expectURL: `https://localhost:2021/t/test?foo=bar`,
},
{
requestURL: `http://localhost:2020/test?t=w`,
targetURL: `https://localhost:2021/t?foo=bar`,
expectURL: `https://localhost:2021/t/test?foo=bar&t=w`,
},
{
requestURL: `http://localhost:2020/test?t=w`,
targetURL: `https://localhost:2021/t?foo=bar`,
expectURL: `https://localhost:2021/t?foo=bar&t=w`,
without: "/test",
},
{
requestURL: `http://localhost:2020/test?t%3dw`,
targetURL: `https://localhost:2021/t?foo%3dbar`,
expectURL: `https://localhost:2021/t?foo%3dbar&t%3dw`,
without: "/test",
},
{
requestURL: `http://localhost:2020/test/`,
targetURL: `https://localhost:2021/t/`,
expectURL: `https://localhost:2021/t/test/`,
},
{
requestURL: `http://localhost:2020/test/mypath`,
targetURL: `https://localhost:2021/t/`,
expectURL: `https://localhost:2021/t/mypath`,
without: "/test",
},
{
requestURL: `http://localhost:2020/%2C`,
targetURL: `https://localhost:2021/t/`,
expectURL: `https://localhost:2021/t/%2C`,
},
{
requestURL: `http://localhost:2020/%2C/`,
targetURL: `https://localhost:2021/t/`,
expectURL: `https://localhost:2021/t/%2C/`,
},
{
requestURL: `http://localhost:2020/test`,
targetURL: `https://localhost:2021/%2C`,
expectURL: `https://localhost:2021/%2C/test`,
},
{
requestURL: `http://localhost:2020/%2C`,
targetURL: `https://localhost:2021/%2C`,
expectURL: `https://localhost:2021/%2C/%2C`,
},
} {
targetURL, err := url.Parse(c.targetURL)
if err != nil {
t.Errorf("case %d failed to parse target URL: %s", i, err)
continue
}
req, err := http.NewRequest("GET", c.requestURL, nil)
if err != nil {
t.Errorf("case %d failed to create request: %s", i, err)
continue
}
NewSingleHostReverseProxy(targetURL, c.without, 0).Director(req)
if expect, got := c.expectURL, req.URL.String(); expect != got {
t.Errorf("case %d url not equal: expect %q, but got %q",
i, expect, got)
}
}
}
func TestReverseProxyRetry(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
// set up proxy
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.Copy(w, r.Body)
r.Body.Close()
}))
defer backend.Close()
su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`
proxy / localhost:65535 localhost:65534 `+backend.URL+` {
policy round_robin
fail_timeout 5s
max_fails 1
try_duration 5s
try_interval 250ms
}
`)))
if err != nil {
t.Fatal(err)
}
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: su,
}
// middle is required to simulate closable downstream request body
middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err = p.ServeHTTP(w, r)
if err != nil {
t.Error(err)
}
}))
defer middle.Close()
testcase := "test content"
r, err := http.NewRequest("POST", middle.URL, bytes.NewBufferString(testcase))
if err != nil {
t.Fatal(err)
}
resp, err := http.DefaultTransport.RoundTrip(r)
if err != nil {
t.Fatal(err)
}
b, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
t.Fatal(err)
}
if string(b) != testcase {
t.Fatalf("string(b) = %s, want %s", string(b), testcase)
}
}
func TestReverseProxyLargeBody(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stderr)
// set up proxy
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
io.Copy(ioutil.Discard, r.Body)
r.Body.Close()
}))
defer backend.Close()
su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`proxy / `+backend.URL)))
if err != nil {
t.Fatal(err)
}
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: su,
}
// middle is required to simulate closable downstream request body
middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err = p.ServeHTTP(w, r)
if err != nil {
t.Error(err)
}
}))
defer middle.Close()
// Our request body will be 100MB
bodySize := uint64(100 * 1000 * 1000)
// We want to see how much memory the proxy module requires for this request.
// So lets record the mem stats before we start it.
begMemstats := &runtime.MemStats{}
runtime.ReadMemStats(begMemstats)
r, err := http.NewRequest("POST", middle.URL, &noopReader{len: bodySize})
if err != nil {
t.Fatal(err)
}
resp, err := http.DefaultTransport.RoundTrip(r)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
// Finally we need the mem stats after the request is done...
endMemstats := &runtime.MemStats{}
runtime.ReadMemStats(endMemstats)
// ...to calculate the total amount of allocated memory during the request.
totalAlloc := endMemstats.TotalAlloc - begMemstats.TotalAlloc
// If that's as much as the size of the body itself it's a serious sign that the
// request was not "streamed" to the upstream without buffering it first.
if totalAlloc >= bodySize {
t.Fatalf("proxy allocated too much memory: %d bytes", totalAlloc)
}
}
func TestCancelRequest(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
}
// setup request with cancel ctx
req := httptest.NewRequest("GET", "/", nil)
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
req = req.WithContext(ctx)
// add GotConn hook to cancel the request
gotC := make(chan struct{})
defer close(gotC)
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
gotC <- struct{}{}
},
}
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
// wait for canceling the request
go func() {
<-gotC
cancel()
}()
status, err := p.ServeHTTP(httptest.NewRecorder(), req)
if status != 0 || err != nil {
t.Errorf("expect proxy handle normally, but not, status:%d, err:%q",
status, err)
}
}
type noopReader struct {
len uint64
pos uint64
}
var _ io.Reader = &noopReader{}
func (r *noopReader) Read(b []byte) (int, error) {
if r.pos >= r.len {
return 0, io.EOF
}
n := int(r.len - r.pos)
if n > len(b) {
n = len(b)
}
for i := range b[:n] {
b[i] = 0
}
r.pos += uint64(n)
return n, nil
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name)
u := &fakeUpstream{
name: name,
from: "/",
host: &UpstreamHost{
Name: name,
ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost),
},
}
if insecure {
u.host.ReverseProxy.UseInsecureTransport()
}
return u
}
type fakeUpstream struct {
name string
host *UpstreamHost
from string
without string
}
func (u *fakeUpstream) From() string {
return u.from
}
func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
if u.host == nil {
uri, err := url.Parse(u.name)
if err != nil {
log.Fatalf("Unable to url.Parse %s: %v", u.name, err)
}
u.host = &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
}
}
return u.host
}
func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true }
func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
func (u *fakeUpstream) GetHostCount() int { return 1 }
func (u *fakeUpstream) Stop() error { return nil }
// newWebSocketTestProxy returns a test proxy that will
// redirect to the specified backendAddr. The function
// also sets up the rules/environment for testing WebSocket
// proxy.
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
return &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{
name: backendAddr,
without: "",
insecure: insecure,
}},
}
}
func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
return &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}},
}
}
type fakeWsUpstream struct {
name string
without string
insecure bool
}
func (u *fakeWsUpstream) From() string {
return "/"
}
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
uri, _ := url.Parse(u.name)
host := &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
UpstreamHeaders: http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}},
}
if u.insecure {
host.ReverseProxy.UseInsecureTransport()
}
return host
}
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
func (u *fakeWsUpstream) GetHostCount() int { return 1 }
func (u *fakeWsUpstream) Stop() error { return nil }
// recorderHijacker is a ResponseRecorder that can
// be hijacked.
type recorderHijacker struct {
*httptest.ResponseRecorder
fakeConn *fakeConn
}
func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return rh.fakeConn, nil, nil
}
type fakeConn struct {
readBuf bytes.Buffer
writeBuf bytes.Buffer
}
func (c *fakeConn) LocalAddr() net.Addr { return nil }
func (c *fakeConn) RemoteAddr() net.Addr { return nil }
func (c *fakeConn) SetDeadline(t time.Time) error { return nil }
func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil }
func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
func BenchmarkProxy(b *testing.B) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello, client"))
}))
defer backend.Close()
upstream := newFakeUpstream(backend.URL, false)
upstream.host.UpstreamHeaders = http.Header{
"Hostname": {"{hostname}"},
"Host": {"{host}"},
"X-Real-IP": {"{remote}"},
"X-Forwarded-Proto": {"{scheme}"},
}
// set up proxy
p := &Proxy{
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
Upstreams: []Upstream{upstream},
}
w := httptest.NewRecorder()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
// create request and response recorder
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
b.Fatalf("Failed to create request: %v", err)
}
b.StartTimer()
p.ServeHTTP(w, r)
}
}