mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-21 01:45:45 +03:00
Merge branch 'master' into macros
This commit is contained in:
commit
8658e189e1
15 changed files with 648 additions and 167 deletions
5
caddy.go
5
caddy.go
|
@ -518,6 +518,11 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
|
||||||
}
|
}
|
||||||
if !Quiet {
|
if !Quiet {
|
||||||
for _, srvln := range inst.servers {
|
for _, srvln := range inst.servers {
|
||||||
|
// only show FD notice if the listener is not nil.
|
||||||
|
// This can happen when only serving UDP or TCP
|
||||||
|
if srvln.listener == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !IsLoopback(srvln.listener.Addr().String()) {
|
if !IsLoopback(srvln.listener.Addr().String()) {
|
||||||
checkFdlimit()
|
checkFdlimit()
|
||||||
break
|
break
|
||||||
|
|
|
@ -214,6 +214,9 @@ func SameNext(next1, next2 Handler) bool {
|
||||||
|
|
||||||
// Context key constants.
|
// Context key constants.
|
||||||
const (
|
const (
|
||||||
|
// ReplacerCtxKey is the context key for a per-request replacer.
|
||||||
|
ReplacerCtxKey caddy.CtxKey = "replacer"
|
||||||
|
|
||||||
// RemoteUserCtxKey is the key for the remote user of the request, if any (basicauth).
|
// RemoteUserCtxKey is the key for the remote user of the request, if any (basicauth).
|
||||||
RemoteUserCtxKey caddy.CtxKey = "remote_user"
|
RemoteUserCtxKey caddy.CtxKey = "remote_user"
|
||||||
|
|
||||||
|
|
|
@ -102,20 +102,30 @@ func (lw *limitWriter) String() string {
|
||||||
// emptyValue should be the string that is used in place
|
// emptyValue should be the string that is used in place
|
||||||
// of empty string (can still be empty string).
|
// of empty string (can still be empty string).
|
||||||
func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer {
|
func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer {
|
||||||
rb := newLimitWriter(MaxLogBodySize)
|
repl := &replacer{
|
||||||
if r.Body != nil {
|
request: r,
|
||||||
r.Body = struct {
|
responseRecorder: rr,
|
||||||
io.Reader
|
emptyValue: emptyValue,
|
||||||
io.Closer
|
|
||||||
}{io.TeeReader(r.Body, rb), io.Closer(r.Body)}
|
|
||||||
}
|
}
|
||||||
return &replacer{
|
|
||||||
request: r,
|
// extract customReplacements from a request replacer when present.
|
||||||
requestBody: rb,
|
if existing, ok := r.Context().Value(ReplacerCtxKey).(*replacer); ok {
|
||||||
responseRecorder: rr,
|
repl.requestBody = existing.requestBody
|
||||||
customReplacements: make(map[string]string),
|
repl.customReplacements = existing.customReplacements
|
||||||
emptyValue: emptyValue,
|
} else {
|
||||||
|
// if there is no existing replacer, build one from scratch.
|
||||||
|
rb := newLimitWriter(MaxLogBodySize)
|
||||||
|
if r.Body != nil {
|
||||||
|
r.Body = struct {
|
||||||
|
io.Reader
|
||||||
|
io.Closer
|
||||||
|
}{io.TeeReader(r.Body, rb), io.Closer(r.Body)}
|
||||||
|
}
|
||||||
|
repl.requestBody = rb
|
||||||
|
repl.customReplacements = make(map[string]string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return repl
|
||||||
}
|
}
|
||||||
|
|
||||||
func canLogRequest(r *http.Request) bool {
|
func canLogRequest(r *http.Request) bool {
|
||||||
|
|
|
@ -356,6 +356,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
c := context.WithValue(r.Context(), OriginalURLCtxKey, urlCopy)
|
c := context.WithValue(r.Context(), OriginalURLCtxKey, urlCopy)
|
||||||
r = r.WithContext(c)
|
r = r.WithContext(c)
|
||||||
|
|
||||||
|
// Setup a replacer for the request that keeps track of placeholder
|
||||||
|
// values across plugins.
|
||||||
|
replacer := NewReplacer(r, nil, "")
|
||||||
|
c = context.WithValue(r.Context(), ReplacerCtxKey, replacer)
|
||||||
|
r = r.WithContext(c)
|
||||||
|
|
||||||
w.Header().Set("Server", caddy.AppName)
|
w.Header().Set("Server", caddy.AppName)
|
||||||
|
|
||||||
status, _ := s.serveHTTP(w, r)
|
status, _ := s.serveHTTP(w, r)
|
||||||
|
|
|
@ -82,7 +82,8 @@ type UpstreamHost struct {
|
||||||
// This is an int32 so that we can use atomic operations to do concurrent
|
// This is an int32 so that we can use atomic operations to do concurrent
|
||||||
// reads & writes to this value. The default value of 0 indicates that it
|
// reads & writes to this value. The default value of 0 indicates that it
|
||||||
// is healthy and any non-zero value indicates unhealthy.
|
// is healthy and any non-zero value indicates unhealthy.
|
||||||
Unhealthy int32
|
Unhealthy int32
|
||||||
|
HealthCheckResult atomic.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Down checks whether the upstream host is down or not.
|
// Down checks whether the upstream host is down or not.
|
||||||
|
|
|
@ -26,7 +26,9 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -91,6 +93,8 @@ type ReverseProxy struct {
|
||||||
// response body.
|
// response body.
|
||||||
// If zero, no periodic flushing is done.
|
// If zero, no periodic flushing is done.
|
||||||
FlushInterval time.Duration
|
FlushInterval time.Duration
|
||||||
|
|
||||||
|
srvResolver srvResolver
|
||||||
}
|
}
|
||||||
|
|
||||||
// Though the relevant directive prefix is just "unix:", url.Parse
|
// Though the relevant directive prefix is just "unix:", url.Parse
|
||||||
|
@ -105,6 +109,23 @@ func socketDial(hostName string) func(network, addr string) (conn net.Conn, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) {
|
||||||
|
service := locator
|
||||||
|
if strings.HasPrefix(locator, "srv://") {
|
||||||
|
service = locator[6:]
|
||||||
|
} else if strings.HasPrefix(locator, "srv+https://") {
|
||||||
|
service = locator[12:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(network, addr string) (conn net.Conn, err error) {
|
||||||
|
_, addrs, err := rp.srvResolver.LookupSRV(context.Background(), "", "", service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func singleJoiningSlash(a, b string) string {
|
func singleJoiningSlash(a, b string) string {
|
||||||
aslash := strings.HasSuffix(a, "/")
|
aslash := strings.HasSuffix(a, "/")
|
||||||
bslash := strings.HasPrefix(b, "/")
|
bslash := strings.HasPrefix(b, "/")
|
||||||
|
@ -131,6 +152,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||||
// scheme and host have to be faked
|
// scheme and host have to be faked
|
||||||
req.URL.Scheme = "http"
|
req.URL.Scheme = "http"
|
||||||
req.URL.Host = "socket"
|
req.URL.Host = "socket"
|
||||||
|
} else if target.Scheme == "srv" {
|
||||||
|
req.URL.Scheme = "http"
|
||||||
|
req.URL.Host = target.Host
|
||||||
|
} else if target.Scheme == "srv+https" {
|
||||||
|
req.URL.Scheme = "https"
|
||||||
|
req.URL.Host = target.Host
|
||||||
} else {
|
} else {
|
||||||
req.URL.Scheme = target.Scheme
|
req.URL.Scheme = target.Scheme
|
||||||
req.URL.Host = target.Host
|
req.URL.Host = target.Host
|
||||||
|
@ -199,7 +226,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events
|
rp := &ReverseProxy{
|
||||||
|
Director: director,
|
||||||
|
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
|
||||||
|
srvResolver: net.DefaultResolver,
|
||||||
|
}
|
||||||
|
|
||||||
if target.Scheme == "unix" {
|
if target.Scheme == "unix" {
|
||||||
rp.Transport = &http.Transport{
|
rp.Transport = &http.Transport{
|
||||||
Dial: socketDial(target.String()),
|
Dial: socketDial(target.String()),
|
||||||
|
@ -210,13 +242,15 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||||
HandshakeTimeout: defaultCryptoHandshakeTimeout,
|
HandshakeTimeout: defaultCryptoHandshakeTimeout,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
} else if keepalive != http.DefaultMaxIdleConnsPerHost {
|
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
|
||||||
// if keepalive is equal to the default,
|
dialFunc := defaultDialer.Dial
|
||||||
// just use default transport, to avoid creating
|
if strings.HasPrefix(target.Scheme, "srv") {
|
||||||
// a brand new transport
|
dialFunc = rp.srvDialerFunc(target.String())
|
||||||
|
}
|
||||||
|
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
Dial: defaultDialer.Dial,
|
Dial: dialFunc,
|
||||||
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
|
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
94
caddyhttp/proxy/reverseproxy_test.go
Normal file
94
caddyhttp/proxy/reverseproxy_test.go
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
// Copyright 2015 Light Code Labs, LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
expectedResponse = "response from request proxied to upstream"
|
||||||
|
expectedStatus = http.StatusOK
|
||||||
|
)
|
||||||
|
|
||||||
|
var upstreamHost *httptest.Server
|
||||||
|
|
||||||
|
func setupTest() {
|
||||||
|
upstreamHost = httptest.NewServer(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/test-path" {
|
||||||
|
w.WriteHeader(expectedStatus)
|
||||||
|
w.Write([]byte(expectedResponse))
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
w.Write([]byte("Not found"))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func tearDownTest() {
|
||||||
|
upstreamHost.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingleSRVHostReverseProxy(t *testing.T) {
|
||||||
|
setupTest()
|
||||||
|
defer tearDownTest()
|
||||||
|
|
||||||
|
target, err := url.Parse("srv://test.upstream.service")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to parse target URL. %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
upstream, err := url.Parse(upstreamHost.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to parse test server URL [%s]. %s", upstreamHost.URL, err.Error())
|
||||||
|
}
|
||||||
|
pp, err := strconv.Atoi(upstream.Port())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to parse upstream server port [%s]. %s", upstream.Port(), err.Error())
|
||||||
|
}
|
||||||
|
port := uint16(pp)
|
||||||
|
|
||||||
|
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost)
|
||||||
|
rp.srvResolver = testResolver{
|
||||||
|
result: []*net.SRV{
|
||||||
|
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
req, err := http.NewRequest("GET", "http://test.host/test-path", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to create new request. %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rp.ServeHTTP(resp, req, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to perform reverse proxy to upstream host. %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Body.String() != expectedResponse {
|
||||||
|
t.Errorf("Unexpected proxy response received. Expected: '%s', Got: '%s'", expectedResponse, resp.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Code != expectedStatus {
|
||||||
|
t.Errorf("Unexpected proxy status. Expected: '%d', Got: '%d'", expectedStatus, resp.Code)
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,6 +16,7 @@ package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -65,6 +66,11 @@ type staticUpstream struct {
|
||||||
IgnoredSubPaths []string
|
IgnoredSubPaths []string
|
||||||
insecureSkipVerify bool
|
insecureSkipVerify bool
|
||||||
MaxFails int32
|
MaxFails int32
|
||||||
|
resolver srvResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
type srvResolver interface {
|
||||||
|
LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStaticUpstreams parses the configuration input and sets up
|
// NewStaticUpstreams parses the configuration input and sets up
|
||||||
|
@ -86,6 +92,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
|
||||||
TryInterval: 250 * time.Millisecond,
|
TryInterval: 250 * time.Millisecond,
|
||||||
MaxConns: 0,
|
MaxConns: 0,
|
||||||
KeepAlive: http.DefaultMaxIdleConnsPerHost,
|
KeepAlive: http.DefaultMaxIdleConnsPerHost,
|
||||||
|
resolver: net.DefaultResolver,
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.Args(&upstream.from) {
|
if !c.Args(&upstream.from) {
|
||||||
|
@ -93,7 +100,21 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
var to []string
|
var to []string
|
||||||
|
hasSrv := false
|
||||||
|
|
||||||
for _, t := range c.RemainingArgs() {
|
for _, t := range c.RemainingArgs() {
|
||||||
|
if len(to) > 0 && hasSrv {
|
||||||
|
return upstreams, c.Err("only one upstream is supported when using SRV locator")
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(t, "srv://") || strings.HasPrefix(t, "srv+https://") {
|
||||||
|
if len(to) > 0 {
|
||||||
|
return upstreams, c.Err("service locator upstreams can not be mixed with host names")
|
||||||
|
}
|
||||||
|
|
||||||
|
hasSrv = true
|
||||||
|
}
|
||||||
|
|
||||||
parsed, err := parseUpstream(t)
|
parsed, err := parseUpstream(t)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return upstreams, err
|
return upstreams, err
|
||||||
|
@ -107,13 +128,18 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
|
||||||
if !c.NextArg() {
|
if !c.NextArg() {
|
||||||
return upstreams, c.ArgErr()
|
return upstreams, c.ArgErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hasSrv {
|
||||||
|
return upstreams, c.Err("upstream directive is not supported when backend is service locator")
|
||||||
|
}
|
||||||
|
|
||||||
parsed, err := parseUpstream(c.Val())
|
parsed, err := parseUpstream(c.Val())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return upstreams, err
|
return upstreams, err
|
||||||
}
|
}
|
||||||
to = append(to, parsed...)
|
to = append(to, parsed...)
|
||||||
default:
|
default:
|
||||||
if err := parseBlock(&c, upstream); err != nil {
|
if err := parseBlock(&c, upstream, hasSrv); err != nil {
|
||||||
return upstreams, err
|
return upstreams, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,7 +191,9 @@ func (u *staticUpstream) From() string {
|
||||||
func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
|
func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
|
||||||
if !strings.HasPrefix(host, "http") &&
|
if !strings.HasPrefix(host, "http") &&
|
||||||
!strings.HasPrefix(host, "unix:") &&
|
!strings.HasPrefix(host, "unix:") &&
|
||||||
!strings.HasPrefix(host, "quic:") {
|
!strings.HasPrefix(host, "quic:") &&
|
||||||
|
!strings.HasPrefix(host, "srv://") &&
|
||||||
|
!strings.HasPrefix(host, "srv+https://") {
|
||||||
host = "http://" + host
|
host = "http://" + host
|
||||||
}
|
}
|
||||||
uh := &UpstreamHost{
|
uh := &UpstreamHost{
|
||||||
|
@ -189,6 +217,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
|
||||||
}(u),
|
}(u),
|
||||||
WithoutPathPrefix: u.WithoutPathPrefix,
|
WithoutPathPrefix: u.WithoutPathPrefix,
|
||||||
MaxConns: u.MaxConns,
|
MaxConns: u.MaxConns,
|
||||||
|
HealthCheckResult: atomic.Value{},
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL, err := url.Parse(uh.Name)
|
baseURL, err := url.Parse(uh.Name)
|
||||||
|
@ -205,50 +234,65 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseUpstream(u string) ([]string, error) {
|
func parseUpstream(u string) ([]string, error) {
|
||||||
if !strings.HasPrefix(u, "unix:") {
|
if strings.HasPrefix(u, "unix:") {
|
||||||
colonIdx := strings.LastIndex(u, ":")
|
return []string{u}, nil
|
||||||
protoIdx := strings.Index(u, "://")
|
|
||||||
|
|
||||||
if colonIdx != -1 && colonIdx != protoIdx {
|
|
||||||
us := u[:colonIdx]
|
|
||||||
ue := ""
|
|
||||||
portsEnd := len(u)
|
|
||||||
if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 {
|
|
||||||
portsEnd = colonIdx + nextSlash
|
|
||||||
ue = u[portsEnd:]
|
|
||||||
}
|
|
||||||
ports := u[len(us)+1 : portsEnd]
|
|
||||||
|
|
||||||
if separators := strings.Count(ports, "-"); separators == 1 {
|
|
||||||
portsStr := strings.Split(ports, "-")
|
|
||||||
pIni, err := strconv.Atoi(portsStr[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pEnd, err := strconv.Atoi(portsStr[1])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if pEnd <= pIni {
|
|
||||||
return nil, fmt.Errorf("port range [%s] is invalid", ports)
|
|
||||||
}
|
|
||||||
|
|
||||||
hosts := []string{}
|
|
||||||
for p := pIni; p <= pEnd; p++ {
|
|
||||||
hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue))
|
|
||||||
}
|
|
||||||
return hosts, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{u}, nil
|
isSrv := strings.HasPrefix(u, "srv://") || strings.HasPrefix(u, "srv+https://")
|
||||||
|
colonIdx := strings.LastIndex(u, ":")
|
||||||
|
protoIdx := strings.Index(u, "://")
|
||||||
|
|
||||||
|
if colonIdx == -1 || colonIdx == protoIdx {
|
||||||
|
return []string{u}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isSrv {
|
||||||
|
return nil, fmt.Errorf("service locator %s can not have port specified", u)
|
||||||
|
}
|
||||||
|
|
||||||
|
us := u[:colonIdx]
|
||||||
|
ue := ""
|
||||||
|
portsEnd := len(u)
|
||||||
|
if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 {
|
||||||
|
portsEnd = colonIdx + nextSlash
|
||||||
|
ue = u[portsEnd:]
|
||||||
|
}
|
||||||
|
|
||||||
|
ports := u[len(us)+1 : portsEnd]
|
||||||
|
separators := strings.Count(ports, "-")
|
||||||
|
|
||||||
|
if separators == 0 {
|
||||||
|
return []string{u}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if separators > 1 {
|
||||||
|
return nil, fmt.Errorf("port range [%s] has %d separators", ports, separators)
|
||||||
|
}
|
||||||
|
|
||||||
|
portsStr := strings.Split(ports, "-")
|
||||||
|
pIni, err := strconv.Atoi(portsStr[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pEnd, err := strconv.Atoi(portsStr[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if pEnd <= pIni {
|
||||||
|
return nil, fmt.Errorf("port range [%s] is invalid", ports)
|
||||||
|
}
|
||||||
|
|
||||||
|
hosts := []string{}
|
||||||
|
for p := pIni; p <= pEnd; p++ {
|
||||||
|
hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue))
|
||||||
|
}
|
||||||
|
|
||||||
|
return hosts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
|
func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
|
||||||
switch c.Val() {
|
switch c.Val() {
|
||||||
case "policy":
|
case "policy":
|
||||||
if !c.NextArg() {
|
if !c.NextArg() {
|
||||||
|
@ -348,6 +392,11 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
|
||||||
if !c.NextArg() {
|
if !c.NextArg() {
|
||||||
return c.ArgErr()
|
return c.ArgErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hasSrv {
|
||||||
|
return c.Err("health_check_port directive is not allowed when upstream is SRV locator")
|
||||||
|
}
|
||||||
|
|
||||||
port := c.Val()
|
port := c.Val()
|
||||||
n, err := strconv.Atoi(port)
|
n, err := strconv.Atoi(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -420,54 +469,94 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *staticUpstream) resolveHost(h string) ([]string, bool, error) {
|
||||||
|
names := []string{}
|
||||||
|
proto := "http"
|
||||||
|
if !strings.HasPrefix(h, "srv://") && !strings.HasPrefix(h, "srv+https://") {
|
||||||
|
return []string{h}, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(h, "srv+https://") {
|
||||||
|
proto = "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
_, addrs, err := u.resolver.LookupSRV(context.Background(), "", "", h)
|
||||||
|
if err != nil {
|
||||||
|
return names, true, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
names = append(names, fmt.Sprintf("%s://%s:%d", proto, addr.Target, addr.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
return names, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (u *staticUpstream) healthCheck() {
|
func (u *staticUpstream) healthCheck() {
|
||||||
for _, host := range u.Hosts {
|
for _, host := range u.Hosts {
|
||||||
hostURL := host.Name
|
candidates, isSrv, err := u.resolveHost(host.Name)
|
||||||
if u.HealthCheck.Port != "" {
|
if err != nil {
|
||||||
hostURL = replacePort(host.Name, u.HealthCheck.Port)
|
host.HealthCheckResult.Store(err.Error())
|
||||||
}
|
|
||||||
hostURL += u.HealthCheck.Path
|
|
||||||
|
|
||||||
unhealthy := func() bool {
|
|
||||||
// set up request, needed to be able to modify headers
|
|
||||||
// possible errors are bad HTTP methods or un-parsable urls
|
|
||||||
req, err := http.NewRequest("GET", hostURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// set host for request going upstream
|
|
||||||
if u.HealthCheck.Host != "" {
|
|
||||||
req.Host = u.HealthCheck.Host
|
|
||||||
}
|
|
||||||
r, err := u.HealthCheck.Client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
io.Copy(ioutil.Discard, r.Body)
|
|
||||||
r.Body.Close()
|
|
||||||
}()
|
|
||||||
if r.StatusCode < 200 || r.StatusCode >= 400 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if u.HealthCheck.ContentString == "" { // don't check for content string
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// TODO ReadAll will be replaced if deemed necessary
|
|
||||||
// See https://github.com/mholt/caddy/pull/1691
|
|
||||||
buf, err := ioutil.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if bytes.Contains(buf, []byte(u.HealthCheck.ContentString)) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}()
|
|
||||||
if unhealthy {
|
|
||||||
atomic.StoreInt32(&host.Unhealthy, 1)
|
atomic.StoreInt32(&host.Unhealthy, 1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
unhealthyCount := 0
|
||||||
|
for _, addr := range candidates {
|
||||||
|
hostURL := addr
|
||||||
|
if !isSrv && u.HealthCheck.Port != "" {
|
||||||
|
hostURL = replacePort(hostURL, u.HealthCheck.Port)
|
||||||
|
}
|
||||||
|
hostURL += u.HealthCheck.Path
|
||||||
|
|
||||||
|
unhealthy := func() bool {
|
||||||
|
// set up request, needed to be able to modify headers
|
||||||
|
// possible errors are bad HTTP methods or un-parsable urls
|
||||||
|
req, err := http.NewRequest("GET", hostURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// set host for request going upstream
|
||||||
|
if u.HealthCheck.Host != "" {
|
||||||
|
req.Host = u.HealthCheck.Host
|
||||||
|
}
|
||||||
|
r, err := u.HealthCheck.Client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
io.Copy(ioutil.Discard, r.Body)
|
||||||
|
r.Body.Close()
|
||||||
|
}()
|
||||||
|
if r.StatusCode < 200 || r.StatusCode >= 400 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if u.HealthCheck.ContentString == "" { // don't check for content string
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// TODO ReadAll will be replaced if deemed necessary
|
||||||
|
// See https://github.com/mholt/caddy/pull/1691
|
||||||
|
buf, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if bytes.Contains(buf, []byte(u.HealthCheck.ContentString)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}()
|
||||||
|
|
||||||
|
if unhealthy {
|
||||||
|
unhealthyCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if unhealthyCount == len(candidates) {
|
||||||
|
atomic.StoreInt32(&host.Unhealthy, 1)
|
||||||
|
host.HealthCheckResult.Store("Failed")
|
||||||
} else {
|
} else {
|
||||||
atomic.StoreInt32(&host.Unhealthy, 0)
|
atomic.StoreInt32(&host.Unhealthy, 0)
|
||||||
|
host.HealthCheckResult.Store("OK")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,10 +15,15 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -187,7 +192,7 @@ func TestParseBlockHealthCheck(t *testing.T) {
|
||||||
u := staticUpstream{}
|
u := staticUpstream{}
|
||||||
c := caddyfile.NewDispenser("Testfile", strings.NewReader(test.config))
|
c := caddyfile.NewDispenser("Testfile", strings.NewReader(test.config))
|
||||||
for c.Next() {
|
for c.Next() {
|
||||||
parseBlock(&c, &u)
|
parseBlock(&c, &u, false)
|
||||||
}
|
}
|
||||||
if u.HealthCheck.Interval.String() != test.interval {
|
if u.HealthCheck.Interval.String() != test.interval {
|
||||||
t.Errorf(
|
t.Errorf(
|
||||||
|
@ -551,3 +556,216 @@ func TestQuicHost(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseSRVBlock(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
config string
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{"proxy / srv://bogus.service", false},
|
||||||
|
{"proxy / srv://bogus.service:80", true},
|
||||||
|
{"proxy / srv://bogus.service srv://bogus.service.fallback", true},
|
||||||
|
{"proxy / srv://bogus.service http://bogus.service.fallback", true},
|
||||||
|
{"proxy / http://bogus.service srv://bogus.service.fallback", true},
|
||||||
|
{"proxy / srv://bogus.service bogus.service.fallback", true},
|
||||||
|
{`proxy / srv://bogus.service {
|
||||||
|
upstream srv://bogus.service
|
||||||
|
}`, true},
|
||||||
|
{"proxy / srv+https://bogus.service", false},
|
||||||
|
{"proxy / srv+https://bogus.service:80", true},
|
||||||
|
{"proxy / srv+https://bogus.service srv://bogus.service.fallback", true},
|
||||||
|
{"proxy / srv+https://bogus.service http://bogus.service.fallback", true},
|
||||||
|
{"proxy / http://bogus.service srv+https://bogus.service.fallback", true},
|
||||||
|
{"proxy / srv+https://bogus.service bogus.service.fallback", true},
|
||||||
|
{`proxy / srv+https://bogus.service {
|
||||||
|
upstream srv://bogus.service
|
||||||
|
}`, true},
|
||||||
|
{`proxy / srv+https://bogus.service {
|
||||||
|
health_check_port 96
|
||||||
|
}`, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
_, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(test.config)), "")
|
||||||
|
if err == nil && test.shouldErr {
|
||||||
|
t.Errorf("Case %d - Expected an error. got nothing", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && !test.shouldErr {
|
||||||
|
t.Errorf("Case %d - Expected no error. got %s", i, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testResolver struct {
|
||||||
|
errOn string
|
||||||
|
result []*net.SRV
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r testResolver) LookupSRV(ctx context.Context, _, _, service string) (string, []*net.SRV, error) {
|
||||||
|
if service == r.errOn {
|
||||||
|
return "", nil, errors.New("an error occurred")
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", r.result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveHost(t *testing.T) {
|
||||||
|
upstream := &staticUpstream{
|
||||||
|
resolver: testResolver{
|
||||||
|
errOn: "srv://problematic.service.name",
|
||||||
|
result: []*net.SRV{
|
||||||
|
{Target: "target-1.fqdn", Port: 85, Priority: 1, Weight: 1},
|
||||||
|
{Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1},
|
||||||
|
{Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
host string
|
||||||
|
expect []string
|
||||||
|
isSrv bool
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
// Static DNS records
|
||||||
|
{"http://subdomain.domain.service",
|
||||||
|
[]string{"http://subdomain.domain.service"},
|
||||||
|
false,
|
||||||
|
false},
|
||||||
|
{"https://subdomain.domain.service",
|
||||||
|
[]string{"https://subdomain.domain.service"},
|
||||||
|
false,
|
||||||
|
false},
|
||||||
|
{"http://subdomain.domain.service:76",
|
||||||
|
[]string{"http://subdomain.domain.service:76"},
|
||||||
|
false,
|
||||||
|
false},
|
||||||
|
{"https://subdomain.domain.service:65",
|
||||||
|
[]string{"https://subdomain.domain.service:65"},
|
||||||
|
false,
|
||||||
|
false},
|
||||||
|
|
||||||
|
// SRV lookups
|
||||||
|
{"srv://service.name", []string{
|
||||||
|
"http://target-1.fqdn:85",
|
||||||
|
"http://target-2.fqdn:33",
|
||||||
|
"http://target-3.fqdn:94",
|
||||||
|
}, true, false},
|
||||||
|
{"srv+https://service.name", []string{
|
||||||
|
"https://target-1.fqdn:85",
|
||||||
|
"https://target-2.fqdn:33",
|
||||||
|
"https://target-3.fqdn:94",
|
||||||
|
}, true, false},
|
||||||
|
{"srv://problematic.service.name", []string{}, true, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
results, isSrv, err := upstream.resolveHost(test.host)
|
||||||
|
if err == nil && test.shouldErr {
|
||||||
|
t.Errorf("Test %d - expected an error, got none", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && !test.shouldErr {
|
||||||
|
t.Errorf("Test %d - unexpected error %s", i, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.isSrv && !isSrv {
|
||||||
|
t.Errorf("Test %d - expecting resolution to be SRV lookup but it isn't", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isSrv && !test.isSrv {
|
||||||
|
t.Errorf("Test %d - expecting resolution to be normal lookup, got SRV", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(results, test.expect) {
|
||||||
|
t.Errorf("Test %d - resolution result %#v does not match expected value %#v", i, results, test.expect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSRVHealthCheck(t *testing.T) {
|
||||||
|
serverURL, err := url.Parse(workableServer.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to parse test server URL: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
pp, err := strconv.Atoi(serverURL.Port())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to parse test server port [%s]: %s", serverURL.Port(), err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
port := uint16(pp)
|
||||||
|
|
||||||
|
allGoodResolver := testResolver{
|
||||||
|
result: []*net.SRV{
|
||||||
|
{Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
partialFailureResolver := testResolver{
|
||||||
|
result: []*net.SRV{
|
||||||
|
{Target: serverURL.Hostname(), Port: port, Priority: 1, Weight: 1},
|
||||||
|
{Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1},
|
||||||
|
{Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fullFailureResolver := testResolver{
|
||||||
|
result: []*net.SRV{
|
||||||
|
{Target: "target-1.fqdn", Port: 876, Priority: 1, Weight: 1},
|
||||||
|
{Target: "target-2.fqdn", Port: 33, Priority: 1, Weight: 1},
|
||||||
|
{Target: "target-3.fqdn", Port: 94, Priority: 1, Weight: 1},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resolutionErrorResolver := testResolver{
|
||||||
|
errOn: "srv://tag.service.consul",
|
||||||
|
result: []*net.SRV{},
|
||||||
|
}
|
||||||
|
|
||||||
|
upstream := &staticUpstream{
|
||||||
|
Hosts: []*UpstreamHost{
|
||||||
|
{Name: "srv://tag.service.consul"},
|
||||||
|
},
|
||||||
|
FailTimeout: 10 * time.Second,
|
||||||
|
MaxFails: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
resolver testResolver
|
||||||
|
shouldFail bool
|
||||||
|
shouldErr bool
|
||||||
|
}{
|
||||||
|
{allGoodResolver, false, false},
|
||||||
|
{partialFailureResolver, false, false},
|
||||||
|
{fullFailureResolver, true, false},
|
||||||
|
{resolutionErrorResolver, true, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
upstream.resolver = test.resolver
|
||||||
|
upstream.healthCheck()
|
||||||
|
if upstream.Hosts[0].Down() && !test.shouldFail {
|
||||||
|
t.Errorf("Test %d - expected all healthchecks to pass, all failing", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.shouldFail && !upstream.Hosts[0].Down() {
|
||||||
|
t.Errorf("Test %d - expected all healthchecks to fail, all passing", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
status := fmt.Sprintf("%s", upstream.Hosts[0].HealthCheckResult.Load())
|
||||||
|
|
||||||
|
if test.shouldFail && !test.shouldErr && status != "Failed" {
|
||||||
|
t.Errorf("Test %d - Expected health check result to be 'Failed', got '%s'", i, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !test.shouldFail && status != "OK" {
|
||||||
|
t.Errorf("Test %d - Expected health check result to be 'OK', got '%s'", i, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if test.shouldErr && status != "an error occurred" {
|
||||||
|
t.Errorf("Test %d - Expected health check result to be 'an error occured', got '%s'", i, status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -39,6 +39,7 @@ type ACMEClient struct {
|
||||||
AllowPrompts bool
|
AllowPrompts bool
|
||||||
config *Config
|
config *Config
|
||||||
acmeClient *acme.Client
|
acmeClient *acme.Client
|
||||||
|
locker Locker
|
||||||
}
|
}
|
||||||
|
|
||||||
// newACMEClient creates a new ACMEClient given an email and whether
|
// newACMEClient creates a new ACMEClient given an email and whether
|
||||||
|
@ -120,6 +121,10 @@ var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error)
|
||||||
AllowPrompts: allowPrompts,
|
AllowPrompts: allowPrompts,
|
||||||
config: config,
|
config: config,
|
||||||
acmeClient: client,
|
acmeClient: client,
|
||||||
|
locker: &syncLock{
|
||||||
|
nameLocks: make(map[string]*sync.WaitGroup),
|
||||||
|
nameLocksMu: sync.Mutex{},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.DNSProvider == "" {
|
if config.DNSProvider == "" {
|
||||||
|
@ -210,7 +215,7 @@ func (c *ACMEClient) Obtain(name string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
waiter, err := storage.TryLock(name)
|
waiter, err := c.locker.TryLock(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -220,7 +225,7 @@ func (c *ACMEClient) Obtain(name string) error {
|
||||||
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
|
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := storage.Unlock(name); err != nil {
|
if err := c.locker.Unlock(name); err != nil {
|
||||||
log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err)
|
log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -286,7 +291,7 @@ func (c *ACMEClient) Renew(name string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
waiter, err := storage.TryLock(name)
|
waiter, err := c.locker.TryLock(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -296,7 +301,7 @@ func (c *ACMEClient) Renew(name string) error {
|
||||||
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
|
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := storage.Unlock(name); err != nil {
|
if err := c.locker.Unlock(name); err != nil {
|
||||||
log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err)
|
log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
|
@ -22,7 +22,6 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/mholt/caddy"
|
"github.com/mholt/caddy"
|
||||||
)
|
)
|
||||||
|
@ -40,8 +39,7 @@ var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
|
||||||
// instance is guaranteed to be non-nil if there is no error.
|
// instance is guaranteed to be non-nil if there is no error.
|
||||||
func NewFileStorage(caURL *url.URL) (Storage, error) {
|
func NewFileStorage(caURL *url.URL) (Storage, error) {
|
||||||
return &FileStorage{
|
return &FileStorage{
|
||||||
Path: filepath.Join(storageBasePath, caURL.Host),
|
Path: filepath.Join(storageBasePath, caURL.Host),
|
||||||
nameLocks: make(map[string]*sync.WaitGroup),
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,9 +47,7 @@ func NewFileStorage(caURL *url.URL) (Storage, error) {
|
||||||
// directory. It is used to get file paths in a consistent,
|
// directory. It is used to get file paths in a consistent,
|
||||||
// cross-platform way or persisting ACME assets on the file system.
|
// cross-platform way or persisting ACME assets on the file system.
|
||||||
type FileStorage struct {
|
type FileStorage struct {
|
||||||
Path string
|
Path string
|
||||||
nameLocks map[string]*sync.WaitGroup
|
|
||||||
nameLocksMu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sites gets the directory that stores site certificate and keys.
|
// sites gets the directory that stores site certificate and keys.
|
||||||
|
@ -254,36 +250,6 @@ func (s *FileStorage) StoreUser(email string, data *UserData) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TryLock attempts to get a lock for name, otherwise it returns
|
|
||||||
// a Waiter value to wait until the other process is finished.
|
|
||||||
func (s *FileStorage) TryLock(name string) (Waiter, error) {
|
|
||||||
s.nameLocksMu.Lock()
|
|
||||||
defer s.nameLocksMu.Unlock()
|
|
||||||
wg, ok := s.nameLocks[name]
|
|
||||||
if ok {
|
|
||||||
// lock already obtained, let caller wait on it
|
|
||||||
return wg, nil
|
|
||||||
}
|
|
||||||
// caller gets lock
|
|
||||||
wg = new(sync.WaitGroup)
|
|
||||||
wg.Add(1)
|
|
||||||
s.nameLocks[name] = wg
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unlock unlocks name.
|
|
||||||
func (s *FileStorage) Unlock(name string) error {
|
|
||||||
s.nameLocksMu.Lock()
|
|
||||||
defer s.nameLocksMu.Unlock()
|
|
||||||
wg, ok := s.nameLocks[name]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("FileStorage: no lock to release for %s", name)
|
|
||||||
}
|
|
||||||
wg.Done()
|
|
||||||
delete(s.nameLocks, name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MostRecentUserEmail implements Storage.MostRecentUserEmail by finding the
|
// MostRecentUserEmail implements Storage.MostRecentUserEmail by finding the
|
||||||
// most recently written sub directory in the users' directory. It is named
|
// most recently written sub directory in the users' directory. It is named
|
||||||
// after the email address. This corresponds to the most recent call to
|
// after the email address. This corresponds to the most recent call to
|
||||||
|
|
|
@ -39,24 +39,9 @@ type UserData struct {
|
||||||
Key []byte
|
Key []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// Storage is an interface abstracting all storage used by Caddy's TLS
|
// Locker provides support for mutual exclusion
|
||||||
// subsystem. Implementations of this interface store both site and
|
type Locker interface {
|
||||||
// user data.
|
// TryLock will return immediatedly with or without acquiring the lock.
|
||||||
type Storage interface {
|
|
||||||
// SiteExists returns true if this site exists in storage.
|
|
||||||
// Site data is considered present when StoreSite has been called
|
|
||||||
// successfully (without DeleteSite having been called, of course).
|
|
||||||
SiteExists(domain string) (bool, error)
|
|
||||||
|
|
||||||
// TryLock is called before Caddy attempts to obtain or renew a
|
|
||||||
// certificate for a certain name and store it. From the perspective
|
|
||||||
// of this method and its companion Unlock, the actions of
|
|
||||||
// obtaining/renewing and then storing the certificate are atomic,
|
|
||||||
// and both should occur within a lock. This prevents multiple
|
|
||||||
// processes -- maybe distributed ones -- from stepping on each
|
|
||||||
// other's space in the same shared storage, and from spamming
|
|
||||||
// certificate providers with multiple, redundant requests.
|
|
||||||
//
|
|
||||||
// If a lock could be obtained, (nil, nil) is returned and you may
|
// If a lock could be obtained, (nil, nil) is returned and you may
|
||||||
// continue normally. If not (meaning another process is already
|
// continue normally. If not (meaning another process is already
|
||||||
// working on that name), a Waiter value will be returned upon
|
// working on that name), a Waiter value will be returned upon
|
||||||
|
@ -75,6 +60,16 @@ type Storage interface {
|
||||||
// the obtain/renew and store are finished, even if there was
|
// the obtain/renew and store are finished, even if there was
|
||||||
// an error (or a timeout).
|
// an error (or a timeout).
|
||||||
Unlock(name string) error
|
Unlock(name string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Storage is an interface abstracting all storage used by Caddy's TLS
|
||||||
|
// subsystem. Implementations of this interface store both site and
|
||||||
|
// user data.
|
||||||
|
type Storage interface {
|
||||||
|
// SiteExists returns true if this site exists in storage.
|
||||||
|
// Site data is considered present when StoreSite has been called
|
||||||
|
// successfully (without DeleteSite having been called, of course).
|
||||||
|
SiteExists(domain string) (bool, error)
|
||||||
|
|
||||||
// LoadSite obtains the site data from storage for the given domain and
|
// LoadSite obtains the site data from storage for the given domain and
|
||||||
// returns it. If data for the domain does not exist, an error value
|
// returns it. If data for the domain does not exist, an error value
|
||||||
|
|
57
caddytls/sync_locker.go
Normal file
57
caddytls/sync_locker.go
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
// Copyright 2015 Light Code Labs, LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package caddytls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ Locker = &syncLock{}
|
||||||
|
|
||||||
|
type syncLock struct {
|
||||||
|
nameLocks map[string]*sync.WaitGroup
|
||||||
|
nameLocksMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryLock attempts to get a lock for name, otherwise it returns
|
||||||
|
// a Waiter value to wait until the other process is finished.
|
||||||
|
func (s *syncLock) TryLock(name string) (Waiter, error) {
|
||||||
|
s.nameLocksMu.Lock()
|
||||||
|
defer s.nameLocksMu.Unlock()
|
||||||
|
wg, ok := s.nameLocks[name]
|
||||||
|
if ok {
|
||||||
|
// lock already obtained, let caller wait on it
|
||||||
|
return wg, nil
|
||||||
|
}
|
||||||
|
// caller gets lock
|
||||||
|
wg = new(sync.WaitGroup)
|
||||||
|
wg.Add(1)
|
||||||
|
s.nameLocks[name] = wg
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlock unlocks name.
|
||||||
|
func (s *syncLock) Unlock(name string) error {
|
||||||
|
s.nameLocksMu.Lock()
|
||||||
|
defer s.nameLocksMu.Unlock()
|
||||||
|
wg, ok := s.nameLocks[name]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("FileStorage: no lock to release for %s", name)
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
delete(s.nameLocks, name)
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -16,7 +16,6 @@ package caddytls
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/xenolf/lego/acme"
|
"github.com/xenolf/lego/acme"
|
||||||
|
@ -94,7 +93,7 @@ func TestQualifiesForManagedTLS(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSaveCertResource(t *testing.T) {
|
func TestSaveCertResource(t *testing.T) {
|
||||||
storage := &FileStorage{Path: "./le_test_save", nameLocks: make(map[string]*sync.WaitGroup)}
|
storage := &FileStorage{Path: "./le_test_save"}
|
||||||
defer func() {
|
defer func() {
|
||||||
err := os.RemoveAll(storage.Path)
|
err := os.RemoveAll(storage.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -140,7 +139,7 @@ func TestSaveCertResource(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExistingCertAndKey(t *testing.T) {
|
func TestExistingCertAndKey(t *testing.T) {
|
||||||
storage := &FileStorage{Path: "./le_test_existing", nameLocks: make(map[string]*sync.WaitGroup)}
|
storage := &FileStorage{Path: "./le_test_existing"}
|
||||||
defer func() {
|
defer func() {
|
||||||
err := os.RemoveAll(storage.Path)
|
err := os.RemoveAll(storage.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -21,7 +21,6 @@ import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -196,7 +195,7 @@ func TestGetEmail(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var testStorage = &FileStorage{Path: "./testdata", nameLocks: make(map[string]*sync.WaitGroup)}
|
var testStorage = &FileStorage{Path: "./testdata"}
|
||||||
|
|
||||||
func (s *FileStorage) clean() error {
|
func (s *FileStorage) clean() error {
|
||||||
return os.RemoveAll(s.Path)
|
return os.RemoveAll(s.Path)
|
||||||
|
|
Loading…
Reference in a new issue