reverseproxy: Expand SRV/A addrs for cache key

Hopefully fix #4645
This commit is contained in:
Matthew Holt 2022-03-18 13:42:29 -06:00
parent 93c99f6734
commit dc4d147388
No known key found for this signature in database
GPG key ID: 2A349DD577D586A5

View file

@ -72,11 +72,6 @@ func (SRVUpstreams) CaddyModule() caddy.ModuleInfo {
}
}
// String returns the RFC 2782 representation of the SRV domain.
func (su SRVUpstreams) String() string {
return fmt.Sprintf("_%s._%s.%s", su.Service, su.Proto, su.Name)
}
func (su *SRVUpstreams) Provision(ctx caddy.Context) error {
su.logger = ctx.Logger(su)
if su.Refresh == 0 {
@ -109,11 +104,11 @@ func (su *SRVUpstreams) Provision(ctx caddy.Context) error {
}
func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
suStr := su.String()
suAddr, service, proto, name := su.expandedAddr(r)
// first, use a cheap read-lock to return a cached result quickly
srvsMu.RLock()
cached := srvs[suStr]
cached := srvs[suAddr]
srvsMu.RUnlock()
if cached.isFresh() {
return cached.upstreams, nil
@ -126,17 +121,11 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
// check to see if it's still stale, since we're now in a different
// lock from when we first checked freshness; another goroutine might
// have refreshed it in the meantime before we re-obtained our lock
cached = srvs[suStr]
cached = srvs[suAddr]
if cached.isFresh() {
return cached.upstreams, nil
}
// prepare parameters and perform the SRV lookup
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
service := repl.ReplaceAll(su.Service, "")
proto := repl.ReplaceAll(su.Proto, "")
name := repl.ReplaceAll(su.Name, "")
su.logger.Debug("refreshing SRV upstreams",
zap.String("service", service),
zap.String("proto", proto),
@ -172,7 +161,7 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
}
}
srvs[suStr] = srvLookup{
srvs[suAddr] = srvLookup{
srvUpstreams: su,
freshness: time.Now(),
upstreams: upstreams,
@ -181,6 +170,37 @@ func (su SRVUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
return upstreams, nil
}
func (su SRVUpstreams) String() string {
if su.Service == "" && su.Proto == "" {
return su.Name
}
return su.formattedAddr(su.Service, su.Proto, su.Name)
}
// expandedAddr expands placeholders in the configured SRV domain labels.
// The return values are: addr, the RFC 2782 representation of the SRV domain;
// service, the service; proto, the protocol; and name, the name.
// If su.Service and su.Proto are empty, name will be returned as addr instead.
func (su SRVUpstreams) expandedAddr(r *http.Request) (addr, service, proto, name string) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
name = repl.ReplaceAll(su.Name, "")
if su.Service == "" && su.Proto == "" {
addr = name
name = ""
return
}
service = repl.ReplaceAll(su.Service, "")
proto = repl.ReplaceAll(su.Proto, "")
addr = su.formattedAddr(service, proto, name)
return
}
// formattedAddr the RFC 2782 representation of the SRV domain, in
// the form "_service._proto.name".
func (SRVUpstreams) formattedAddr(service, proto, name string) string {
return fmt.Sprintf("_%s._%s.%s", service, proto, name)
}
type srvLookup struct {
srvUpstreams SRVUpstreams
freshness time.Time
@ -234,8 +254,6 @@ func (AUpstreams) CaddyModule() caddy.ModuleInfo {
}
}
func (au AUpstreams) String() string { return au.Name }
func (au *AUpstreams) Provision(_ caddy.Context) error {
if au.Refresh == 0 {
au.Refresh = caddy.Duration(time.Minute)
@ -270,7 +288,8 @@ func (au *AUpstreams) Provision(_ caddy.Context) error {
}
func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
auStr := au.String()
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
auStr := repl.ReplaceAll(au.String(), "")
// first, use a cheap read-lock to return a cached result quickly
aAaaaMu.RLock()
@ -292,7 +311,6 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
return cached.upstreams, nil
}
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
name := repl.ReplaceAll(au.Name, "")
port := repl.ReplaceAll(au.Port, "")
@ -325,6 +343,8 @@ func (au AUpstreams) GetUpstreams(r *http.Request) ([]*Upstream, error) {
return upstreams, nil
}
func (au AUpstreams) String() string { return au.Name }
type aLookup struct {
aUpstreams AUpstreams
freshness time.Time