mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-07 19:38:49 +03:00
Merge pull request #1309 from lhecker/master
Fixed #1292 and resulting issues from #1300
This commit is contained in:
commit
7cbbb01f94
3 changed files with 348 additions and 83 deletions
|
@ -247,12 +247,28 @@ func createUpstreamRequest(r *http.Request) *http.Request {
|
||||||
outreq.URL.Opaque = outreq.URL.RawPath
|
outreq.URL.Opaque = outreq.URL.RawPath
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We are modifying the same underlying map from req (shallow
|
||||||
|
// copied above) so we only copy it if necessary.
|
||||||
|
copiedHeaders := false
|
||||||
|
|
||||||
|
// Remove hop-by-hop headers listed in the "Connection" header.
|
||||||
|
// See RFC 2616, section 14.10.
|
||||||
|
if c := outreq.Header.Get("Connection"); c != "" {
|
||||||
|
for _, f := range strings.Split(c, ",") {
|
||||||
|
if f = strings.TrimSpace(f); f != "" {
|
||||||
|
if !copiedHeaders {
|
||||||
|
outreq.Header = make(http.Header)
|
||||||
|
copyHeader(outreq.Header, r.Header)
|
||||||
|
copiedHeaders = true
|
||||||
|
}
|
||||||
|
outreq.Header.Del(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Remove hop-by-hop headers to the backend. Especially
|
// Remove hop-by-hop headers to the backend. Especially
|
||||||
// important is "Connection" because we want a persistent
|
// important is "Connection" because we want a persistent
|
||||||
// connection, regardless of what the client sent to us. This
|
// connection, regardless of what the client sent to us.
|
||||||
// is modifying the same underlying map from r (shallow
|
|
||||||
// copied above) so we only copy it if necessary.
|
|
||||||
var copiedHeaders bool
|
|
||||||
for _, h := range hopHeaders {
|
for _, h := range hopHeaders {
|
||||||
if outreq.Header.Get(h) != "" {
|
if outreq.Header.Get(h) != "" {
|
||||||
if !copiedHeaders {
|
if !copiedHeaders {
|
||||||
|
|
|
@ -42,10 +42,32 @@ func TestReverseProxy(t *testing.T) {
|
||||||
log.SetOutput(ioutil.Discard)
|
log.SetOutput(ioutil.Discard)
|
||||||
defer log.SetOutput(os.Stderr)
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
verifyHeaders := func(headers http.Header, trailers http.Header) {
|
||||||
|
if headers.Get("X-Header") != "header-value" {
|
||||||
|
t.Error("Expected header 'X-Header' to be proxied properly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if trailers == nil {
|
||||||
|
t.Error("Expected to receive trailers")
|
||||||
|
}
|
||||||
|
if trailers.Get("X-Trailer") != "trailer-value" {
|
||||||
|
t.Error("Expected header 'X-Trailer' to be proxied properly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var requestReceived bool
|
var requestReceived bool
|
||||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// read the body (even if it's empty) to make Go parse trailers
|
||||||
|
io.Copy(ioutil.Discard, r.Body)
|
||||||
|
verifyHeaders(r.Header, r.Trailer)
|
||||||
|
|
||||||
requestReceived = true
|
requestReceived = true
|
||||||
|
|
||||||
|
w.Header().Set("Trailer", "X-Trailer")
|
||||||
|
w.Header().Set("X-Header", "header-value")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte("Hello, client"))
|
w.Write([]byte("Hello, client"))
|
||||||
|
w.Header().Set("X-Trailer", "trailer-value")
|
||||||
}))
|
}))
|
||||||
defer backend.Close()
|
defer backend.Close()
|
||||||
|
|
||||||
|
@ -59,12 +81,21 @@ func TestReverseProxy(t *testing.T) {
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
r.ContentLength = -1 // force chunked encoding (required for trailers)
|
||||||
|
r.Header.Set("X-Header", "header-value")
|
||||||
|
r.Trailer = map[string][]string{
|
||||||
|
"X-Trailer": {"trailer-value"},
|
||||||
|
}
|
||||||
|
|
||||||
p.ServeHTTP(w, r)
|
p.ServeHTTP(w, r)
|
||||||
|
|
||||||
if !requestReceived {
|
if !requestReceived {
|
||||||
t.Error("Expected backend to receive request, but it didn't")
|
t.Error("Expected backend to receive request, but it didn't")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
res := w.Result()
|
||||||
|
verifyHeaders(res.Header, res.Trailer)
|
||||||
|
|
||||||
// Make sure {upstream} placeholder is set
|
// Make sure {upstream} placeholder is set
|
||||||
rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
|
rr := httpserver.NewResponseRecorder(httptest.NewRecorder())
|
||||||
rr.Replacer = httpserver.NewReplacer(r, rr, "-")
|
rr.Replacer = httpserver.NewReplacer(r, rr, "-")
|
||||||
|
@ -123,7 +154,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
|
||||||
defer wsNop.Close()
|
defer wsNop.Close()
|
||||||
|
|
||||||
// Get proxy to use for the test
|
// Get proxy to use for the test
|
||||||
p := newWebSocketTestProxy(wsNop.URL)
|
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||||
|
|
||||||
// Create client request
|
// Create client request
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
@ -148,7 +179,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
||||||
defer wsNop.Close()
|
defer wsNop.Close()
|
||||||
|
|
||||||
// Get proxy to use for the test
|
// Get proxy to use for the test
|
||||||
p := newWebSocketTestProxy(wsNop.URL)
|
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||||
|
|
||||||
// Create client request
|
// Create client request
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
|
@ -189,7 +220,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
||||||
defer wsEcho.Close()
|
defer wsEcho.Close()
|
||||||
|
|
||||||
// Get proxy to use for the test
|
// Get proxy to use for the test
|
||||||
p := newWebSocketTestProxy(wsEcho.URL)
|
p := newWebSocketTestProxy(wsEcho.URL, false)
|
||||||
|
|
||||||
// This is a full end-end test, so the proxy handler
|
// This is a full end-end test, so the proxy handler
|
||||||
// has to be part of a server listening on a port. Our
|
// has to be part of a server listening on a port. Our
|
||||||
|
@ -228,6 +259,52 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
|
||||||
|
wsEcho := newTLSServer(websocket.Handler(func(ws *websocket.Conn) {
|
||||||
|
io.Copy(ws, ws)
|
||||||
|
}))
|
||||||
|
defer wsEcho.Close()
|
||||||
|
|
||||||
|
p := newWebSocketTestProxy(wsEcho.URL, true)
|
||||||
|
|
||||||
|
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
p.ServeHTTP(w, r)
|
||||||
|
}))
|
||||||
|
defer echoProxy.Close()
|
||||||
|
|
||||||
|
// Set up WebSocket client
|
||||||
|
url := strings.Replace(echoProxy.URL, "https://", "wss://", 1)
|
||||||
|
wsCfg, err := websocket.NewConfig(url, echoProxy.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
wsCfg.TlsConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
ws, err := websocket.DialConfig(wsCfg)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer ws.Close()
|
||||||
|
|
||||||
|
// Send test message
|
||||||
|
trialMsg := "Is it working?"
|
||||||
|
|
||||||
|
if sendErr := websocket.Message.Send(ws, trialMsg); sendErr != nil {
|
||||||
|
t.Fatal(sendErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// It should be echoed back to us
|
||||||
|
var actualMsg string
|
||||||
|
|
||||||
|
if rcvErr := websocket.Message.Receive(ws, &actualMsg); rcvErr != nil {
|
||||||
|
t.Fatal(rcvErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if actualMsg != trialMsg {
|
||||||
|
t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestUnixSocketProxy(t *testing.T) {
|
func TestUnixSocketProxy(t *testing.T) {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
return
|
return
|
||||||
|
@ -264,7 +341,7 @@ func TestUnixSocketProxy(t *testing.T) {
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
|
|
||||||
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
||||||
p := newWebSocketTestProxy(url)
|
p := newWebSocketTestProxy(url, false)
|
||||||
|
|
||||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
p.ServeHTTP(w, r)
|
p.ServeHTTP(w, r)
|
||||||
|
@ -982,10 +1059,14 @@ func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.
|
||||||
// redirect to the specified backendAddr. The function
|
// redirect to the specified backendAddr. The function
|
||||||
// also sets up the rules/environment for testing WebSocket
|
// also sets up the rules/environment for testing WebSocket
|
||||||
// proxy.
|
// proxy.
|
||||||
func newWebSocketTestProxy(backendAddr string) *Proxy {
|
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
|
||||||
return &Proxy{
|
return &Proxy{
|
||||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}},
|
Upstreams: []Upstream{&fakeWsUpstream{
|
||||||
|
name: backendAddr,
|
||||||
|
without: "",
|
||||||
|
insecure: insecure,
|
||||||
|
}},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -997,8 +1078,9 @@ func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
|
||||||
}
|
}
|
||||||
|
|
||||||
type fakeWsUpstream struct {
|
type fakeWsUpstream struct {
|
||||||
name string
|
name string
|
||||||
without string
|
without string
|
||||||
|
insecure bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *fakeWsUpstream) From() string {
|
func (u *fakeWsUpstream) From() string {
|
||||||
|
@ -1007,13 +1089,17 @@ func (u *fakeWsUpstream) From() string {
|
||||||
|
|
||||||
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
||||||
uri, _ := url.Parse(u.name)
|
uri, _ := url.Parse(u.name)
|
||||||
return &UpstreamHost{
|
host := &UpstreamHost{
|
||||||
Name: u.name,
|
Name: u.name,
|
||||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
|
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
|
||||||
UpstreamHeaders: http.Header{
|
UpstreamHeaders: http.Header{
|
||||||
"Connection": {"{>Connection}"},
|
"Connection": {"{>Connection}"},
|
||||||
"Upgrade": {"{>Upgrade}"}},
|
"Upgrade": {"{>Upgrade}"}},
|
||||||
}
|
}
|
||||||
|
if u.insecure {
|
||||||
|
host.ReverseProxy.UseInsecureTransport()
|
||||||
|
}
|
||||||
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
|
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
|
||||||
|
|
|
@ -27,10 +27,28 @@ import (
|
||||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||||
)
|
)
|
||||||
|
|
||||||
var bufferPool = sync.Pool{New: createBuffer}
|
var (
|
||||||
|
defaultDialer = &net.Dialer{
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
KeepAlive: 30 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
bufferPool = sync.Pool{New: createBuffer}
|
||||||
|
)
|
||||||
|
|
||||||
func createBuffer() interface{} {
|
func createBuffer() interface{} {
|
||||||
return make([]byte, 32*1024)
|
return make([]byte, 0, 32*1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pooledIoCopy(dst io.Writer, src io.Reader) {
|
||||||
|
buf := bufferPool.Get().([]byte)
|
||||||
|
defer bufferPool.Put(buf)
|
||||||
|
|
||||||
|
// CopyBuffer only uses buf up to its length and panics if it's 0.
|
||||||
|
// Due to that we extend buf's length to its capacity here and
|
||||||
|
// ensure it's always non-zero.
|
||||||
|
bufCap := cap(buf)
|
||||||
|
io.CopyBuffer(dst, src, buf[0:bufCap:bufCap])
|
||||||
}
|
}
|
||||||
|
|
||||||
// onExitFlushLoop is a callback set by tests to detect the state of the
|
// onExitFlushLoop is a callback set by tests to detect the state of the
|
||||||
|
@ -135,11 +153,8 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||||
// just use default transport, to avoid creating
|
// just use default transport, to avoid creating
|
||||||
// a brand new transport
|
// a brand new transport
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
Dial: (&net.Dialer{
|
Dial: defaultDialer.Dial,
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
}
|
}
|
||||||
|
@ -148,7 +163,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||||
} else {
|
} else {
|
||||||
transport.MaxIdleConnsPerHost = keepalive
|
transport.MaxIdleConnsPerHost = keepalive
|
||||||
}
|
}
|
||||||
http2.ConfigureTransport(transport)
|
if httpserver.HTTP2 {
|
||||||
|
http2.ConfigureTransport(transport)
|
||||||
|
}
|
||||||
rp.Transport = transport
|
rp.Transport = transport
|
||||||
}
|
}
|
||||||
return rp
|
return rp
|
||||||
|
@ -160,18 +177,20 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||||
func (rp *ReverseProxy) UseInsecureTransport() {
|
func (rp *ReverseProxy) UseInsecureTransport() {
|
||||||
if rp.Transport == nil {
|
if rp.Transport == nil {
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
Dial: (&net.Dialer{
|
Dial: defaultDialer.Dial,
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
}
|
}
|
||||||
http2.ConfigureTransport(transport)
|
if httpserver.HTTP2 {
|
||||||
|
http2.ConfigureTransport(transport)
|
||||||
|
}
|
||||||
rp.Transport = transport
|
rp.Transport = transport
|
||||||
} else if transport, ok := rp.Transport.(*http.Transport); ok {
|
} else if transport, ok := rp.Transport.(*http.Transport); ok {
|
||||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
// No http2.ConfigureTransport() here.
|
||||||
|
// For now this is only added in places where
|
||||||
|
// an http.Transport is actually created.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,20 +205,33 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
||||||
}
|
}
|
||||||
|
|
||||||
rp.Director(outreq)
|
rp.Director(outreq)
|
||||||
outreq.Proto = "HTTP/1.1"
|
|
||||||
outreq.ProtoMajor = 1
|
|
||||||
outreq.ProtoMinor = 1
|
|
||||||
outreq.Close = false
|
|
||||||
|
|
||||||
res, err := transport.RoundTrip(outreq)
|
res, err := transport.RoundTrip(outreq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
isWebsocket := res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket"
|
||||||
|
|
||||||
|
// Remove hop-by-hop headers listed in the
|
||||||
|
// "Connection" header of the response.
|
||||||
|
if c := res.Header.Get("Connection"); c != "" {
|
||||||
|
for _, f := range strings.Split(c, ",") {
|
||||||
|
if f = strings.TrimSpace(f); f != "" {
|
||||||
|
res.Header.Del(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, h := range hopHeaders {
|
||||||
|
res.Header.Del(h)
|
||||||
|
}
|
||||||
|
|
||||||
if respUpdateFn != nil {
|
if respUpdateFn != nil {
|
||||||
respUpdateFn(res)
|
respUpdateFn(res)
|
||||||
}
|
}
|
||||||
if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" {
|
|
||||||
|
if isWebsocket {
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
hj, ok := rw.(http.Hijacker)
|
hj, ok := rw.(http.Hijacker)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -228,27 +260,39 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
||||||
}
|
}
|
||||||
defer backendConn.Close()
|
defer backendConn.Close()
|
||||||
|
|
||||||
go func() {
|
go pooledIoCopy(backendConn, conn) // write tcp stream to backend
|
||||||
io.Copy(backendConn, conn) // write tcp stream to backend.
|
pooledIoCopy(conn, backendConn) // read tcp stream from backend
|
||||||
}()
|
|
||||||
io.Copy(conn, backendConn) // read tcp stream from backend.
|
|
||||||
} else {
|
} else {
|
||||||
defer res.Body.Close()
|
|
||||||
for _, h := range hopHeaders {
|
|
||||||
res.Header.Del(h)
|
|
||||||
}
|
|
||||||
copyHeader(rw.Header(), res.Header)
|
copyHeader(rw.Header(), res.Header)
|
||||||
|
|
||||||
|
// The "Trailer" header isn't included in the Transport's response,
|
||||||
|
// at least for *http.Transport. Build it up from Trailer.
|
||||||
|
if len(res.Trailer) > 0 {
|
||||||
|
trailerKeys := make([]string, 0, len(res.Trailer))
|
||||||
|
for k := range res.Trailer {
|
||||||
|
trailerKeys = append(trailerKeys, k)
|
||||||
|
}
|
||||||
|
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
rw.WriteHeader(res.StatusCode)
|
rw.WriteHeader(res.StatusCode)
|
||||||
|
if len(res.Trailer) > 0 {
|
||||||
|
// Force chunking if we saw a response trailer.
|
||||||
|
// This prevents net/http from calculating the length for short
|
||||||
|
// bodies and adding a Content-Length.
|
||||||
|
if fl, ok := rw.(http.Flusher); ok {
|
||||||
|
fl.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
rp.copyResponse(rw, res.Body)
|
rp.copyResponse(rw, res.Body)
|
||||||
|
res.Body.Close() // close now, instead of defer, to populate res.Trailer
|
||||||
|
copyHeader(rw.Header(), res.Trailer)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||||
buf := bufferPool.Get()
|
|
||||||
defer bufferPool.Put(buf)
|
|
||||||
|
|
||||||
if rp.FlushInterval != 0 {
|
if rp.FlushInterval != 0 {
|
||||||
if wf, ok := dst.(writeFlusher); ok {
|
if wf, ok := dst.(writeFlusher); ok {
|
||||||
mlw := &maxLatencyWriter{
|
mlw := &maxLatencyWriter{
|
||||||
|
@ -261,7 +305,7 @@ func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
|
||||||
dst = mlw
|
dst = mlw
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
io.CopyBuffer(dst, src, buf.([]byte))
|
pooledIoCopy(dst, src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// skip these headers if they already exist.
|
// skip these headers if they already exist.
|
||||||
|
@ -295,16 +339,17 @@ func copyHeader(dst, src http.Header) {
|
||||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||||
var hopHeaders = []string{
|
var hopHeaders = []string{
|
||||||
|
"Alt-Svc",
|
||||||
|
"Alternate-Protocol",
|
||||||
"Connection",
|
"Connection",
|
||||||
"Keep-Alive",
|
"Keep-Alive",
|
||||||
"Proxy-Authenticate",
|
"Proxy-Authenticate",
|
||||||
"Proxy-Authorization",
|
"Proxy-Authorization",
|
||||||
"Te", // canonicalized version of "TE"
|
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||||
"Trailers",
|
"Te", // canonicalized version of "TE"
|
||||||
|
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
|
||||||
"Transfer-Encoding",
|
"Transfer-Encoding",
|
||||||
"Upgrade",
|
"Upgrade",
|
||||||
"Alternate-Protocol",
|
|
||||||
"Alt-Svc",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type respUpdateFn func(resp *http.Response)
|
type respUpdateFn func(resp *http.Response)
|
||||||
|
@ -331,51 +376,169 @@ type connHijackerTransport struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
|
||||||
transport := &http.Transport{
|
t := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
Dial: (&net.Dialer{
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
KeepAlive: 30 * time.Second,
|
|
||||||
}).Dial,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
MaxIdleConnsPerHost: -1,
|
MaxIdleConnsPerHost: -1,
|
||||||
}
|
}
|
||||||
if base != nil {
|
if b, _ := base.(*http.Transport); b != nil {
|
||||||
if baseTransport, ok := base.(*http.Transport); ok {
|
tlsClientConfig := b.TLSClientConfig
|
||||||
transport.Proxy = baseTransport.Proxy
|
if tlsClientConfig.NextProtos != nil {
|
||||||
transport.TLSClientConfig = baseTransport.TLSClientConfig
|
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
|
||||||
transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout
|
tlsClientConfig.NextProtos = nil
|
||||||
transport.Dial = baseTransport.Dial
|
|
||||||
transport.DialTLS = baseTransport.DialTLS
|
|
||||||
transport.MaxIdleConnsPerHost = -1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
t.Proxy = b.Proxy
|
||||||
|
t.TLSClientConfig = tlsClientConfig
|
||||||
|
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
|
||||||
|
t.Dial = b.Dial
|
||||||
|
t.DialTLS = b.DialTLS
|
||||||
|
} else {
|
||||||
|
t.Proxy = http.ProxyFromEnvironment
|
||||||
|
t.TLSHandshakeTimeout = 10 * time.Second
|
||||||
}
|
}
|
||||||
hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]}
|
hj := &connHijackerTransport{t, nil, bufferPool.Get().([]byte)[:0]}
|
||||||
oldDial := transport.Dial
|
|
||||||
oldDialTLS := transport.DialTLS
|
dial := getTransportDial(t)
|
||||||
if oldDial == nil {
|
dialTLS := getTransportDialTLS(t)
|
||||||
oldDial = (&net.Dialer{
|
t.Dial = func(network, addr string) (net.Conn, error) {
|
||||||
Timeout: 30 * time.Second,
|
c, err := dial(network, addr)
|
||||||
KeepAlive: 30 * time.Second,
|
hj.Conn = c
|
||||||
}).Dial
|
return &hijackedConn{c, hj}, err
|
||||||
}
|
}
|
||||||
hjTransport.Dial = func(network, addr string) (net.Conn, error) {
|
t.DialTLS = func(network, addr string) (net.Conn, error) {
|
||||||
c, err := oldDial(network, addr)
|
c, err := dialTLS(network, addr)
|
||||||
hjTransport.Conn = c
|
hj.Conn = c
|
||||||
return &hijackedConn{c, hjTransport}, err
|
return &hijackedConn{c, hj}, err
|
||||||
}
|
}
|
||||||
if oldDialTLS != nil {
|
|
||||||
hjTransport.DialTLS = func(network, addr string) (net.Conn, error) {
|
return hj
|
||||||
c, err := oldDialTLS(network, addr)
|
}
|
||||||
hjTransport.Conn = c
|
|
||||||
return &hijackedConn{c, hjTransport}, err
|
// getTransportDial always returns a plain Dialer
|
||||||
|
// and defaults to the existing t.Dial.
|
||||||
|
func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||||
|
if t.Dial != nil {
|
||||||
|
return t.Dial
|
||||||
|
}
|
||||||
|
return defaultDialer.Dial
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTransportDial always returns a TLS Dialer
|
||||||
|
// and defaults to the existing t.DialTLS.
|
||||||
|
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
|
||||||
|
if t.DialTLS != nil {
|
||||||
|
return t.DialTLS
|
||||||
|
}
|
||||||
|
|
||||||
|
// newConnHijackerTransport will modify t.Dial after calling this method
|
||||||
|
// => Create a backup reference.
|
||||||
|
plainDial := getTransportDial(t)
|
||||||
|
|
||||||
|
// The following DialTLS implementation stems from the Go stdlib and
|
||||||
|
// is identical to what happens if DialTLS is not provided.
|
||||||
|
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
|
||||||
|
return func(network, addr string) (net.Conn, error) {
|
||||||
|
plainConn, err := plainDial(network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlsClientConfig := t.TLSClientConfig
|
||||||
|
if tlsClientConfig == nil {
|
||||||
|
tlsClientConfig = &tls.Config{}
|
||||||
|
}
|
||||||
|
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
|
||||||
|
tlsClientConfig.ServerName = stripPort(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := tls.Client(plainConn, tlsClientConfig)
|
||||||
|
errc := make(chan error, 2)
|
||||||
|
var timer *time.Timer
|
||||||
|
if d := t.TLSHandshakeTimeout; d != 0 {
|
||||||
|
timer = time.AfterFunc(d, func() {
|
||||||
|
errc <- tlsHandshakeTimeoutError{}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
err := tlsConn.Handshake()
|
||||||
|
if timer != nil {
|
||||||
|
timer.Stop()
|
||||||
|
}
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
if err := <-errc; err != nil {
|
||||||
|
plainConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !tlsClientConfig.InsecureSkipVerify {
|
||||||
|
hostname := tlsClientConfig.ServerName
|
||||||
|
if hostname == "" {
|
||||||
|
hostname = stripPort(addr)
|
||||||
|
}
|
||||||
|
if err := tlsConn.VerifyHostname(hostname); err != nil {
|
||||||
|
plainConn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tlsConn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripPort returns address without its port if it has one and
|
||||||
|
// works with IP addresses as well as hostnames formatted as host:port.
|
||||||
|
//
|
||||||
|
// IPv6 addresses (excluding the port) must be enclosed in
|
||||||
|
// square brackets similar to the requirements of Go's stdlib.
|
||||||
|
func stripPort(address string) string {
|
||||||
|
// Keep in mind that the address might be a IPv6 address
|
||||||
|
// and thus contain a colon, but not have a port.
|
||||||
|
portIdx := strings.LastIndex(address, ":")
|
||||||
|
ipv6Idx := strings.LastIndex(address, "]")
|
||||||
|
if portIdx > ipv6Idx {
|
||||||
|
address = address[:portIdx]
|
||||||
|
}
|
||||||
|
return address
|
||||||
|
}
|
||||||
|
|
||||||
|
type tlsHandshakeTimeoutError struct{}
|
||||||
|
|
||||||
|
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
||||||
|
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
|
||||||
|
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
|
||||||
|
|
||||||
|
// cloneTLSClientConfig is like cloneTLSConfig but omits
|
||||||
|
// the fields SessionTicketsDisabled and SessionTicketKey.
|
||||||
|
// This makes it safe to call cloneTLSClientConfig on a config
|
||||||
|
// in active use by a server.
|
||||||
|
func cloneTLSClientConfig(cfg *tls.Config) *tls.Config {
|
||||||
|
if cfg == nil {
|
||||||
|
return &tls.Config{}
|
||||||
|
}
|
||||||
|
return &tls.Config{
|
||||||
|
Rand: cfg.Rand,
|
||||||
|
Time: cfg.Time,
|
||||||
|
Certificates: cfg.Certificates,
|
||||||
|
NameToCertificate: cfg.NameToCertificate,
|
||||||
|
GetCertificate: cfg.GetCertificate,
|
||||||
|
RootCAs: cfg.RootCAs,
|
||||||
|
NextProtos: cfg.NextProtos,
|
||||||
|
ServerName: cfg.ServerName,
|
||||||
|
ClientAuth: cfg.ClientAuth,
|
||||||
|
ClientCAs: cfg.ClientCAs,
|
||||||
|
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||||
|
CipherSuites: cfg.CipherSuites,
|
||||||
|
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
|
||||||
|
ClientSessionCache: cfg.ClientSessionCache,
|
||||||
|
MinVersion: cfg.MinVersion,
|
||||||
|
MaxVersion: cfg.MaxVersion,
|
||||||
|
CurvePreferences: cfg.CurvePreferences,
|
||||||
|
DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled,
|
||||||
|
Renegotiation: cfg.Renegotiation,
|
||||||
}
|
}
|
||||||
return hjTransport
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestIsWebsocket(req *http.Request) bool {
|
func requestIsWebsocket(req *http.Request) bool {
|
||||||
return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade"))
|
return strings.ToLower(req.Header.Get("Upgrade")) == "websocket" && strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")
|
||||||
}
|
}
|
||||||
|
|
||||||
type writeFlusher interface {
|
type writeFlusher interface {
|
||||||
|
|
Loading…
Reference in a new issue