fastcgi: Only perform extra copy if necessary; added tests

This commit is contained in:
Matthew Holt 2016-02-24 16:41:45 -07:00
parent 367397dbd6
commit 737c7c4372
2 changed files with 79 additions and 33 deletions

View file

@ -72,7 +72,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// Connect to FastCGI gateway // Connect to FastCGI gateway
network, address := rule.parseAddress() network, address := rule.parseAddress()
fcgi, err := Dial(network, address) fcgiBackend, err := Dial(network, address)
if err != nil { if err != nil {
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
@ -81,19 +81,19 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length")) contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length"))
switch r.Method { switch r.Method {
case "HEAD": case "HEAD":
resp, err = fcgi.Head(env) resp, err = fcgiBackend.Head(env)
case "GET": case "GET":
resp, err = fcgi.Get(env) resp, err = fcgiBackend.Get(env)
case "OPTIONS": case "OPTIONS":
resp, err = fcgi.Options(env) resp, err = fcgiBackend.Options(env)
case "POST": case "POST":
resp, err = fcgi.Post(env, r.Header.Get("Content-Type"), r.Body, contentLength) resp, err = fcgiBackend.Post(env, r.Header.Get("Content-Type"), r.Body, contentLength)
case "PUT": case "PUT":
resp, err = fcgi.Put(env, r.Header.Get("Content-Type"), r.Body, contentLength) resp, err = fcgiBackend.Put(env, r.Header.Get("Content-Type"), r.Body, contentLength)
case "PATCH": case "PATCH":
resp, err = fcgi.Patch(env, r.Header.Get("Content-Type"), r.Body, contentLength) resp, err = fcgiBackend.Patch(env, r.Header.Get("Content-Type"), r.Body, contentLength)
case "DELETE": case "DELETE":
resp, err = fcgi.Delete(env, r.Header.Get("Content-Type"), r.Body, contentLength) resp, err = fcgiBackend.Delete(env, r.Header.Get("Content-Type"), r.Body, contentLength)
default: default:
return http.StatusMethodNotAllowed, nil return http.StatusMethodNotAllowed, nil
} }
@ -106,29 +106,35 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
// Write the response body to a buffer var responseBody io.Reader = resp.Body
// To explicitly set Content-Length
// For FastCGI app that don't set it
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
if r.Header.Get("Content-Length") == "" { if r.Header.Get("Content-Length") == "" {
// If the upstream app didn't set a Content-Length (shame on them),
// we need to do it to prevent error messages being appended to
// an already-written response, and other problematic behavior.
// So we copy it to a buffer and read its size before flushing
// the response out to the client. See issues #567 and #614.
buf := new(bytes.Buffer)
_, err := io.Copy(buf, resp.Body)
if err != nil {
return http.StatusBadGateway, err
}
w.Header().Set("Content-Length", strconv.Itoa(buf.Len())) w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
responseBody = buf
} }
// Write the status code and header fields
writeHeader(w, resp) writeHeader(w, resp)
// Write the response body // Write the response body
// TODO: If this has an error, the response will already be _, err = io.Copy(w, responseBody)
// partly written. We should copy out of resp.Body into a buffer
// first, then write it to the response...
_, err = io.Copy(w, &buf)
if err != nil { if err != nil {
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
// FastCGI stderr outputs // FastCGI stderr outputs
if fcgi.stderr.Len() != 0 { if fcgiBackend.stderr.Len() != 0 {
// Remove trailing newline, error logger already does this. // Remove trailing newline, error logger already does this.
err = LogError(strings.TrimSuffix(fcgi.stderr.String(), "\n")) err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
} }
return resp.StatusCode, err return resp.StatusCode, err

View file

@ -1,13 +1,61 @@
package fastcgi package fastcgi
import ( import (
"net"
"net/http" "net/http"
"net/http/fcgi"
"net/http/httptest"
"net/url" "net/url"
"strconv"
"testing" "testing"
) )
func TestRuleParseAddress(t *testing.T) { func TestServeHTTPContentLength(t *testing.T) {
testWithBackend := func(body string, setContentLength bool) {
bodyLenStr := strconv.Itoa(len(body))
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("BackendSetsContentLength=%v: Unable to create listener for test: %v", setContentLength, err)
}
defer listener.Close()
go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if setContentLength {
w.Header().Set("Content-Length", bodyLenStr)
}
w.Write([]byte(body))
}))
handler := Handler{
Next: nil,
Rules: []Rule{{Path: "/", Address: listener.Addr().String()}},
}
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("BackendSetsContentLength=%v: Unable to create request: %v", setContentLength, err)
}
w := httptest.NewRecorder()
status, err := handler.ServeHTTP(w, r)
if got, want := status, http.StatusOK; got != want {
t.Errorf("BackendSetsContentLength=%v: Expected returned status code to be %d, got %d", setContentLength, want, got)
}
if err != nil {
t.Errorf("BackendSetsContentLength=%v: Expected nil error, got: %v", setContentLength, err)
}
if got, want := w.Header().Get("Content-Length"), bodyLenStr; got != want {
t.Errorf("BackendSetsContentLength=%v: Expected Content-Length to be '%s', got: '%s'", setContentLength, want, got)
}
if got, want := w.Body.String(), body; got != want {
t.Errorf("BackendSetsContentLength=%v: Expected response body to be '%s', got: '%s'", setContentLength, want, got)
}
}
testWithBackend("Backend does NOT set Content-Length", false)
testWithBackend("Backend sets Content-Length", true)
}
func TestRuleParseAddress(t *testing.T) {
getClientTestTable := []struct { getClientTestTable := []struct {
rule *Rule rule *Rule
expectednetwork string expectednetwork string
@ -27,28 +75,21 @@ func TestRuleParseAddress(t *testing.T) {
if _, actualaddress := entry.rule.parseAddress(); actualaddress != entry.expectedaddress { if _, actualaddress := entry.rule.parseAddress(); actualaddress != entry.expectedaddress {
t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address, actualaddress, entry.expectedaddress) t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address, actualaddress, entry.expectedaddress)
} }
} }
} }
func TestBuildEnv(t *testing.T) { func TestBuildEnv(t *testing.T) {
testBuildEnv := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string) {
buildEnvSingle := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string, t *testing.T) { var h Handler
h := Handler{}
env, err := h.buildEnv(r, rule, fpath) env, err := h.buildEnv(r, rule, fpath)
if err != nil { if err != nil {
t.Error("Unexpected error:", err.Error()) t.Error("Unexpected error:", err.Error())
} }
for k, v := range envExpected { for k, v := range envExpected {
if env[k] != v { if env[k] != v {
t.Errorf("Unexpected %v. Got %v, expected %v", k, env[k], v) t.Errorf("Unexpected %v. Got %v, expected %v", k, env[k], v)
} }
} }
} }
rule := Rule{} rule := Rule{}
@ -80,16 +121,15 @@ func TestBuildEnv(t *testing.T) {
} }
// 1. Test for full canonical IPv6 address // 1. Test for full canonical IPv6 address
buildEnvSingle(&r, rule, fpath, envExpected, t) testBuildEnv(&r, rule, fpath, envExpected)
// 2. Test for shorthand notation of IPv6 address // 2. Test for shorthand notation of IPv6 address
r.RemoteAddr = "[::1]:51688" r.RemoteAddr = "[::1]:51688"
envExpected["REMOTE_ADDR"] = "[::1]" envExpected["REMOTE_ADDR"] = "[::1]"
buildEnvSingle(&r, rule, fpath, envExpected, t) testBuildEnv(&r, rule, fpath, envExpected)
// 3. Test for IPv4 address // 3. Test for IPv4 address
r.RemoteAddr = "192.168.0.10:51688" r.RemoteAddr = "192.168.0.10:51688"
envExpected["REMOTE_ADDR"] = "192.168.0.10" envExpected["REMOTE_ADDR"] = "192.168.0.10"
buildEnvSingle(&r, rule, fpath, envExpected, t) testBuildEnv(&r, rule, fpath, envExpected)
} }