diff --git a/listen_unix.go b/listen_unix.go index 7011a1e6..e31ecac7 100644 --- a/listen_unix.go +++ b/listen_unix.go @@ -108,6 +108,15 @@ func listenTCPOrUnix(ctx context.Context, lnKey string, network, address string, listenerPool.LoadOrStore(lnKey, nil) } + // if new listener is a unix socket, make sure we can reuse it later + // (we do our own "unlink on close" -- not required, but more tidy) + one := int32(1) + if unix, ok := ln.(*net.UnixListener); ok { + unix.SetUnlinkOnClose(false) + ln = &unixListener{unix, lnKey, &one} + unixSockets[lnKey] = ln.(*unixListener) + } + // lightly wrap the listener so that when it is closed, // we can decrement the usage pool counter return deleteListener{ln, lnKey}, err diff --git a/listeners.go b/listeners.go index 1429b14e..08bdbcf7 100644 --- a/listeners.go +++ b/listeners.go @@ -186,19 +186,11 @@ func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net return nil, fmt.Errorf("unsupported network type: %s", na.Network) } - // if new listener is a unix socket, make sure we can reuse it later - // (we do our own "unlink on close" -- not required, but more tidy) - one := int32(1) - switch lnValue := ln.(type) { - case deleteListener: - if unix, ok := lnValue.Listener.(*net.UnixListener); ok { - unix.SetUnlinkOnClose(false) - ln = &unixListener{unix, lnKey, &one} - unixSockets[lnKey] = ln.(*unixListener) - } - case *net.UnixConn: - ln = &unixConn{lnValue, address, lnKey, &one} - unixSockets[lnKey] = ln.(*unixConn) + // TODO: Not 100% sure this is necessary, but we do this for net.UnixListener in listen_unix.go, so... + if unix, ok := ln.(*net.UnixConn); ok { + one := int32(1) + ln = &unixConn{unix, address, lnKey, &one} + unixSockets[lnKey] = unix } return ln, nil