mirror of
https://github.com/caddyserver/caddy.git
synced 2025-03-30 17:09:05 +03:00
reverseproxy: Fix host and port on requests; fix Caddyfile parser
This commit is contained in:
parent
b4dce74e59
commit
758269124e
2 changed files with 128 additions and 94 deletions
|
@ -376,108 +376,110 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
func (h *HTTPTransport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||||
for d.NextBlock(0) {
|
for d.Next() {
|
||||||
switch d.Val() {
|
for d.NextBlock(0) {
|
||||||
case "read_buffer":
|
switch d.Val() {
|
||||||
if !d.NextArg() {
|
case "read_buffer":
|
||||||
return d.ArgErr()
|
if !d.NextArg() {
|
||||||
}
|
return d.ArgErr()
|
||||||
size, err := humanize.ParseBytes(d.Val())
|
}
|
||||||
if err != nil {
|
size, err := humanize.ParseBytes(d.Val())
|
||||||
return d.Errf("invalid read buffer size '%s': %v", d.Val(), err)
|
if err != nil {
|
||||||
}
|
return d.Errf("invalid read buffer size '%s': %v", d.Val(), err)
|
||||||
h.ReadBufferSize = int(size)
|
}
|
||||||
|
h.ReadBufferSize = int(size)
|
||||||
|
|
||||||
case "write_buffer":
|
case "write_buffer":
|
||||||
if !d.NextArg() {
|
if !d.NextArg() {
|
||||||
return d.ArgErr()
|
return d.ArgErr()
|
||||||
}
|
}
|
||||||
size, err := humanize.ParseBytes(d.Val())
|
size, err := humanize.ParseBytes(d.Val())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return d.Errf("invalid write buffer size '%s': %v", d.Val(), err)
|
return d.Errf("invalid write buffer size '%s': %v", d.Val(), err)
|
||||||
}
|
}
|
||||||
h.WriteBufferSize = int(size)
|
h.WriteBufferSize = int(size)
|
||||||
|
|
||||||
case "dial_timeout":
|
case "dial_timeout":
|
||||||
if !d.NextArg() {
|
if !d.NextArg() {
|
||||||
return d.ArgErr()
|
return d.ArgErr()
|
||||||
}
|
}
|
||||||
dur, err := time.ParseDuration(d.Val())
|
dur, err := time.ParseDuration(d.Val())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return d.Errf("bad timeout value '%s': %v", d.Val(), err)
|
return d.Errf("bad timeout value '%s': %v", d.Val(), err)
|
||||||
}
|
}
|
||||||
h.DialTimeout = caddy.Duration(dur)
|
h.DialTimeout = caddy.Duration(dur)
|
||||||
|
|
||||||
case "tls_client_auth":
|
case "tls_client_auth":
|
||||||
args := d.RemainingArgs()
|
args := d.RemainingArgs()
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return d.ArgErr()
|
return d.ArgErr()
|
||||||
}
|
}
|
||||||
if h.TLS == nil {
|
if h.TLS == nil {
|
||||||
h.TLS = new(TLSConfig)
|
h.TLS = new(TLSConfig)
|
||||||
}
|
}
|
||||||
h.TLS.ClientCertificateFile = args[0]
|
h.TLS.ClientCertificateFile = args[0]
|
||||||
h.TLS.ClientCertificateKeyFile = args[1]
|
h.TLS.ClientCertificateKeyFile = args[1]
|
||||||
|
|
||||||
case "tls":
|
case "tls":
|
||||||
if h.TLS == nil {
|
if h.TLS == nil {
|
||||||
h.TLS = new(TLSConfig)
|
h.TLS = new(TLSConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "tls_insecure_skip_verify":
|
case "tls_insecure_skip_verify":
|
||||||
if d.NextArg() {
|
if d.NextArg() {
|
||||||
return d.ArgErr()
|
return d.ArgErr()
|
||||||
}
|
}
|
||||||
if h.TLS == nil {
|
if h.TLS == nil {
|
||||||
h.TLS = new(TLSConfig)
|
h.TLS = new(TLSConfig)
|
||||||
}
|
}
|
||||||
h.TLS.InsecureSkipVerify = true
|
h.TLS.InsecureSkipVerify = true
|
||||||
|
|
||||||
case "tls_timeout":
|
case "tls_timeout":
|
||||||
if !d.NextArg() {
|
if !d.NextArg() {
|
||||||
return d.ArgErr()
|
return d.ArgErr()
|
||||||
}
|
}
|
||||||
dur, err := time.ParseDuration(d.Val())
|
dur, err := time.ParseDuration(d.Val())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return d.Errf("bad timeout value '%s': %v", d.Val(), err)
|
return d.Errf("bad timeout value '%s': %v", d.Val(), err)
|
||||||
}
|
}
|
||||||
if h.TLS == nil {
|
if h.TLS == nil {
|
||||||
h.TLS = new(TLSConfig)
|
h.TLS = new(TLSConfig)
|
||||||
}
|
}
|
||||||
h.TLS.HandshakeTimeout = caddy.Duration(dur)
|
h.TLS.HandshakeTimeout = caddy.Duration(dur)
|
||||||
|
|
||||||
case "keepalive":
|
case "keepalive":
|
||||||
if !d.NextArg() {
|
if !d.NextArg() {
|
||||||
return d.ArgErr()
|
return d.ArgErr()
|
||||||
}
|
}
|
||||||
if h.KeepAlive == nil {
|
if h.KeepAlive == nil {
|
||||||
h.KeepAlive = new(KeepAlive)
|
h.KeepAlive = new(KeepAlive)
|
||||||
}
|
}
|
||||||
if d.Val() == "off" {
|
if d.Val() == "off" {
|
||||||
var disable bool
|
var disable bool
|
||||||
h.KeepAlive.Enabled = &disable
|
h.KeepAlive.Enabled = &disable
|
||||||
}
|
}
|
||||||
dur, err := time.ParseDuration(d.Val())
|
dur, err := time.ParseDuration(d.Val())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return d.Errf("bad duration value '%s': %v", d.Val(), err)
|
return d.Errf("bad duration value '%s': %v", d.Val(), err)
|
||||||
}
|
}
|
||||||
h.KeepAlive.IdleConnTimeout = caddy.Duration(dur)
|
h.KeepAlive.IdleConnTimeout = caddy.Duration(dur)
|
||||||
|
|
||||||
case "keepalive_idle_conns":
|
case "keepalive_idle_conns":
|
||||||
if !d.NextArg() {
|
if !d.NextArg() {
|
||||||
return d.ArgErr()
|
return d.ArgErr()
|
||||||
}
|
}
|
||||||
num, err := strconv.Atoi(d.Val())
|
num, err := strconv.Atoi(d.Val())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return d.Errf("bad integer value '%s': %v", d.Val(), err)
|
return d.Errf("bad integer value '%s': %v", d.Val(), err)
|
||||||
}
|
}
|
||||||
if h.KeepAlive == nil {
|
if h.KeepAlive == nil {
|
||||||
h.KeepAlive = new(KeepAlive)
|
h.KeepAlive = new(KeepAlive)
|
||||||
}
|
}
|
||||||
h.KeepAlive.MaxIdleConns = num
|
h.KeepAlive.MaxIdleConns = num
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return d.Errf("unrecognized subdirective %s", d.Val())
|
return d.Errf("unrecognized subdirective %s", d.Val())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -145,6 +145,26 @@ func (h *Handler) Provision(ctx caddy.Context) error {
|
||||||
|
|
||||||
var allUpstreams []*Upstream
|
var allUpstreams []*Upstream
|
||||||
for _, upstream := range h.Upstreams {
|
for _, upstream := range h.Upstreams {
|
||||||
|
// if a port was not specified (and the network type uses
|
||||||
|
// ports), then maybe we can figure out the default port
|
||||||
|
netw, host, port, err := caddy.SplitNetworkAddress(upstream.Dial)
|
||||||
|
if err != nil && port == "" && !strings.Contains(netw, "unix") {
|
||||||
|
if host == "" {
|
||||||
|
// assume all that was given was the host, no port
|
||||||
|
host = upstream.Dial
|
||||||
|
}
|
||||||
|
// a port was not specified, but we may be able to
|
||||||
|
// infer it if we know the standard ports on which
|
||||||
|
// the transport protocol operates
|
||||||
|
if ht, ok := h.Transport.(*HTTPTransport); ok {
|
||||||
|
defaultPort := "80"
|
||||||
|
if ht.TLS != nil {
|
||||||
|
defaultPort = "443"
|
||||||
|
}
|
||||||
|
upstream.Dial = caddy.JoinNetworkAddress(netw, host, defaultPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// upstreams are allowed to map to only a single host,
|
// upstreams are allowed to map to only a single host,
|
||||||
// but an upstream's address may semantically represent
|
// but an upstream's address may semantically represent
|
||||||
// multiple addresses, so make sure to handle each
|
// multiple addresses, so make sure to handle each
|
||||||
|
@ -474,7 +494,19 @@ func (h Handler) tryAgain(start time.Time, proxyErr error) bool {
|
||||||
// given upstream host. It must modify ONLY the request URL.
|
// given upstream host. It must modify ONLY the request URL.
|
||||||
func (h Handler) directRequest(req *http.Request, upstream *Upstream) {
|
func (h Handler) directRequest(req *http.Request, upstream *Upstream) {
|
||||||
if req.URL.Host == "" {
|
if req.URL.Host == "" {
|
||||||
req.URL.Host = upstream.dialInfo.Address
|
// we need a host, so set the upstream's host address
|
||||||
|
fullHost := upstream.dialInfo.Address
|
||||||
|
|
||||||
|
// but if the port matches the scheme, strip the port because
|
||||||
|
// it's weird to make a request like http://example.com:80/.
|
||||||
|
host, port, err := net.SplitHostPort(fullHost)
|
||||||
|
if err == nil &&
|
||||||
|
(req.URL.Scheme == "http" && port == "80") ||
|
||||||
|
(req.URL.Scheme == "https" && port == "443") {
|
||||||
|
fullHost = host
|
||||||
|
}
|
||||||
|
|
||||||
|
req.URL.Host = fullHost
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue