fixed req.URL.Path for unix: sockets

This commit is contained in:
eiszfuchs 2016-03-28 20:30:00 +02:00
parent 32dbbfd64c
commit cbd9b814b9
2 changed files with 132 additions and 3 deletions

View file

@ -238,6 +238,116 @@ func TestUnixSocketProxy(t *testing.T) {
} }
} }
func GetHTTPProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, messageFormat, r.URL.String())
}))
return newPrefixedWebSocketTestProxy(ts.URL, prefix), ts
}
func GetSocketProxy(messageFormat string, prefix string) (*Proxy, *httptest.Server, error) {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, messageFormat, r.URL.String())
}))
socketPath, err := filepath.Abs("./test_socket")
if err != nil {
return nil, nil, fmt.Errorf("Unable to get absolute path: %v", err)
}
ln, err := net.Listen("unix", socketPath)
if err != nil {
return nil, nil, fmt.Errorf("Unable to listen: %v", err)
}
ts.Listener = ln
ts.Start()
tsURL := strings.Replace(ts.URL, "http://", "unix:", 1)
return newPrefixedWebSocketTestProxy(tsURL, prefix), ts, nil
}
func GetTestServerMessage(p *Proxy, ts *httptest.Server, path string) (string, error) {
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
p.ServeHTTP(w, r)
}))
// *httptest.Server is passed so it can be `defer`red properly
defer ts.Close()
defer echoProxy.Close()
res, err := http.Get(echoProxy.URL + path)
if err != nil {
return "", fmt.Errorf("Unable to GET: %v", err)
}
greeting, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
return "", fmt.Errorf("Unable to read body: %v", err)
}
return fmt.Sprintf("%s", greeting), nil
}
func TestUnixSocketProxyPaths(t *testing.T) {
greeting := "Hello route %s"
tests := []struct {
url string
prefix string
expected string
}{
{"", "", fmt.Sprintf(greeting, "/")},
{"/hello", "", fmt.Sprintf(greeting, "/hello")},
{"/foo/bar", "", fmt.Sprintf(greeting, "/foo/bar")},
{"/foo?bar", "", fmt.Sprintf(greeting, "/foo?bar")},
{"/greet?name=john", "", fmt.Sprintf(greeting, "/greet?name=john")},
{"/world?wonderful&colorful", "", fmt.Sprintf(greeting, "/world?wonderful&colorful")},
{"/proxy/hello", "/proxy", fmt.Sprintf(greeting, "/hello")},
{"/proxy/foo/bar", "/proxy", fmt.Sprintf(greeting, "/foo/bar")},
{"/proxy/?foo=bar", "/proxy", fmt.Sprintf(greeting, "/?foo=bar")},
}
for _, test := range tests {
p, ts := GetHTTPProxy(greeting, test.prefix)
actualMsg, err := GetTestServerMessage(p, ts, test.url)
if err != nil {
t.Fatalf("Getting server message failed - %v", err)
}
if actualMsg != test.expected {
t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg)
}
}
if runtime.GOOS == "windows" {
return
}
for _, test := range tests {
p, ts, err := GetSocketProxy(greeting, test.prefix)
if err != nil {
t.Fatalf("Getting socket proxy failed - %v", err)
}
actualMsg, err := GetTestServerMessage(p, ts, test.url)
if err != nil {
t.Fatalf("Getting server message failed - %v", err)
}
if actualMsg != test.expected {
t.Errorf("Expected '%s' but got '%s' instead", test.expected, actualMsg)
}
}
}
func newFakeUpstream(name string, insecure bool) *fakeUpstream { func newFakeUpstream(name string, insecure bool) *fakeUpstream {
uri, _ := url.Parse(name) uri, _ := url.Parse(name)
u := &fakeUpstream{ u := &fakeUpstream{
@ -276,12 +386,19 @@ func (u *fakeUpstream) AllowedPath(requestPath string) bool {
// proxy. // proxy.
func newWebSocketTestProxy(backendAddr string) *Proxy { func newWebSocketTestProxy(backendAddr string) *Proxy {
return &Proxy{ return &Proxy{
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr}}, Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: ""}},
}
}
func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
return &Proxy{
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}},
} }
} }
type fakeWsUpstream struct { type fakeWsUpstream struct {
name string name string
without string
} }
func (u *fakeWsUpstream) From() string { func (u *fakeWsUpstream) From() string {
@ -292,7 +409,7 @@ func (u *fakeWsUpstream) Select() *UpstreamHost {
uri, _ := url.Parse(u.name) uri, _ := url.Parse(u.name)
return &UpstreamHost{ return &UpstreamHost{
Name: u.name, Name: u.name,
ReverseProxy: NewSingleHostReverseProxy(uri, ""), ReverseProxy: NewSingleHostReverseProxy(uri, u.without),
ExtraHeaders: http.Header{ ExtraHeaders: http.Header{
"Connection": {"{>Connection}"}, "Connection": {"{>Connection}"},
"Upgrade": {"{>Upgrade}"}}, "Upgrade": {"{>Upgrade}"}},

View file

@ -95,6 +95,18 @@ func NewSingleHostReverseProxy(target *url.URL, without string) *ReverseProxy {
} else { } else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
} }
// Trims the path of the socket from the URL path.
// This is done because req.URL passed to your proxied service
// will have the full path of the socket file prefixed to it.
// Calling /test on a server that proxies requests to
// unix:/var/run/www.socket will thus set the requested path
// to /var/run/www.socket/test, rendering paths useless.
if target.Scheme == "unix" {
// See comment on socketDial for the trim
socketPrefix := target.String()[len("unix://"):]
req.URL.Path = strings.TrimPrefix(req.URL.Path, socketPrefix)
}
// We are then safe to remove the `without` prefix.
if without != "" { if without != "" {
req.URL.Path = strings.TrimPrefix(req.URL.Path, without) req.URL.Path = strings.TrimPrefix(req.URL.Path, without)
} }