Add parseUpstream method

This commit is contained in:
Marc Guasch 2016-06-03 21:34:31 +02:00
parent 2536ea74d9
commit 1bdbf9d6ba
No known key found for this signature in database
GPG key ID: 3AC32D728BC3DDDF

View file

@ -1,6 +1,7 @@
package proxy package proxy
import ( import (
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
@ -56,17 +57,38 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
if !c.Args(&upstream.from) { if !c.Args(&upstream.from) {
return upstreams, c.ArgErr() return upstreams, c.ArgErr()
} }
to := c.RemainingArgs()
if len(to) == 0 { var to []string
return upstreams, c.ArgErr() for _, t := range c.RemainingArgs() {
parsed, err := parseUpstream(t)
if err != nil {
return upstreams, err
}
to = append(to, parsed...)
} }
for c.NextBlock() { for c.NextBlock() {
if err := parseBlock(&c, upstream); err != nil { switch c.Val() {
return upstreams, err case "upstream":
if !c.NextArg() {
return upstreams, c.ArgErr()
}
parsed, err := parseUpstream(c.Val())
if err != nil {
return upstreams, err
}
to = append(to, parsed...)
default:
if err := parseBlock(&c, upstream); err != nil {
return upstreams, err
}
} }
} }
if len(to) == 0 {
return upstreams, c.ArgErr()
}
upstream.Hosts = make([]*UpstreamHost, len(to)) upstream.Hosts = make([]*UpstreamHost, len(to))
for i, host := range to { for i, host := range to {
uh, err := upstream.NewHost(host) uh, err := upstream.NewHost(host)
@ -134,6 +156,45 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
return uh, nil return uh, nil
} }
func parseUpstream(u string) ([]string, error) {
if !strings.HasPrefix(u, "unix:") {
colonIdx := strings.LastIndex(u, ":")
protoIdx := strings.Index(u, "://")
if colonIdx != -1 && colonIdx != protoIdx {
us := u[:colonIdx]
ports := u[len(us)+1:]
if separators := strings.Count(ports, "-"); separators > 1 {
return nil, fmt.Errorf("port range [%s] is invalid", ports)
} else if separators == 1 {
portsStr := strings.Split(ports, "-")
pIni, err := strconv.Atoi(portsStr[0])
if err != nil {
return nil, err
}
pEnd, err := strconv.Atoi(portsStr[1])
if err != nil {
return nil, err
}
if pEnd <= pIni {
return nil, fmt.Errorf("port range [%s] is invalid", ports)
}
hosts := []string{}
for p := pIni; p <= pEnd; p++ {
hosts = append(hosts, fmt.Sprintf("%s:%d", us, p))
}
return hosts, nil
}
}
}
return []string{u}, nil
}
func parseBlock(c *parse.Dispenser, u *staticUpstream) error { func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
switch c.Val() { switch c.Val() {
case "policy": case "policy":