mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-23 18:55:49 +03:00
59bf71c293
* Add a shutdown function and context to staticUpstream so that running goroutines can be cancelled. Add a GetShutdownFunc to Upstream interface to expose the shutdown function to the caddy Controller for performing it on restarts. * Make fakeUpstream implement new Upstream methods. Implement new Upstream method for fakeWSUpstream as well. * Rename GetShutdownFunc to Stop(). Add a waitgroup to the staticUpstream for controlling individual object's goroutines. Add the Stop function to OnRestart and OnShutdown. Add tests for checking to see if healthchecks continue hitting a backend server after stop has been called. * Go back to using a stop channel since the context adds no additional benefit. Only register stop function for onShutdown since it's called as part of restart. * Remove assignment to atomic value * Incrementing WaitGroup outside of goroutine to avoid race condition. Loading atomic values in test. * Linting: change counter to just use the default zero value instead of setting it * Clarify Stop method comments, add comments to stop channel and waitgroup and remove out of date comment about handling stopping the proxy. Stop the ticker when the stop signal is sent
456 lines
10 KiB
Go
456 lines
10 KiB
Go
package proxy
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/mholt/caddy/caddyfile"
|
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
|
)
|
|
|
|
var (
|
|
supportedPolicies = make(map[string]func() Policy)
|
|
)
|
|
|
|
type staticUpstream struct {
|
|
from string
|
|
upstreamHeaders http.Header
|
|
downstreamHeaders http.Header
|
|
stop chan struct{} // Signals running goroutines to stop.
|
|
wg sync.WaitGroup // Used to wait for running goroutines to stop.
|
|
Hosts HostPool
|
|
Policy Policy
|
|
KeepAlive int
|
|
FailTimeout time.Duration
|
|
TryDuration time.Duration
|
|
TryInterval time.Duration
|
|
MaxConns int64
|
|
HealthCheck struct {
|
|
Client http.Client
|
|
Path string
|
|
Interval time.Duration
|
|
Timeout time.Duration
|
|
}
|
|
WithoutPathPrefix string
|
|
IgnoredSubPaths []string
|
|
insecureSkipVerify bool
|
|
MaxFails int32
|
|
}
|
|
|
|
// NewStaticUpstreams parses the configuration input and sets up
|
|
// static upstreams for the proxy middleware.
|
|
func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) {
|
|
var upstreams []Upstream
|
|
for c.Next() {
|
|
|
|
upstream := &staticUpstream{
|
|
from: "",
|
|
stop: make(chan struct{}),
|
|
upstreamHeaders: make(http.Header),
|
|
downstreamHeaders: make(http.Header),
|
|
Hosts: nil,
|
|
Policy: &Random{},
|
|
MaxFails: 1,
|
|
TryInterval: 250 * time.Millisecond,
|
|
MaxConns: 0,
|
|
KeepAlive: http.DefaultMaxIdleConnsPerHost,
|
|
}
|
|
|
|
if !c.Args(&upstream.from) {
|
|
return upstreams, c.ArgErr()
|
|
}
|
|
|
|
var to []string
|
|
for _, t := range c.RemainingArgs() {
|
|
parsed, err := parseUpstream(t)
|
|
if err != nil {
|
|
return upstreams, err
|
|
}
|
|
to = append(to, parsed...)
|
|
}
|
|
|
|
for c.NextBlock() {
|
|
switch c.Val() {
|
|
case "upstream":
|
|
if !c.NextArg() {
|
|
return upstreams, c.ArgErr()
|
|
}
|
|
parsed, err := parseUpstream(c.Val())
|
|
if err != nil {
|
|
return upstreams, err
|
|
}
|
|
to = append(to, parsed...)
|
|
default:
|
|
if err := parseBlock(&c, upstream); err != nil {
|
|
return upstreams, err
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(to) == 0 {
|
|
return upstreams, c.ArgErr()
|
|
}
|
|
|
|
upstream.Hosts = make([]*UpstreamHost, len(to))
|
|
for i, host := range to {
|
|
uh, err := upstream.NewHost(host)
|
|
if err != nil {
|
|
return upstreams, err
|
|
}
|
|
upstream.Hosts[i] = uh
|
|
}
|
|
|
|
if upstream.HealthCheck.Path != "" {
|
|
upstream.HealthCheck.Client = http.Client{
|
|
Timeout: upstream.HealthCheck.Timeout,
|
|
}
|
|
upstream.wg.Add(1)
|
|
go func() {
|
|
defer upstream.wg.Done()
|
|
upstream.HealthCheckWorker(upstream.stop)
|
|
}()
|
|
}
|
|
upstreams = append(upstreams, upstream)
|
|
}
|
|
return upstreams, nil
|
|
}
|
|
|
|
func (u *staticUpstream) From() string {
|
|
return u.from
|
|
}
|
|
|
|
func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
|
|
if !strings.HasPrefix(host, "http") &&
|
|
!strings.HasPrefix(host, "unix:") {
|
|
host = "http://" + host
|
|
}
|
|
uh := &UpstreamHost{
|
|
Name: host,
|
|
Conns: 0,
|
|
Fails: 0,
|
|
FailTimeout: u.FailTimeout,
|
|
Unhealthy: 0,
|
|
UpstreamHeaders: u.upstreamHeaders,
|
|
DownstreamHeaders: u.downstreamHeaders,
|
|
CheckDown: func(u *staticUpstream) UpstreamHostDownFunc {
|
|
return func(uh *UpstreamHost) bool {
|
|
if atomic.LoadInt32(&uh.Unhealthy) != 0 {
|
|
return true
|
|
}
|
|
if atomic.LoadInt32(&uh.Fails) >= u.MaxFails {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
}(u),
|
|
WithoutPathPrefix: u.WithoutPathPrefix,
|
|
MaxConns: u.MaxConns,
|
|
}
|
|
|
|
baseURL, err := url.Parse(uh.Name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive)
|
|
if u.insecureSkipVerify {
|
|
uh.ReverseProxy.UseInsecureTransport()
|
|
}
|
|
|
|
return uh, nil
|
|
}
|
|
|
|
func parseUpstream(u string) ([]string, error) {
|
|
if !strings.HasPrefix(u, "unix:") {
|
|
colonIdx := strings.LastIndex(u, ":")
|
|
protoIdx := strings.Index(u, "://")
|
|
|
|
if colonIdx != -1 && colonIdx != protoIdx {
|
|
us := u[:colonIdx]
|
|
ue := ""
|
|
portsEnd := len(u)
|
|
if nextSlash := strings.Index(u[colonIdx:], "/"); nextSlash != -1 {
|
|
portsEnd = colonIdx + nextSlash
|
|
ue = u[portsEnd:]
|
|
}
|
|
ports := u[len(us)+1 : portsEnd]
|
|
|
|
if separators := strings.Count(ports, "-"); separators == 1 {
|
|
portsStr := strings.Split(ports, "-")
|
|
pIni, err := strconv.Atoi(portsStr[0])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pEnd, err := strconv.Atoi(portsStr[1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if pEnd <= pIni {
|
|
return nil, fmt.Errorf("port range [%s] is invalid", ports)
|
|
}
|
|
|
|
hosts := []string{}
|
|
for p := pIni; p <= pEnd; p++ {
|
|
hosts = append(hosts, fmt.Sprintf("%s:%d%s", us, p, ue))
|
|
}
|
|
return hosts, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
return []string{u}, nil
|
|
|
|
}
|
|
|
|
func parseBlock(c *caddyfile.Dispenser, u *staticUpstream) error {
|
|
switch c.Val() {
|
|
case "policy":
|
|
if !c.NextArg() {
|
|
return c.ArgErr()
|
|
}
|
|
policyCreateFunc, ok := supportedPolicies[c.Val()]
|
|
if !ok {
|
|
return c.ArgErr()
|
|
}
|
|
u.Policy = policyCreateFunc()
|
|
case "fail_timeout":
|
|
if !c.NextArg() {
|
|
return c.ArgErr()
|
|
}
|
|
dur, err := time.ParseDuration(c.Val())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
u.FailTimeout = dur
|
|
case "max_fails":
|
|
if !c.NextArg() {
|
|
return c.ArgErr()
|
|
}
|
|
n, err := strconv.Atoi(c.Val())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n < 1 {
|
|
return c.Err("max_fails must be at least 1")
|
|
}
|
|
u.MaxFails = int32(n)
|
|
case "try_duration":
|
|
if !c.NextArg() {
|
|
return c.ArgErr()
|
|
}
|
|
dur, err := time.ParseDuration(c.Val())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
u.TryDuration = dur
|
|
case "try_interval":
|
|
if !c.NextArg() {
|
|
return c.ArgErr()
|
|
}
|
|
interval, err := time.ParseDuration(c.Val())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
u.TryInterval = interval
|
|
case "max_conns":
|
|
if !c.NextArg() {
|
|
return c.ArgErr()
|
|
}
|
|
n, err := strconv.ParseInt(c.Val(), 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
u.MaxConns = n
|
|
case "health_check":
|
|
if !c.NextArg() {
|
|
return c.ArgErr()
|
|
}
|
|
u.HealthCheck.Path = c.Val()
|
|
|
|
// Set defaults
|
|
if u.HealthCheck.Interval == 0 {
|
|
u.HealthCheck.Interval = 30 * time.Second
|
|
}
|
|
if u.HealthCheck.Timeout == 0 {
|
|
u.HealthCheck.Timeout = 60 * time.Second
|
|
}
|
|
case "health_check_interval":
|
|
var interval string
|
|
if !c.Args(&interval) {
|
|
return c.ArgErr()
|
|
}
|
|
dur, err := time.ParseDuration(interval)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
u.HealthCheck.Interval = dur
|
|
case "health_check_timeout":
|
|
var interval string
|
|
if !c.Args(&interval) {
|
|
return c.ArgErr()
|
|
}
|
|
dur, err := time.ParseDuration(interval)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
u.HealthCheck.Timeout = dur
|
|
case "header_upstream":
|
|
var header, value string
|
|
if !c.Args(&header, &value) {
|
|
// When removing a header, the value can be optional.
|
|
if !strings.HasPrefix(header, "-") {
|
|
return c.ArgErr()
|
|
}
|
|
}
|
|
u.upstreamHeaders.Add(header, value)
|
|
case "header_downstream":
|
|
var header, value string
|
|
if !c.Args(&header, &value) {
|
|
// When removing a header, the value can be optional.
|
|
if !strings.HasPrefix(header, "-") {
|
|
return c.ArgErr()
|
|
}
|
|
}
|
|
u.downstreamHeaders.Add(header, value)
|
|
case "transparent":
|
|
u.upstreamHeaders.Add("Host", "{host}")
|
|
u.upstreamHeaders.Add("X-Real-IP", "{remote}")
|
|
u.upstreamHeaders.Add("X-Forwarded-For", "{remote}")
|
|
u.upstreamHeaders.Add("X-Forwarded-Proto", "{scheme}")
|
|
case "websocket":
|
|
u.upstreamHeaders.Add("Connection", "{>Connection}")
|
|
u.upstreamHeaders.Add("Upgrade", "{>Upgrade}")
|
|
case "without":
|
|
if !c.NextArg() {
|
|
return c.ArgErr()
|
|
}
|
|
u.WithoutPathPrefix = c.Val()
|
|
case "except":
|
|
ignoredPaths := c.RemainingArgs()
|
|
if len(ignoredPaths) == 0 {
|
|
return c.ArgErr()
|
|
}
|
|
u.IgnoredSubPaths = ignoredPaths
|
|
case "insecure_skip_verify":
|
|
u.insecureSkipVerify = true
|
|
case "keepalive":
|
|
if !c.NextArg() {
|
|
return c.ArgErr()
|
|
}
|
|
n, err := strconv.Atoi(c.Val())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if n < 0 {
|
|
return c.ArgErr()
|
|
}
|
|
u.KeepAlive = n
|
|
default:
|
|
return c.Errf("unknown property '%s'", c.Val())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (u *staticUpstream) healthCheck() {
|
|
for _, host := range u.Hosts {
|
|
hostURL := host.Name + u.HealthCheck.Path
|
|
var unhealthy bool
|
|
if r, err := u.HealthCheck.Client.Get(hostURL); err == nil {
|
|
io.Copy(ioutil.Discard, r.Body)
|
|
r.Body.Close()
|
|
unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
|
|
} else {
|
|
unhealthy = true
|
|
}
|
|
if unhealthy {
|
|
atomic.StoreInt32(&host.Unhealthy, 1)
|
|
} else {
|
|
atomic.StoreInt32(&host.Unhealthy, 0)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
|
|
ticker := time.NewTicker(u.HealthCheck.Interval)
|
|
u.healthCheck()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
u.healthCheck()
|
|
case <-stop:
|
|
ticker.Stop()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *staticUpstream) Select(r *http.Request) *UpstreamHost {
|
|
pool := u.Hosts
|
|
if len(pool) == 1 {
|
|
if !pool[0].Available() {
|
|
return nil
|
|
}
|
|
return pool[0]
|
|
}
|
|
allUnavailable := true
|
|
for _, host := range pool {
|
|
if host.Available() {
|
|
allUnavailable = false
|
|
break
|
|
}
|
|
}
|
|
if allUnavailable {
|
|
return nil
|
|
}
|
|
if u.Policy == nil {
|
|
return (&Random{}).Select(pool, r)
|
|
}
|
|
return u.Policy.Select(pool, r)
|
|
}
|
|
|
|
func (u *staticUpstream) AllowedPath(requestPath string) bool {
|
|
for _, ignoredSubPath := range u.IgnoredSubPaths {
|
|
if httpserver.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// GetTryDuration returns u.TryDuration.
|
|
func (u *staticUpstream) GetTryDuration() time.Duration {
|
|
return u.TryDuration
|
|
}
|
|
|
|
// GetTryInterval returns u.TryInterval.
|
|
func (u *staticUpstream) GetTryInterval() time.Duration {
|
|
return u.TryInterval
|
|
}
|
|
|
|
func (u *staticUpstream) GetHostCount() int {
|
|
return len(u.Hosts)
|
|
}
|
|
|
|
// Stop sends a signal to all goroutines started by this staticUpstream to exit
|
|
// and waits for them to finish before returning.
|
|
func (u *staticUpstream) Stop() error {
|
|
close(u.stop)
|
|
u.wg.Wait()
|
|
return nil
|
|
}
|
|
|
|
// RegisterPolicy adds a custom policy to the proxy.
|
|
func RegisterPolicy(name string, policy func() Policy) {
|
|
supportedPolicies[name] = policy
|
|
}
|