Move listen address functions into caddy package; fix unix bug

This commit is contained in:
Matthew Holt 2019-07-08 16:46:38 -06:00
parent 4eb5fc541b
commit d25008d2c8
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5
4 changed files with 94 additions and 66 deletions

View file

@ -17,6 +17,8 @@ package caddy
import ( import (
"fmt" "fmt"
"net" "net"
"strconv"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -160,3 +162,73 @@ var (
listeners = make(map[string]*listenerUsage) listeners = make(map[string]*listenerUsage)
listenersMu sync.Mutex listenersMu sync.Mutex
) )
// ParseListenAddr parses addr, a string of the form "network/host:port"
// (with any part optional) into its component parts. Because a port can
// also be a port range, there may be multiple addresses returned.
func ParseListenAddr(addr string) (network string, addrs []string, err error) {
var host, port string
network, host, port, err = SplitListenAddr(addr)
if network == "" {
network = "tcp"
}
if err != nil {
return
}
if network == "unix" {
addrs = []string{host}
return
}
ports := strings.SplitN(port, "-", 2)
if len(ports) == 1 {
ports = append(ports, ports[0])
}
var start, end int
start, err = strconv.Atoi(ports[0])
if err != nil {
return
}
end, err = strconv.Atoi(ports[1])
if err != nil {
return
}
if end < start {
err = fmt.Errorf("end port must be greater than start port")
return
}
for p := start; p <= end; p++ {
addrs = append(addrs, net.JoinHostPort(host, fmt.Sprintf("%d", p)))
}
return
}
// SplitListenAddr splits a into its network, host, and port components.
// Note that port may be a port range, or omitted for unix sockets.
func SplitListenAddr(a string) (network, host, port string, err error) {
if idx := strings.Index(a, "/"); idx >= 0 {
network = strings.ToLower(strings.TrimSpace(a[:idx]))
a = a[idx+1:]
}
if network == "unix" {
host = a
return
}
host, port, err = net.SplitHostPort(a)
return
}
// JoinListenAddr combines network, host, and port into a single
// address string of the form "network/host:port". Port may be a
// port range. For unix sockets, the network should be "unix" and
// the path to the socket should be given in the host argument.
func JoinListenAddr(network, host, port string) string {
var a string
if network != "" {
a = network + "/"
}
a += host
if port != "" {
a += ":" + port
}
return a
}

View file

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package caddyhttp package caddy
import ( import (
"reflect" "reflect"
@ -62,8 +62,13 @@ func TestSplitListenerAddr(t *testing.T) {
expectNetwork: "udp", expectNetwork: "udp",
expectErr: true, expectErr: true,
}, },
{
input: "unix//foo/bar",
expectNetwork: "unix",
expectHost: "/foo/bar",
},
} { } {
actualNetwork, actualHost, actualPort, err := splitListenAddr(tc.input) actualNetwork, actualHost, actualPort, err := SplitListenAddr(tc.input)
if tc.expectErr && err == nil { if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err) t.Errorf("Test %d: Expected error but got: %v", i, err)
} }
@ -119,8 +124,12 @@ func TestJoinListenerAddr(t *testing.T) {
network: "udp", host: "", port: "1234", network: "udp", host: "", port: "1234",
expect: "udp/:1234", expect: "udp/:1234",
}, },
{
network: "unix", host: "/foo/bar", port: "",
expect: "unix//foo/bar",
},
} { } {
actual := joinListenAddr(tc.network, tc.host, tc.port) actual := JoinListenAddr(tc.network, tc.host, tc.port)
if actual != tc.expect { if actual != tc.expect {
t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual) t.Errorf("Test %d: Expected '%s' but got '%s'", i, tc.expect, actual)
} }
@ -165,9 +174,9 @@ func TestParseListenerAddr(t *testing.T) {
expectAddrs: []string{"localhost:1234"}, expectAddrs: []string{"localhost:1234"},
}, },
{ {
input: "unix/localhost:1234-1236", input: "unix//foo/bar",
expectNetwork: "unix", expectNetwork: "unix",
expectAddrs: []string{"localhost:1234", "localhost:1235", "localhost:1236"}, expectAddrs: []string{"/foo/bar"},
}, },
{ {
input: "localhost:1234-1234", input: "localhost:1234-1234",
@ -185,7 +194,7 @@ func TestParseListenerAddr(t *testing.T) {
expectAddrs: []string{"localhost:0"}, expectAddrs: []string{"localhost:0"},
}, },
} { } {
actualNetwork, actualAddrs, err := parseListenAddr(tc.input) actualNetwork, actualAddrs, err := ParseListenAddr(tc.input)
if tc.expectErr && err == nil { if tc.expectErr && err == nil {
t.Errorf("Test %d: Expected error but got: %v", i, err) t.Errorf("Test %d: Expected error but got: %v", i, err)
} }

View file

@ -96,7 +96,7 @@ func (app *App) Validate() error {
lnAddrs := make(map[string]string) lnAddrs := make(map[string]string)
for srvName, srv := range app.Servers { for srvName, srv := range app.Servers {
for _, addr := range srv.Listen { for _, addr := range srv.Listen {
netw, expanded, err := parseListenAddr(addr) netw, expanded, err := caddy.ParseListenAddr(addr)
if err != nil { if err != nil {
return fmt.Errorf("invalid listener address '%s': %v", addr, err) return fmt.Errorf("invalid listener address '%s': %v", addr, err)
} }
@ -137,7 +137,7 @@ func (app *App) Start() error {
} }
for _, lnAddr := range srv.Listen { for _, lnAddr := range srv.Listen {
network, addrs, err := parseListenAddr(lnAddr) network, addrs, err := caddy.ParseListenAddr(lnAddr)
if err != nil { if err != nil {
return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err) return fmt.Errorf("%s: parsing listen address '%s': %v", srvName, lnAddr, err)
} }
@ -289,7 +289,7 @@ func (app *App) automaticHTTPS() error {
// create HTTP->HTTPS redirects // create HTTP->HTTPS redirects
for _, addr := range srv.Listen { for _, addr := range srv.Listen {
netw, host, port, err := splitListenAddr(addr) netw, host, port, err := caddy.SplitListenAddr(addr)
if err != nil { if err != nil {
return fmt.Errorf("%s: invalid listener address: %v", srvName, addr) return fmt.Errorf("%s: invalid listener address: %v", srvName, addr)
} }
@ -298,7 +298,7 @@ func (app *App) automaticHTTPS() error {
if httpPort == 0 { if httpPort == 0 {
httpPort = DefaultHTTPPort httpPort = DefaultHTTPPort
} }
httpRedirLnAddr := joinListenAddr(netw, host, strconv.Itoa(httpPort)) httpRedirLnAddr := caddy.JoinListenAddr(netw, host, strconv.Itoa(httpPort))
lnAddrMap[httpRedirLnAddr] = struct{}{} lnAddrMap[httpRedirLnAddr] = struct{}{}
if parts := strings.SplitN(port, "-", 2); len(parts) == 2 { if parts := strings.SplitN(port, "-", 2); len(parts) == 2 {
@ -339,7 +339,7 @@ func (app *App) automaticHTTPS() error {
var lnAddrs []string var lnAddrs []string
mapLoop: mapLoop:
for addr := range lnAddrMap { for addr := range lnAddrMap {
netw, addrs, err := parseListenAddr(addr) netw, addrs, err := caddy.ParseListenAddr(addr)
if err != nil { if err != nil {
continue continue
} }
@ -364,7 +364,7 @@ func (app *App) automaticHTTPS() error {
func (app *App) listenerTaken(network, address string) bool { func (app *App) listenerTaken(network, address string) bool {
for _, srv := range app.Servers { for _, srv := range app.Servers {
for _, addr := range srv.Listen { for _, addr := range srv.Listen {
netw, addrs, err := parseListenAddr(addr) netw, addrs, err := caddy.ParseListenAddr(addr)
if err != nil || netw != network { if err != nil || netw != network {
continue continue
} }
@ -425,59 +425,6 @@ func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
// sometimes better than a nil Handler pointer. // sometimes better than a nil Handler pointer.
var emptyHandler HandlerFunc = func(w http.ResponseWriter, r *http.Request) error { return nil } var emptyHandler HandlerFunc = func(w http.ResponseWriter, r *http.Request) error { return nil }
func parseListenAddr(a string) (network string, addrs []string, err error) {
var host, port string
network, host, port, err = splitListenAddr(a)
if network == "" {
network = "tcp"
}
if err != nil {
return
}
ports := strings.SplitN(port, "-", 2)
if len(ports) == 1 {
ports = append(ports, ports[0])
}
var start, end int
start, err = strconv.Atoi(ports[0])
if err != nil {
return
}
end, err = strconv.Atoi(ports[1])
if err != nil {
return
}
if end < start {
err = fmt.Errorf("end port must be greater than start port")
return
}
for p := start; p <= end; p++ {
addrs = append(addrs, net.JoinHostPort(host, fmt.Sprintf("%d", p)))
}
return
}
func splitListenAddr(a string) (network, host, port string, err error) {
if idx := strings.Index(a, "/"); idx >= 0 {
network = strings.ToLower(strings.TrimSpace(a[:idx]))
a = a[idx+1:]
}
host, port, err = net.SplitHostPort(a)
return
}
func joinListenAddr(network, host, port string) string {
var a string
if network != "" {
a = network + "/"
}
a += host
if port != "" {
a += ":" + port
}
return a
}
const ( const (
// DefaultHTTPPort is the default port for HTTP. // DefaultHTTPPort is the default port for HTTP.
DefaultHTTPPort = 80 DefaultHTTPPort = 80

View file

@ -156,7 +156,7 @@ func (s *Server) enforcementHandler(w http.ResponseWriter, r *http.Request, next
func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool { func (s *Server) listenersUseAnyPortOtherThan(otherPort int) bool {
for _, lnAddr := range s.Listen { for _, lnAddr := range s.Listen {
_, addrs, err := parseListenAddr(lnAddr) _, addrs, err := caddy.ParseListenAddr(lnAddr)
if err == nil { if err == nil {
for _, a := range addrs { for _, a := range addrs {
_, port, err := net.SplitHostPort(a) _, port, err := net.SplitHostPort(a)