fastcgi: Only perform extra copy if necessary; added tests

This commit is contained in:
Matthew Holt 2016-02-24 16:41:45 -07:00
parent 367397dbd6
commit 737c7c4372
2 changed files with 79 additions and 33 deletions

View file

@ -72,7 +72,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
// Connect to FastCGI gateway
network, address := rule.parseAddress()
fcgi, err := Dial(network, address)
fcgiBackend, err := Dial(network, address)
if err != nil {
return http.StatusBadGateway, err
}
@ -81,19 +81,19 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
contentLength, _ := strconv.Atoi(r.Header.Get("Content-Length"))
switch r.Method {
case "HEAD":
resp, err = fcgi.Head(env)
resp, err = fcgiBackend.Head(env)
case "GET":
resp, err = fcgi.Get(env)
resp, err = fcgiBackend.Get(env)
case "OPTIONS":
resp, err = fcgi.Options(env)
resp, err = fcgiBackend.Options(env)
case "POST":
resp, err = fcgi.Post(env, r.Header.Get("Content-Type"), r.Body, contentLength)
resp, err = fcgiBackend.Post(env, r.Header.Get("Content-Type"), r.Body, contentLength)
case "PUT":
resp, err = fcgi.Put(env, r.Header.Get("Content-Type"), r.Body, contentLength)
resp, err = fcgiBackend.Put(env, r.Header.Get("Content-Type"), r.Body, contentLength)
case "PATCH":
resp, err = fcgi.Patch(env, r.Header.Get("Content-Type"), r.Body, contentLength)
resp, err = fcgiBackend.Patch(env, r.Header.Get("Content-Type"), r.Body, contentLength)
case "DELETE":
resp, err = fcgi.Delete(env, r.Header.Get("Content-Type"), r.Body, contentLength)
resp, err = fcgiBackend.Delete(env, r.Header.Get("Content-Type"), r.Body, contentLength)
default:
return http.StatusMethodNotAllowed, nil
}
@ -106,29 +106,35 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error)
return http.StatusBadGateway, err
}
// Write the response body to a buffer
// To explicitly set Content-Length
// For FastCGI app that don't set it
var buf bytes.Buffer
io.Copy(&buf, resp.Body)
var responseBody io.Reader = resp.Body
if r.Header.Get("Content-Length") == "" {
// If the upstream app didn't set a Content-Length (shame on them),
// we need to do it to prevent error messages being appended to
// an already-written response, and other problematic behavior.
// So we copy it to a buffer and read its size before flushing
// the response out to the client. See issues #567 and #614.
buf := new(bytes.Buffer)
_, err := io.Copy(buf, resp.Body)
if err != nil {
return http.StatusBadGateway, err
}
w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
responseBody = buf
}
// Write the status code and header fields
writeHeader(w, resp)
// Write the response body
// TODO: If this has an error, the response will already be
// partly written. We should copy out of resp.Body into a buffer
// first, then write it to the response...
_, err = io.Copy(w, &buf)
_, err = io.Copy(w, responseBody)
if err != nil {
return http.StatusBadGateway, err
}
// FastCGI stderr outputs
if fcgi.stderr.Len() != 0 {
if fcgiBackend.stderr.Len() != 0 {
// Remove trailing newline, error logger already does this.
err = LogError(strings.TrimSuffix(fcgi.stderr.String(), "\n"))
err = LogError(strings.TrimSuffix(fcgiBackend.stderr.String(), "\n"))
}
return resp.StatusCode, err

View file

@ -1,13 +1,61 @@
package fastcgi
import (
"net"
"net/http"
"net/http/fcgi"
"net/http/httptest"
"net/url"
"strconv"
"testing"
)
func TestRuleParseAddress(t *testing.T) {
func TestServeHTTPContentLength(t *testing.T) {
testWithBackend := func(body string, setContentLength bool) {
bodyLenStr := strconv.Itoa(len(body))
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("BackendSetsContentLength=%v: Unable to create listener for test: %v", setContentLength, err)
}
defer listener.Close()
go fcgi.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if setContentLength {
w.Header().Set("Content-Length", bodyLenStr)
}
w.Write([]byte(body))
}))
handler := Handler{
Next: nil,
Rules: []Rule{{Path: "/", Address: listener.Addr().String()}},
}
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatalf("BackendSetsContentLength=%v: Unable to create request: %v", setContentLength, err)
}
w := httptest.NewRecorder()
status, err := handler.ServeHTTP(w, r)
if got, want := status, http.StatusOK; got != want {
t.Errorf("BackendSetsContentLength=%v: Expected returned status code to be %d, got %d", setContentLength, want, got)
}
if err != nil {
t.Errorf("BackendSetsContentLength=%v: Expected nil error, got: %v", setContentLength, err)
}
if got, want := w.Header().Get("Content-Length"), bodyLenStr; got != want {
t.Errorf("BackendSetsContentLength=%v: Expected Content-Length to be '%s', got: '%s'", setContentLength, want, got)
}
if got, want := w.Body.String(), body; got != want {
t.Errorf("BackendSetsContentLength=%v: Expected response body to be '%s', got: '%s'", setContentLength, want, got)
}
}
testWithBackend("Backend does NOT set Content-Length", false)
testWithBackend("Backend sets Content-Length", true)
}
func TestRuleParseAddress(t *testing.T) {
getClientTestTable := []struct {
rule *Rule
expectednetwork string
@ -27,28 +75,21 @@ func TestRuleParseAddress(t *testing.T) {
if _, actualaddress := entry.rule.parseAddress(); actualaddress != entry.expectedaddress {
t.Errorf("Unexpected parsed address for address string %v. Got %v, expected %v", entry.rule.Address, actualaddress, entry.expectedaddress)
}
}
}
func TestBuildEnv(t *testing.T) {
buildEnvSingle := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string, t *testing.T) {
h := Handler{}
testBuildEnv := func(r *http.Request, rule Rule, fpath string, envExpected map[string]string) {
var h Handler
env, err := h.buildEnv(r, rule, fpath)
if err != nil {
t.Error("Unexpected error:", err.Error())
}
for k, v := range envExpected {
if env[k] != v {
t.Errorf("Unexpected %v. Got %v, expected %v", k, env[k], v)
}
}
}
rule := Rule{}
@ -80,16 +121,15 @@ func TestBuildEnv(t *testing.T) {
}
// 1. Test for full canonical IPv6 address
buildEnvSingle(&r, rule, fpath, envExpected, t)
testBuildEnv(&r, rule, fpath, envExpected)
// 2. Test for shorthand notation of IPv6 address
r.RemoteAddr = "[::1]:51688"
envExpected["REMOTE_ADDR"] = "[::1]"
buildEnvSingle(&r, rule, fpath, envExpected, t)
testBuildEnv(&r, rule, fpath, envExpected)
// 3. Test for IPv4 address
r.RemoteAddr = "192.168.0.10:51688"
envExpected["REMOTE_ADDR"] = "192.168.0.10"
buildEnvSingle(&r, rule, fpath, envExpected, t)
testBuildEnv(&r, rule, fpath, envExpected)
}