caddyhttp: ensure ResponseWriterWrapper and ResponseRecorder use ReadFrom if the underlying response writer implements it. (#5022)

Doing so allows for splice/sendfile optimizations when available.
Fixes #4731

Co-authored-by: flga <flga@users.noreply.github.com>
Co-authored-by: Matthew Holt <mholt@users.noreply.github.com>
This commit is contained in:
fleandro 2022-09-07 21:13:35 +01:00 committed by GitHub
parent 1c9c8f6a13
commit dd9813c65b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 200 additions and 2 deletions

View file

@ -62,6 +62,16 @@ func (rww *ResponseWriterWrapper) Push(target string, opts *http.PushOptions) er
return ErrNotImplemented return ErrNotImplemented
} }
// ReadFrom implements io.ReaderFrom. It simply calls the underlying
// ResponseWriter's ReadFrom method if there is one, otherwise it defaults
// to io.Copy.
func (rww *ResponseWriterWrapper) ReadFrom(r io.Reader) (n int64, err error) {
if rf, ok := rww.ResponseWriter.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(rww.ResponseWriter, r)
}
// HTTPInterfaces mix all the interfaces that middleware ResponseWriters need to support. // HTTPInterfaces mix all the interfaces that middleware ResponseWriters need to support.
type HTTPInterfaces interface { type HTTPInterfaces interface {
http.ResponseWriter http.ResponseWriter
@ -188,9 +198,26 @@ func (rr *responseRecorder) Write(data []byte) (int, error) {
} else { } else {
n, err = rr.buf.Write(data) n, err = rr.buf.Write(data)
} }
if err == nil {
rr.size += n rr.size += n
return n, err
} }
func (rr *responseRecorder) ReadFrom(r io.Reader) (int64, error) {
rr.WriteHeader(http.StatusOK)
var n int64
var err error
if rr.stream {
if rf, ok := rr.ResponseWriter.(io.ReaderFrom); ok {
n, err = rf.ReadFrom(r)
} else {
n, err = io.Copy(rr.ResponseWriter, r)
}
} else {
n, err = rr.buf.ReadFrom(r)
}
rr.size += int(n)
return n, err return n, err
} }
@ -251,4 +278,10 @@ type ShouldBufferFunc func(status int, header http.Header) bool
var ( var (
_ HTTPInterfaces = (*ResponseWriterWrapper)(nil) _ HTTPInterfaces = (*ResponseWriterWrapper)(nil)
_ ResponseRecorder = (*responseRecorder)(nil) _ ResponseRecorder = (*responseRecorder)(nil)
// Implementing ReaderFrom can be such a significant
// optimization that it should probably be required!
// see PR #5022 (25%-50% speedup)
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
_ io.ReaderFrom = (*responseRecorder)(nil)
) )

View file

@ -0,0 +1,165 @@
package caddyhttp
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"testing"
)
type responseWriterSpy interface {
http.ResponseWriter
Written() string
CalledReadFrom() bool
}
var (
_ responseWriterSpy = (*baseRespWriter)(nil)
_ responseWriterSpy = (*readFromRespWriter)(nil)
)
// a barebones http.ResponseWriter mock
type baseRespWriter []byte
func (brw *baseRespWriter) Write(d []byte) (int, error) {
*brw = append(*brw, d...)
return len(d), nil
}
func (brw *baseRespWriter) Header() http.Header { return nil }
func (brw *baseRespWriter) WriteHeader(statusCode int) {}
func (brw *baseRespWriter) Written() string { return string(*brw) }
func (brw *baseRespWriter) CalledReadFrom() bool { return false }
// an http.ResponseWriter mock that supports ReadFrom
type readFromRespWriter struct {
baseRespWriter
called bool
}
func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) {
rf.called = true
return io.Copy(&rf.baseRespWriter, r)
}
func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called }
func TestResponseWriterWrapperReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
wantReadFrom bool
}{
"no ReadFrom": {
responseWriter: &baseRespWriter{},
wantReadFrom: false,
},
"has ReadFrom": {
responseWriter: &readFromRespWriter{},
wantReadFrom: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
// what we expect middlewares to do:
type myWrapper struct {
*ResponseWriterWrapper
}
wrapped := myWrapper{
ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: tt.responseWriter},
}
const srcData = "boo!"
// hides everything but Read, since strings.Reader implements WriteTo it would
// take precedence over our ReadFrom.
src := struct{ io.Reader }{strings.NewReader(srcData)}
fmt.Println(name)
if _, err := io.Copy(wrapped, src); err != nil {
t.Errorf("Copy() err = %v", err)
}
if got := tt.responseWriter.Written(); got != srcData {
t.Errorf("data = %q, want %q", got, srcData)
}
if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
if tt.wantReadFrom {
t.Errorf("ReadFrom() should have been called")
} else {
t.Errorf("ReadFrom() should not have been called")
}
}
})
}
}
func TestResponseRecorderReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
shouldBuffer bool
wantReadFrom bool
}{
"buffered plain": {
responseWriter: &baseRespWriter{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed plain": {
responseWriter: &baseRespWriter{},
shouldBuffer: false,
wantReadFrom: false,
},
"buffered ReadFrom": {
responseWriter: &readFromRespWriter{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed ReadFrom": {
responseWriter: &readFromRespWriter{},
shouldBuffer: false,
wantReadFrom: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
var buf bytes.Buffer
rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool {
return tt.shouldBuffer
})
const srcData = "boo!"
// hides everything but Read, since strings.Reader implements WriteTo it would
// take precedence over our ReadFrom.
src := struct{ io.Reader }{strings.NewReader(srcData)}
if _, err := io.Copy(rr, src); err != nil {
t.Errorf("Copy() err = %v", err)
}
wantStreamed := srcData
wantBuffered := ""
if tt.shouldBuffer {
wantStreamed = ""
wantBuffered = srcData
}
if got := tt.responseWriter.Written(); got != wantStreamed {
t.Errorf("streamed data = %q, want %q", got, wantStreamed)
}
if got := buf.String(); got != wantBuffered {
t.Errorf("buffered data = %q, want %q", got, wantBuffered)
}
if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
if tt.wantReadFrom {
t.Errorf("ReadFrom() should have been called")
} else {
t.Errorf("ReadFrom() should not have been called")
}
}
})
}
}