proxy: added 'health_check_port' to upstream (#1666)

* proxy: added 'health_check_port' to upstream

* proxy: `net.JoinHostPort` instead of `fmt.Printf` for upstream checks

* proxy: changing health_check_port type (int->string)

adding tests for invalid port config
This commit is contained in:
Lucas Fontes 2017-05-13 18:49:06 -04:00 committed by Matt Holt
parent 5f860d3a9f
commit 73494ce63a
2 changed files with 111 additions and 1 deletions

View file

@ -4,6 +4,7 @@ import (
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"path"
@ -42,6 +43,7 @@ type staticUpstream struct {
Interval time.Duration
Timeout time.Duration
Host string
Port string
}
WithoutPathPrefix string
IgnoredSubPaths []string
@ -321,6 +323,20 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
return err
}
u.HealthCheck.Timeout = dur
case "health_check_port":
if !c.NextArg() {
return c.ArgErr()
}
port := c.Val()
n, err := strconv.Atoi(port)
if err != nil {
return err
}
if n < 0 {
return c.Errf("invalid health_check_port '%s'", port)
}
u.HealthCheck.Port = c.Val()
case "header_upstream":
var header, value string
if !c.Args(&header, &value) {
@ -380,7 +396,12 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
func (u *staticUpstream) healthCheck() {
for _, host := range u.Hosts {
hostURL := host.Name + u.HealthCheck.Path
hostURL := host.Name
if u.HealthCheck.Port != "" {
hostURL = replacePort(host.Name, u.HealthCheck.Port)
}
hostURL += u.HealthCheck.Path
var unhealthy bool
// set up request, needed to be able to modify headers
@ -483,3 +504,19 @@ func (u *staticUpstream) Stop() error {
func RegisterPolicy(name string, policy func() Policy) {
supportedPolicies[name] = policy
}
func replacePort(originalURL string, newPort string) string {
parsedURL, err := url.Parse(originalURL)
if err != nil {
return originalURL
}
// handles 'localhost' and 'localhost:8080'
parsedHost, _, err := net.SplitHostPort(parsedURL.Host)
if err != nil {
parsedHost = parsedURL.Host
}
parsedURL.Host = net.JoinHostPort(parsedHost, newPort)
return parsedURL.String()
}

View file

@ -2,6 +2,7 @@ package proxy
import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"strings"
@ -375,3 +376,75 @@ func TestHealthCheckHost(t *testing.T) {
}
}
}
func TestHealthCheckPort(t *testing.T) {
var counter int64
healthCounter := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.Body.Close()
atomic.AddInt64(&counter, 1)
}))
_, healthPort, err := net.SplitHostPort(healthCounter.Listener.Addr().String())
if err != nil {
t.Fatal(err)
}
defer healthCounter.Close()
tests := []struct {
config string
}{
// Test #1: upstream with port
{"proxy / localhost:8080 {\n health_check / health_check_port " + healthPort + "\n}"},
// Test #2: upstream without port (default to 80)
{"proxy / localhost {\n health_check / health_check_port " + healthPort + "\n}"},
}
for i, test := range tests {
counterValueAtStart := atomic.LoadInt64(&counter)
upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "")
if err != nil {
t.Error("Expected no error. Got:", err.Error())
}
// Give some time for healthchecks to hit the server.
time.Sleep(500 * time.Millisecond)
for _, upstream := range upstreams {
if err := upstream.Stop(); err != nil {
t.Errorf("Test %d: Expected no error stopping upstream. Got: %v", i, err.Error())
}
}
counterValueAfterShutdown := atomic.LoadInt64(&counter)
if counterValueAfterShutdown == counterValueAtStart {
t.Errorf("Test %d: Expected healthchecks to hit test server. Got no healthchecks.", i)
}
}
t.Run("valid_port", func(t *testing.T) {
tests := []struct {
config string
}{
// Test #1: invalid port (nil)
{"proxy / localhost {\n health_check / health_check_port\n}"},
// Test #2: invalid port (string)
{"proxy / localhost {\n health_check / health_check_port abc\n}"},
// Test #3: invalid port (negative)
{"proxy / localhost {\n health_check / health_check_port -1\n}"},
}
for i, test := range tests {
_, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "")
if err == nil {
t.Errorf("Test %d accepted invalid config", i)
}
}
})
}