Add send_timeout property to fastcgi directive.

* Convert rwc field on FCGIClient from type io.ReadWriteCloser to net.Conn.
* Return HTTP 504 to the client when a timeout occurs.
* In Handler.ServeHTTP(), close the connection before returning an HTTP
502/504.
* Refactor tests and add coverage.
This commit is contained in:
ericdreeves 2016-11-29 20:26:51 -06:00
parent 17e7e6076a
commit 5874fbeb7e
7 changed files with 219 additions and 81 deletions

View file

@ -19,8 +19,11 @@ type basicDialer struct {
timeout time.Duration
}
func (b basicDialer) Dial() (Client, error) { return Dial(b.network, b.address, b.timeout) }
func (b basicDialer) Close(c Client) error { return c.Close() }
func (b basicDialer) Dial() (Client, error) {
return DialTimeout(b.network, b.address, b.timeout)
}
func (b basicDialer) Close(c Client) error { return c.Close() }
// persistentDialer keeps a pool of fcgi connections.
// connections are not closed after use, rather added back to the pool for reuse.
@ -47,7 +50,7 @@ func (p *persistentDialer) Dial() (Client, error) {
p.Unlock()
// no connection available, create new one
return Dial(p.network, p.address, p.timeout)
return DialTimeout(p.network, p.address, p.timeout)
}
func (p *persistentDialer) Close(client Client) error {

View file

@ -6,6 +6,7 @@ package fastcgi
import (
"errors"
"io"
"net"
"net/http"
"os"
"path"
@ -82,7 +83,9 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
if err != nil {
return http.StatusBadGateway, err
}
defer fcgiBackend.Close()
fcgiBackend.SetReadTimeout(rule.ReadTimeout)
fcgiBackend.SetSendTimeout(rule.SendTimeout)
var resp *http.Response
contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length"))
@ -97,8 +100,12 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
resp, err = fcgiBackend.Post(env, r.Method, r.Header.Get("Content-Type"), r.Body, contentLength)
}
if err != nil && err != io.EOF {
return http.StatusBadGateway, err
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
return http.StatusGatewayTimeout, err
} else if err != io.EOF {
return http.StatusBadGateway, err
}
}
// Write response header
@ -110,8 +117,6 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
return http.StatusBadGateway, err
}
defer rule.dialer.Close(fcgiBackend)
// Log any stderr output from upstream
if stderr := fcgiBackend.StdErr(); stderr.Len() != 0 {
// Remove trailing newline, error logger already does this.
@ -306,6 +311,9 @@ type Rule struct {
// The duration used to set a deadline when reading from the FastCGI server.
ReadTimeout time.Duration
// The duration used to set a deadline when sending to the FastCGI server.
SendTimeout time.Duration
// FCGI dialer
dialer dialer
}

View file

@ -327,38 +327,124 @@ func TestBuildEnv(t *testing.T) {
}
func TestReadTimeout(t *testing.T) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Unable to create listener for test: %v", err)
tests := []struct {
sleep time.Duration
readTimeout time.Duration
shouldErr bool
}{
{75 * time.Millisecond, 50 * time.Millisecond, true},
{0, -1 * time.Second, true},
{0, time.Minute, false},
}
defer listener.Close()
network, address := parseAddress(listener.Addr().String())
handler := Handler{
Next: nil,
Rules: []Rule{
{
Path: "/",
Address: listener.Addr().String(),
dialer: basicDialer{network: network, address: address},
ReadTimeout: time.Millisecond * 100,
var wg sync.WaitGroup
for i, test := range tests {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Test %d: Unable to create listener for test: %v", i, err)
}
defer listener.Close()
network, address := parseAddress(listener.Addr().String())
handler := Handler{
Next: nil,
Rules: []Rule{
{
Path: "/",
Address: listener.Addr().String(),
dialer: basicDialer{network: network, address: address},
ReadTimeout: test.readTimeout,
},
},
},
}
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Unable to create request: %v", err)
}
w := httptest.NewRecorder()
}
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Test %d: Unable to create request: %v", i, err)
}
w := httptest.NewRecorder()
go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond * 130)
}))
wg.Add(1)
go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(test.sleep)
w.WriteHeader(http.StatusOK)
wg.Done()
}))
_, err = handler.ServeHTTP(w, r)
if err == nil {
t.Error("Expected i/o timeout error but had none")
} else if err, ok := err.(net.Error); !ok || !err.Timeout() {
t.Errorf("Expected i/o timeout error, got: '%s'", err.Error())
got, err := handler.ServeHTTP(w, r)
if test.shouldErr {
if err == nil {
t.Errorf("Test %d: Expected i/o timeout error but had none", i)
} else if err, ok := err.(net.Error); !ok || !err.Timeout() {
t.Errorf("Test %d: Expected i/o timeout error, got: '%s'", i, err.Error())
}
want := http.StatusGatewayTimeout
if got != want {
t.Errorf("Test %d: Expected returned status code to be %d, got: %d",
i, want, got)
}
} else if err != nil {
t.Errorf("Test %d: Expected nil error, got: %v", i, err)
}
wg.Wait()
}
}
func TestSendTimeout(t *testing.T) {
tests := []struct {
sendTimeout time.Duration
shouldErr bool
}{
{-1 * time.Second, true},
{time.Minute, false},
}
for i, test := range tests {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Test %d: Unable to create listener for test: %v", i, err)
}
defer listener.Close()
network, address := parseAddress(listener.Addr().String())
handler := Handler{
Next: nil,
Rules: []Rule{
{
Path: "/",
Address: listener.Addr().String(),
dialer: basicDialer{network: network, address: address},
SendTimeout: test.sendTimeout,
},
},
}
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("Test %d: Unable to create request: %v", i, err)
}
w := httptest.NewRecorder()
go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
got, err := handler.ServeHTTP(w, r)
if test.shouldErr {
if err == nil {
t.Errorf("Test %d: Expected i/o timeout error but had none", i)
} else if err, ok := err.(net.Error); !ok || !err.Timeout() {
t.Errorf("Test %d: Expected i/o timeout error, got: '%s'", i, err.Error())
}
want := http.StatusGatewayTimeout
if got != want {
t.Errorf("Test %d: Expected returned status code to be %d, got: %d",
i, want, got)
}
} else if err != nil {
t.Errorf("Test %d: Expected nil error, got: %v", i, err)
}
}
}

