diff --git a/caddy/setup/proxy_test.go b/caddy/setup/proxy_test.go new file mode 100644 index 00000000..3d6d04a0 --- /dev/null +++ b/caddy/setup/proxy_test.go @@ -0,0 +1,136 @@ +package setup + +import ( + "reflect" + "testing" + + "github.com/mholt/caddy/middleware/proxy" +) + +func TestUpstream(t *testing.T) { + for i, test := range []struct { + input string + shouldErr bool + expectedHosts map[string]struct{} + }{ + // test #0 test usual to destination still works normally + { + "proxy / localhost:80", + false, + map[string]struct{}{ + "http://localhost:80": {}, + }, + }, + + // test #1 test usual to destination with port range + { + "proxy / localhost:8080-8082", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + "http://localhost:8081": {}, + "http://localhost:8082": {}, + }, + }, + + // test #2 test upstream directive + { + "proxy / {\n upstream localhost:8080\n}", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + }, + }, + + // test #3 test upstream directive with port range + { + "proxy / {\n upstream localhost:8080-8081\n}", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + "http://localhost:8081": {}, + }, + }, + + // test #4 test to destination with upstream directive + { + "proxy / localhost:8080 {\n upstream localhost:8081-8082\n}", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + "http://localhost:8081": {}, + "http://localhost:8082": {}, + }, + }, + + // test #5 test with unix sockets + { + "proxy / localhost:8080 {\n upstream unix:/var/foo\n}", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + "unix:/var/foo": {}, + }, + }, + + // test #6 test fail on malformed port range + { + "proxy / localhost:8090-8080", + true, + nil, + }, + + // test #7 test fail on malformed port range 2 + { + "proxy / {\n upstream localhost:80-A\n}", + true, + nil, + }, + + // test #8 test upstreams without ports work correctly + { + "proxy / http://localhost {\n upstream testendpoint\n}", + false, + map[string]struct{}{ + "http://localhost": {}, + "http://testendpoint": {}, + }, + }, + + // test #9 test several upstream directives + { + "proxy / localhost:8080 {\n upstream localhost:8081-8082\n upstream localhost:8083-8085\n}", + false, + map[string]struct{}{ + "http://localhost:8080": {}, + "http://localhost:8081": {}, + "http://localhost:8082": {}, + "http://localhost:8083": {}, + "http://localhost:8084": {}, + "http://localhost:8085": {}, + }, + }, + } { + receivedFunc, err := Proxy(NewTestController(test.input)) + if err != nil && !test.shouldErr { + t.Errorf("Test case #%d received an error of %v", i, err) + } else if test.shouldErr { + continue + } + + upstreams := receivedFunc(nil).(proxy.Proxy).Upstreams + for _, upstream := range upstreams { + val := reflect.ValueOf(upstream).Elem() + hosts := val.FieldByName("Hosts").Interface().(proxy.HostPool) + if len(hosts) != len(test.expectedHosts) { + t.Errorf("Test case #%d expected %d hosts but received %d", i, len(test.expectedHosts), len(hosts)) + } else { + for _, host := range hosts { + if _, found := test.expectedHosts[host.Name]; !found { + t.Errorf("Test case #%d has an unexpected host %s", i, host.Name) + } + } + } + } + } +} diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index e28db643..a1d9fcfc 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -1,6 +1,7 @@ package proxy import ( + "fmt" "io" "io/ioutil" "net/http" @@ -56,17 +57,38 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { if !c.Args(&upstream.from) { return upstreams, c.ArgErr() } - to := c.RemainingArgs() - if len(to) == 0 { - return upstreams, c.ArgErr() + + var to []string + for _, t := range c.RemainingArgs() { + parsed, err := parseUpstream(t) + if err != nil { + return upstreams, err + } + to = append(to, parsed...) } for c.NextBlock() { - if err := parseBlock(&c, upstream); err != nil { - return upstreams, err + switch c.Val() { + case "upstream": + if !c.NextArg() { + return upstreams, c.ArgErr() + } + parsed, err := parseUpstream(c.Val()) + if err != nil { + return upstreams, err + } + to = append(to, parsed...) + default: + if err := parseBlock(&c, upstream); err != nil { + return upstreams, err + } } } + if len(to) == 0 { + return upstreams, c.ArgErr() + } + upstream.Hosts = make([]*UpstreamHost, len(to)) for i, host := range to { uh, err := upstream.NewHost(host) @@ -134,6 +156,45 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) { return uh, nil } +func parseUpstream(u string) ([]string, error) { + if !strings.HasPrefix(u, "unix:") { + colonIdx := strings.LastIndex(u, ":") + protoIdx := strings.Index(u, "://") + + if colonIdx != -1 && colonIdx != protoIdx { + us := u[:colonIdx] + ports := u[len(us)+1:] + if separators := strings.Count(ports, "-"); separators > 1 { + return nil, fmt.Errorf("port range [%s] is invalid", ports) + } else if separators == 1 { + portsStr := strings.Split(ports, "-") + pIni, err := strconv.Atoi(portsStr[0]) + if err != nil { + return nil, err + } + + pEnd, err := strconv.Atoi(portsStr[1]) + if err != nil { + return nil, err + } + + if pEnd <= pIni { + return nil, fmt.Errorf("port range [%s] is invalid", ports) + } + + hosts := []string{} + for p := pIni; p <= pEnd; p++ { + hosts = append(hosts, fmt.Sprintf("%s:%d", us, p)) + } + return hosts, nil + } + } + } + + return []string{u}, nil + +} + func parseBlock(c *parse.Dispenser, u *staticUpstream) error { switch c.Val() { case "policy":