add multi proxy supprot based on urls

This commit is contained in:
Viacheslav Biriukov 2015-06-07 17:07:23 +00:00
parent 995a2ea618
commit 81eb7dbb26
3 changed files with 258 additions and 76 deletions

View file

@ -13,6 +13,16 @@ import (
var errUnreachable = errors.New("Unreachable backend")
// Does path match pattern?
func pathMatch(pattern, path string) bool {
if len(pattern) == 0 {
// should not happen
return false
}
n := len(pattern)
return len(path) >= n && path[0:n] == pattern
}
// Proxy represents a middleware instance that can proxy requests.
type Proxy struct {
Next middleware.Handler
@ -56,72 +66,87 @@ func (uh *UpstreamHost) Down() bool {
return uh.CheckDown(uh)
}
// ServeHTTP satisfies the middleware.Handler interface.
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
func (p Proxy) match(path string) Upstream {
var u Upstream
n := 0
for _, upstream := range p.Upstreams {
if middleware.Path(r.URL.Path).Matches(upstream.From()) {
var replacer middleware.Replacer
start := time.Now()
requestHost := r.Host
// Since Select() should give us "up" hosts, keep retrying
// hosts until timeout (or until we get a nil host).
for time.Now().Sub(start) < (60 * time.Second) {
host := upstream.Select()
if host == nil {
return http.StatusBadGateway, errUnreachable
}
proxy := host.ReverseProxy
r.Host = host.Name
if baseURL, err := url.Parse(host.Name); err == nil {
r.Host = baseURL.Host
if proxy == nil {
proxy = NewSingleHostReverseProxy(baseURL, host.WithoutPathPrefix)
}
} else if proxy == nil {
return http.StatusInternalServerError, err
}
var extraHeaders http.Header
if host.ExtraHeaders != nil {
extraHeaders = make(http.Header)
if replacer == nil {
rHost := r.Host
r.Host = requestHost
replacer = middleware.NewReplacer(r, nil)
r.Host = rHost
}
for header, values := range host.ExtraHeaders {
for _, value := range values {
extraHeaders.Add(header,
replacer.Replace(value))
if header == "Host" {
r.Host = replacer.Replace(value)
}
}
}
}
atomic.AddInt64(&host.Conns, 1)
backendErr := proxy.ServeHTTP(w, r, extraHeaders)
atomic.AddInt64(&host.Conns, -1)
if backendErr == nil {
return 0, nil
}
timeout := host.FailTimeout
if timeout == 0 {
timeout = 10 * time.Second
}
atomic.AddInt32(&host.Fails, 1)
go func(host *UpstreamHost, timeout time.Duration) {
time.Sleep(timeout)
atomic.AddInt32(&host.Fails, -1)
}(host, timeout)
}
return http.StatusBadGateway, errUnreachable
pattern := upstream.From()
if !pathMatch(pattern, path) {
continue
}
if len(pattern) > n {
n = len(pattern)
u = upstream
}
}
return p.Next.ServeHTTP(w, r)
return u
}
// ServeHTTP satisfies the middleware.Handler interface.
func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
// Select best match upstream
upstream := p.match(r.URL.Path)
if upstream == nil {
return p.Next.ServeHTTP(w, r)
}
var replacer middleware.Replacer
start := time.Now()
requestHost := r.Host
// Since Select() should give us "up" hosts, keep retrying
// hosts until timeout (or until we get a nil host).
for time.Now().Sub(start) < (60 * time.Second) {
host := upstream.Select()
if host == nil {
return http.StatusBadGateway, errUnreachable
}
proxy := host.ReverseProxy
r.Host = host.Name
if baseURL, err := url.Parse(host.Name); err == nil {
r.Host = baseURL.Host
if proxy == nil {
proxy = NewSingleHostReverseProxy(baseURL, host.WithoutPathPrefix)
}
} else if proxy == nil {
return http.StatusInternalServerError, err
}
var extraHeaders http.Header
if host.ExtraHeaders != nil {
extraHeaders = make(http.Header)
if replacer == nil {
rHost := r.Host
r.Host = requestHost
replacer = middleware.NewReplacer(r, nil)
r.Host = rHost
}
for header, values := range host.ExtraHeaders {
for _, value := range values {
extraHeaders.Add(header,
replacer.Replace(value))
if header == "Host" {
r.Host = replacer.Replace(value)
}
}
}
}
atomic.AddInt64(&host.Conns, 1)
backendErr := proxy.ServeHTTP(w, r, extraHeaders)
atomic.AddInt64(&host.Conns, -1)
if backendErr == nil {
return 0, nil
}
timeout := host.FailTimeout
if timeout == 0 {
timeout = 10 * time.Second
}
atomic.AddInt32(&host.Fails, 1)
go func(host *UpstreamHost, timeout time.Duration) {
time.Sleep(timeout)
atomic.AddInt32(&host.Fails, -1)
}(host, timeout)
}
return http.StatusBadGateway, errUnreachable
}

View file

@ -3,7 +3,9 @@ package proxy
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
@ -98,30 +100,37 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
// also sets up the rules/environment for testing WebSocket
// proxy.
func newWebSocketTestProxy(backendAddr string) *Proxy {
proxyHeaders = http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"},
}
return &Proxy{
Upstreams: []Upstream{&fakeUpstream{name: backendAddr}},
Upstreams: []Upstream{
&fakeUpstream{
name: backendAddr,
from: "/",
extraHeaders: http.Header{
"Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"},
},
},
},
}
}
type fakeUpstream struct {
name string
name string
from string
without string
extraHeaders http.Header
}
func (u *fakeUpstream) From() string {
return "/"
return u.from
}
func (u *fakeUpstream) Select() *UpstreamHost {
uri, _ := url.Parse(u.name)
return &UpstreamHost{
Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, ""),
ExtraHeaders: proxyHeaders,
ReverseProxy: NewSingleHostReverseProxy(uri, u.without),
ExtraHeaders: u.extraHeaders,
}
}
@ -149,3 +158,151 @@ func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
var (
upstreamResp1 = []byte("Hello, /")
upstreamResp2 = []byte("Hello, /api/")
)
func newMultiHostTestProxy() *Proxy {
// No-op backends.
upstreamServer1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s", upstreamResp1)
}))
upstreamServer2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%s", upstreamResp2)
}))
// Test proxy.
p := &Proxy{
Upstreams: []Upstream{
&fakeUpstream{
name: upstreamServer1.URL,
from: "/",
extraHeaders: http.Header{
"Host": {"example.com"},
},
},
&fakeUpstream{
name: upstreamServer2.URL,
from: "/api",
extraHeaders: http.Header{
"Host": {"example.net"},
},
},
},
}
return p
}
func TestMultiReverseProxyFromClient(t *testing.T) {
p := newMultiHostTestProxy()
// This is a full end-end test, so the proxy handler.
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
defer proxy.Close()
// Table tests.
var multiProxy = []struct {
url string
body []byte
}{
{
"/",
upstreamResp1,
},
{
"/api/",
upstreamResp2,
},
{
"/messages/",
upstreamResp1,
},
{
"/api/messages/?text=cat",
upstreamResp2,
},
}
for _, tt := range multiProxy {
// Create client request
reqURL := singleJoiningSlash(proxy.URL, tt.url)
req, err := http.NewRequest("GET", reqURL, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
client := &http.Client{}
resp, err := client.Do(req)
defer resp.Body.Close()
if err != nil {
t.Fatalf("Failed to make request: %v", err)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatalf("Failed to read responce: %v", err)
}
if !bytes.Equal(body, tt.body) {
t.Errorf("Expected '%s' but got '%s' instead", tt.body, body)
}
}
}
func T1estMultiReverseProxyServeHostHeader(t *testing.T) {
p := newMultiHostTestProxy()
// Table tests.
var multiHostHeader = []struct {
url string
extraHeaders http.Header
}{
{
"/",
http.Header{
"Host": {"example.com"},
},
},
{
"/api/",
http.Header{
"Host": {"example.net"},
},
},
{
"/messages/",
http.Header{
"Host": {"example.com"},
},
},
{
"/api/messages/?text=cat",
http.Header{
"Host": {"example.net"},
},
},
}
for _, tt := range multiHostHeader {
// Create client request
r, err := http.NewRequest("GET", tt.url, nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
// Capture the request
w := httptest.NewRecorder()
// Booya! Do the test.
p.ServeHTTP(w, r)
host := r.Header.Get("Host")
if host != tt.extraHeaders.Get("Host") {
t.Errorf("Expected Host Header '%s' but got '%s' instead", tt.extraHeaders.Get("Host"), host)
}
}
}

View file

@ -14,7 +14,6 @@ import (
var (
supportedPolicies map[string]func() Policy = make(map[string]func() Policy)
proxyHeaders http.Header = make(http.Header)
)
type staticUpstream struct {
@ -36,6 +35,7 @@ type staticUpstream struct {
func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
var upstreams []Upstream
for c.Next() {
var proxyHeaders http.Header = make(http.Header)
upstream := &staticUpstream{
from: "",
Hosts: nil,