From d25008d2c8e2eb5f96b2b37a1cca5b4e140cfe8d Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Mon, 8 Jul 2019 16:46:38 -0600 Subject: [PATCH] Move listen address functions into caddy package; fix unix bug --- listeners.go | 72 +++++++++++++++++++ .../caddyhttp_test.go => listeners_test.go | 21 ++++-- modules/caddyhttp/caddyhttp.go | 65 ++--------------- modules/caddyhttp/server.go | 2 +- 4 files changed, 94 insertions(+), 66 deletions(-) rename modules/caddyhttp/caddyhttp_test.go => listeners_test.go (90%) diff --git a/listeners.go b/listeners.go index 28adc8b9..d97674d8 100644 --- a/listeners.go +++ b/listeners.go @@ -17,6 +17,8 @@ package caddy import ( "fmt" "net" + "strconv" + "strings" "sync" "sync/atomic" "time" @@ -160,3 +162,73 @@ var ( listeners = make(map[string]*listenerUsage) listenersMu sync.Mutex ) + +// ParseListenAddr parses addr, a string of the form "network/host:port" +// (with any part optional) into its component parts. Because a port can +// also be a port range, there may be multiple addresses returned. +func ParseListenAddr(addr string) (network string, addrs []string, err error) { + var host, port string + network, host, port, err = SplitListenAddr(addr) + if network == "" { + network = "tcp" + } + if err != nil { + return + } + if network == "unix" { + addrs = []string{host} + return + } + ports := strings.SplitN(port, "-", 2) + if len(ports) == 1 { + ports = append(ports, ports[0]) + } + var start, end int + start, err = strconv.Atoi(ports[0]) + if err != nil { + return + } + end, err = strconv.Atoi(ports[1]) + if err != nil { + return + } + if end < start { + err = fmt.Errorf("end port must be greater than start port") + return + } + for p := start; p <= end; p++ { + addrs = append(addrs, net.JoinHostPort(host, fmt.Sprintf("%d", p))) + } + return +} + +// SplitListenAddr splits a into its network, host, and port components. +// Note that port may be a port range, or omitted for unix sockets. +func SplitListenAddr(a string) (network, host, port string, err error) { + if idx := strings.Index(a, "/"); idx >= 0 { + network = strings.ToLower(strings.TrimSpace(a[:idx])) + a = a[idx+1:] + } + if network == "unix" { + host = a + return + } + host, port, err = net.SplitHostPort(a) + return +} + +// JoinListenAddr combines network, host, and port into a single +// address string of the form "network/host:port". Port may be a +// port range. For unix sockets, the network should be "unix" and +// the path to the socket should be given in the host argument. +func JoinListenAddr(network, host, port string) string { + var a string + if network != "" { + a = network + "/" + } + a += host + if port != "" { + a += ":" + port + } + return a +} diff --git a/modules/caddyhttp/caddyhttp_test.go b/listeners_test.go similarity index 90% rename from modules/caddyhttp/caddyhttp_test.go rename to listeners_test.go index 3c0c2f46..7c5a2fcf 100644 --- a/modules/caddyhttp/caddyhttp_test.go +++ b/listeners_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package caddyhttp +package caddy import ( "reflect" @@ -62,8 +62,13 @@ func TestSplitListenerAddr(t *testing.T) { expectNetwork: "udp", expectErr: true, }, + { + input: "unix//foo/bar", + expectNetwork: "unix", + expectHost: "/foo/bar", + }, } { - actualNetwork, actualHost, actualPort, err := splitListenAddr(tc.input) + actualNetwork, actualHost, actualPort, err := SplitListenAddr(tc.input) if tc.expectErr && err == nil { t.Errorf("Test %d: Expected error but got: %v", i, err) } @@ -119,8 +124,12 @@ func TestJoinListenerAddr(t *testing.T) { network: "udp", host: "", port: "1234", expect: "udp/:1234", }, + { + network: "unix", host: "/foo/bar", port: "", + expect: "unix//foo/bar", + }, } { - actual := joinListenAddr(tc.network, tc.host, tc.port) + actual := JoinListenAddr(tc.network, tc.host, tc.port) if actual != tc.expect { t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual) } @@ -165,9 +174,9 @@ func TestParseListenerAddr(t *testing.T) { expectAddrs: []string{"localhost:1234"}, }, { - input: "unix/localhost:1234-1236", + input: "unix//foo/bar", expectNetwork: "unix", - expectAddrs: []string{"localhost:1234", "localhost:1235", "localhost:1236"}, + expectAddrs: []string{"/foo/bar"}, }, { input: "localhost:1234-1234", @@ -185,7 +194,7 @@ func TestParseListenerAddr(t *testing.T) { expectAddrs: []string{"localhost:0"}, }, } { - actualNetwork, actualAddrs, err := parseListenAddr(tc.input) + actualNetwork, actualAddrs, err := ParseListenAddr(tc.input) if tc.expectErr && err == nil { t.Errorf("Test %d: Expected error but got: %v", i, err) } diff --git a/modules/caddyhttp/caddyhttp.go b/modules/caddyhttp/caddyhttp.go index 3277e429..d0c75401 100644 --- a/modules/caddyhttp/caddyhttp.go +++ b/modules/caddyhttp/caddyhttp.go @@ -96,7 +96,7 @@ func (app *App) Validate() error { lnAddrs := make(map[string]string) for srvName, srv := range app.Servers { for _, addr := range srv.Listen { - netw, expanded, err := parseListenAddr(addr) + netw, expanded, err := caddy.ParseListenAddr(addr) if err != nil { return fmt.Errorf("invalid listener address '%s': %v", addr, err) } @@ -137,7 +137,7 @@ func (app *App) Start() error { } for _, lnAddr := range srv.Listen { - network, addrs, err := parseListenAddr(lnAddr) + network, addrs, err := caddy.ParseListenAddr(lnAddr) if err != nil { return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err) } @@ -289,7 +289,7 @@ func (app *App) automaticHTTPS() error { // create HTTP->HTTPS redirects for _, addr := range srv.Listen { - netw, host, port, err := splitListenAddr(addr) + netw, host, port, err := caddy.SplitListenAddr(addr) if err != nil { return fmt.Errorf("%s: invalid listener address: %v", srvName, addr) } @@ -298,7 +298,7 @@ func (app *App) automaticHTTPS() error { if httpPort == 0 { httpPort = DefaultHTTPPort } - httpRedirLnAddr := joinListenAddr(netw, host, strconv.Itoa(httpPort)) + httpRedirLnAddr := caddy.JoinListenAddr(netw, host, strconv.Itoa(httpPort)) lnAddrMap[httpRedirLnAddr] = struct{}{} if parts := strings.SplitN(port, "-", 2); len(parts) == 2 { @@ -339,7 +339,7 @@ func (app *App) automaticHTTPS() error { var lnAddrs []string mapLoop: for addr := range lnAddrMap { - netw, addrs, err := parseListenAddr(addr) + netw, addrs, err := caddy.ParseListenAddr(addr) if err != nil { continue } @@ -364,7 +364,7 @@ func (app *App) automaticHTTPS() error { func (app *App) listenerTaken(network, address string) bool { for _, srv := range app.Servers { for _, addr := range srv.Listen { - netw, addrs, err := parseListenAddr(addr) + netw, addrs, err := caddy.ParseListenAddr(addr) if err != nil || netw != network { continue } @@ -425,59 +425,6 @@ func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) error { // sometimes better than a nil Handler pointer. var emptyHandler HandlerFunc = func(w http.ResponseWriter, r *http.Request) error { return nil } -func parseListenAddr(a string) (network string, addrs []string, err error) { - var host, port string - network, host, port, err = splitListenAddr(a) - if network == "" { - network = "tcp" - } - if err != nil { - return - } - ports := strings.SplitN(port, "-", 2) - if len(ports) == 1 { - ports = append(ports, ports[0]) - } - var start, end int - start, err = strconv.Atoi(ports[0]) - if err != nil { - return - } - end, err = strconv.Atoi(ports[1]) - if err != nil { - return - } - if end < start { - err = fmt.Errorf("end port must be greater than start port") - return - } - for p := start; p <= end; p++ { - addrs = append(addrs, net.JoinHostPort(host, fmt.Sprintf("%d", p))) - } - return -} - -func splitListenAddr(a string) (network, host, port string, err error) { - if idx := strings.Index(a, "/"); idx >= 0 { - network = strings.ToLower(strings.TrimSpace(a[:idx])) - a = a[idx+1:] - } - host, port, err = net.SplitHostPort(a) - return -} - -func joinListenAddr(network, host, port string) string { - var a string - if network != "" { - a = network + "/" - } - a += host - if port != "" { - a += ":" + port - } - return a -} - const ( // DefaultHTTPPort is the default port for HTTP. DefaultHTTPPort = 80 diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index 14601520..d40a01dc 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -156,7 +156,7 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool { for _, lnAddr := range s.Listen { - _, addrs, err := parseListenAddr(lnAddr) + _, addrs, err := caddy.ParseListenAddr(lnAddr) if err == nil { for _, a := range addrs { _, port, err := net.SplitHostPort(a)