View file

@ -15,7 +15,6 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
@ -116,8 +115,8 @@ type Client interface {
Post(pairs map[string]string, method string, bodyType string, body io.Reader, contentLength int) (response *http.Response, err error)
Close() error
StdErr() bytes.Buffer
ReadTimeout() time.Duration
SetReadTimeout(time.Duration) error
SetSendTimeout(time.Duration) error
}
type header struct {
@ -174,57 +173,32 @@ func (rec *record) read(r io.Reader) (buf []byte, err error) {
// interfacing external applications with Web servers.
type FCGIClient struct {
mutex sync.Mutex
rwc io.ReadWriteCloser
conn net.Conn
h header
buf bytes.Buffer
stderr bytes.Buffer
keepAlive bool
reqID uint16
readTimeout time.Duration
sendTimeout time.Duration
}
// DialWithDialer connects to the fcgi responder at the specified network address, using custom net.Dialer.
// DialTimeout connects to the fcgi responder at the specified network address, using default net.Dialer.
// See func net.Dial for a description of the network and address parameters.
func DialWithDialer(network, address string, dialer net.Dialer) (fcgi *FCGIClient, err error) {
var conn net.Conn
conn, err = dialer.Dial(network, address)
func DialTimeout(network string, address string, timeout time.Duration) (fcgi *FCGIClient, err error) {
conn, err := net.DialTimeout(network, address, timeout)
if err != nil {
return
}
fcgi = &FCGIClient{
rwc: conn,
keepAlive: false,
reqID: 1,
}
fcgi = &FCGIClient{conn: conn, keepAlive: false, reqID: 1}
return
}
// Dial connects to the fcgi responder at the specified network address, using default net.Dialer.
// See func net.Dial for a description of the network and address parameters.
func Dial(network string, address string, timeout time.Duration) (fcgi *FCGIClient, err error) {
return DialWithDialer(network, address, net.Dialer{Timeout: timeout})
return fcgi, nil
}
// Close closes fcgi connnection.
func (c *FCGIClient) Close() error {
return c.rwc.Close()
}
// setReadDeadline sets a read deadline on FCGIClient based on the configured
// readTimeout. A zero value for readTimeout means no deadline will be set.
func (c *FCGIClient) setReadDeadline() error {
if c.readTimeout > 0 {
conn, ok := c.rwc.(net.Conn)
if ok {
conn.SetReadDeadline(time.Now().Add(c.readTimeout))
} else {
return fmt.Errorf("Could not set Client ReadTimeout")
}
}
return nil
return c.conn.Close()
}
func (c *FCGIClient) writeRecord(recType uint8, content []byte) error {
@ -245,7 +219,13 @@ func (c *FCGIClient) writeRecord(recType uint8, content []byte) error {
return err
}
if _, err := c.rwc.Write(c.buf.Bytes()); err != nil {
if c.sendTimeout != 0 {
if err := c.conn.SetWriteDeadline(time.Now().Add(c.sendTimeout)); err != nil {
return err
}
}
if _, err := c.conn.Write(c.buf.Bytes()); err != nil {
return err
}
@ -369,7 +349,7 @@ func (w *streamReader) Read(p []byte) (n int, err error) {
for {
rec := &record{}
var buf []byte
buf, err = rec.read(w.c.rwc)
buf, err = rec.read(w.c.conn)
if err == errInvalidHeaderVersion {
continue
} else if err != nil {
@ -436,7 +416,6 @@ func (c clientCloser) Close() error { return c.f.Close() }
// Request returns a HTTP Response with Header and Body
// from fcgi responder
func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Response, err error) {
r, err := c.Do(p, req)
if err != nil {
return
@ -446,8 +425,10 @@ func (c *FCGIClient) Request(p map[string]string, req io.Reader) (resp *http.Res
tp := textproto.NewReader(rb)
resp = new(http.Response)
if err = c.setReadDeadline(); err != nil {
return
if c.readTimeout != 0 {
if err = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil {
return
}
}
// Parse the response headers.
@ -582,10 +563,6 @@ func (c *FCGIClient) PostFile(p map[string]string, data url.Values, file map[str
return c.Post(p, "POST", bodyType, buf, buf.Len())
}
// ReadTimeout returns the read timeout for future calls that read from the
// fcgi responder.
func (c *FCGIClient) ReadTimeout() time.Duration { return c.readTimeout }
// SetReadTimeout sets the read timeout for future calls that read from the
// fcgi responder. A zero value for t means no timeout will be set.
func (c *FCGIClient) SetReadTimeout(t time.Duration) error {
@ -593,6 +570,13 @@ func (c *FCGIClient) SetReadTimeout(t time.Duration) error {
return nil
}
// SetSendTimeout sets the read timeout for future calls that send data to
// the fcgi responder. A zero value for t means no timeout will be set.
func (c *FCGIClient) SetSendTimeout(t time.Duration) error {
c.sendTimeout = t
return nil
}
// Checks whether chunked is part of the encodings stack
func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }

View file

@ -103,7 +103,7 @@ func (s FastCGIServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
}
func sendFcgi(reqType int, fcgiParams map[string]string, data []byte, posts map[string]string, files map[string]string) (content []byte) {
fcgi, err := Dial("tcp", ipPort, 0)
fcgi, err := DialTimeout("tcp", ipPort, 0)
if err != nil {
log.Println("err:", err)
return

View file

@ -59,7 +59,7 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
return rules, c.ArgErr()
}
rule := Rule{Path: args[0], ReadTimeout: 60 * time.Second}
rule := Rule{Path: args[0], ReadTimeout: 60 * time.Second, SendTimeout: 60 * time.Second}
upstreams := []string{args[1]}
if len(args) == 3 {
@ -144,6 +144,15 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
return rules, err
}
rule.ReadTimeout = readTimeout
case "send_timeout":
if !c.NextArg() {
return rules, c.ArgErr()
}
sendTimeout, err := time.ParseDuration(c.Val())
if err != nil {
return rules, err
}
rule.SendTimeout = sendTimeout
}
}

View file

@ -80,6 +80,7 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}}},
IndexFiles: []string{"index.php"},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi /blog 127.0.0.1:9000 php {
upstream 127.0.0.1:9001
@ -92,6 +93,7 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}, basicDialer{network: "tcp", address: "127.0.0.1:9001", timeout: 60 * time.Second}}},
IndexFiles: []string{"index.php"},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi /blog 127.0.0.1:9000 {
upstream 127.0.0.1:9001
@ -104,6 +106,7 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}, basicDialer{network: "tcp", address: "127.0.0.1:9001", timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi / ` + defaultAddress + ` {
split .html
@ -116,6 +119,7 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi / ` + defaultAddress + ` {
split .html
@ -130,6 +134,7 @@ func TestFastcgiParse(t *testing.T) {
IndexFiles: []string{},
IgnoredSubPaths: []string{"/admin", "/user"},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi / ` + defaultAddress + ` {
pool 0
@ -142,6 +147,7 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 0, network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi / 127.0.0.1:8080 {
upstream 127.0.0.1:9000
@ -155,6 +161,7 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{&persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:8080", timeout: 60 * time.Second}, &persistentDialer{size: 5, network: "tcp", address: "127.0.0.1:9000", timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi / ` + defaultAddress + ` {
split .php
@ -167,6 +174,7 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{`fastcgi / ` + defaultAddress + ` {
connect_timeout 5s
@ -179,7 +187,13 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 5 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{
`fastcgi / ` + defaultAddress + ` { connect_timeout BADVALUE }`,
true,
[]Rule{},
},
{`fastcgi / ` + defaultAddress + ` {
read_timeout 5s
}`,
@ -191,7 +205,31 @@ func TestFastcgiParse(t *testing.T) {
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 5 * time.Second,
SendTimeout: 60 * time.Second,
}}},
{
`fastcgi / ` + defaultAddress + ` { read_timeout BADVALUE }`,
true,
[]Rule{},
},
{`fastcgi / ` + defaultAddress + ` {
send_timeout 5s
}`,
false, []Rule{{
Path: "/",
Address: defaultAddress,
Ext: "",
SplitPath: "",
dialer: &loadBalancingDialer{dialers: []dialer{basicDialer{network: network, address: address, timeout: 60 * time.Second}}},
IndexFiles: []string{},
ReadTimeout: 60 * time.Second,
SendTimeout: 5 * time.Second,
}}},
{
`fastcgi / ` + defaultAddress + ` { send_timeout BADVALUE }`,
true,
[]Rule{},
},
{`fastcgi / {
}`,
@ -251,6 +289,16 @@ func TestFastcgiParse(t *testing.T) {
t.Errorf("Test %d expected %dth FastCGI IgnoredSubPaths to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].IgnoredSubPaths, actualFastcgiConfig.IgnoredSubPaths)
}
if fmt.Sprint(actualFastcgiConfig.ReadTimeout) != fmt.Sprint(test.expectedFastcgiConfig[j].ReadTimeout) {
t.Errorf("Test %d expected %dth FastCGI ReadTimeout to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].ReadTimeout, actualFastcgiConfig.ReadTimeout)
}
if fmt.Sprint(actualFastcgiConfig.SendTimeout) != fmt.Sprint(test.expectedFastcgiConfig[j].SendTimeout) {
t.Errorf("Test %d expected %dth FastCGI SendTimeout to be %s , but got %s",
i, j, test.expectedFastcgiConfig[j].SendTimeout, actualFastcgiConfig.SendTimeout)
}
}
}
}