initial implementation of buffering for requests with unknown content length for fastcgi

This commit is contained in:
WeidiDeng 2024-10-17 15:34:13 +08:00
parent c6f2979986
commit 4715bbfd64
No known key found for this signature in database
GPG key ID: 25F87CE1741EC7CD
5 changed files with 342 additions and 12 deletions

View file

@ -19,6 +19,8 @@ import (
"net/http"
"strconv"
"github.com/dustin/go-humanize"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
@ -43,7 +45,10 @@ func init() {
// dial_timeout <duration>
// read_timeout <duration>
// write_timeout <duration>
// capture_stderr
// body_buffer_disabled
// body_buffer_memory_limit <size>
// file_buffer_size_limit <size>
// file_buffer_filepath <path>
// }
func (t *Transport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
d.Next() // consume transport name
@ -113,6 +118,35 @@ func (t *Transport) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
t.CaptureStderr = true
case "body_buffer_disabled":
t.BodyBufferDisabled = true
case "body_buffer_memory_limit":
if !d.NextArg() {
return d.ArgErr()
}
size, err := humanize.ParseBytes(d.Val())
if err != nil {
return d.Errf("bad buffer size %s: %v", d.Val(), err)
}
t.BodyBufferMemoryLimit = int64(size)
case "file_buffer_size_limit":
if !d.NextArg() {
return d.ArgErr()
}
size, err := humanize.ParseBytes(d.Val())
if err != nil {
return d.Errf("bad buffer size %s: %v", d.Val(), err)
}
t.FileBufferSizeLimit = int64(size)
case "file_buffer_filepath":
if !d.NextArg() {
return d.ArgErr()
}
t.FileBufferFilepath = d.Val()
default:
return d.Errf("unrecognized subdirective %s", d.Val())
}
@ -294,6 +328,35 @@ func parsePHPFastCGI(h httpcaddyfile.Helper) ([]httpcaddyfile.ConfigValue, error
args := dispenser.RemainingArgs()
dispenser.DeleteN(len(args) + 1)
fcgiTransport.CaptureStderr = true
case "body_buffer_disabled":
fcgiTransport.BodyBufferDisabled = true
case "body_buffer_memory_limit":
if !dispenser.NextArg() {
return nil, dispenser.ArgErr()
}
size, err := humanize.ParseBytes(dispenser.Val())
if err != nil {
return nil, dispenser.Errf("bad buffer size %s: %v", dispenser.Val(), err)
}
fcgiTransport.BodyBufferMemoryLimit = int64(size)
case "file_buffer_size_limit":
if !dispenser.NextArg() {
return nil, dispenser.ArgErr()
}
size, err := humanize.ParseBytes(dispenser.Val())
if err != nil {
return nil, dispenser.Errf("bad buffer size %s: %v", dispenser.Val(), err)
}
fcgiTransport.FileBufferSizeLimit = int64(size)
case "file_buffer_filepath":
if !dispenser.NextArg() {
return nil, dispenser.ArgErr()
}
fcgiTransport.FileBufferFilepath = dispenser.Val()
}
}
}

View file

@ -128,9 +128,11 @@ var pad [maxPad]byte
type client struct {
rwc net.Conn
// keepAlive bool // TODO: implement
reqID uint16
stderr bool
logger *zap.Logger
reqID uint16
stderr bool
logger *zap.Logger
buffer bool
bufferFunc func(io.Reader) (int64, io.ReadCloser, error)
}
// Do made the request and returns a io.Reader that translates the data read
@ -155,6 +157,10 @@ func (c *client) Do(p map[string]string, req io.Reader) (r io.Reader, err error)
writer.recType = Stdin
if req != nil {
_, err = io.Copy(writer, req)
// body length mismatch
if lr, ok := req.(*io.LimitedReader); ok && lr.N > 0 {
return nil, io.ErrUnexpectedEOF
}
if err != nil {
return nil, err
}
@ -197,9 +203,67 @@ func (f clientCloser) Close() error {
return f.rwc.Close()
}
// create a response that describes the error message, it's passed to the client.
// caller should close the connection to fastcgi server
func newErrorResponse(status int) *http.Response {
statusText := http.StatusText(status)
resp := &http.Response{
Status: statusText,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(statusText)),
ContentLength: int64(len(statusText)),
}
resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10))
resp.Header.Set("Content-Type", "text/plain; charset=utf-8")
resp.Header.Set("X-Content-Type-Options", "nosniff")
return resp
}
// check for invalid content_length to determine if the request should be buffered
func checkContentLength(p map[string]string) (int64, bool) {
clStr, ok := p["CONTENT_LENGTH"]
if !ok {
return 0, false
}
cl, err := strconv.ParseInt(clStr, 10, 64)
if err != nil || cl < 0 {
return 0, false
}
return cl, true
}
// Request returns a HTTP Response with Header and Body
// from fcgi responder
func (c *client) Request(p map[string]string, req io.Reader) (resp *http.Response, err error) {
// defer closing the request body if it's an io.Closer
if closer, ok := req.(io.Closer); ok {
defer closer.Close()
}
// check for content_length and buffer the request if needed
cl, ok := checkContentLength(p)
if !ok {
// buffering disabled
if !c.buffer {
c.rwc.Close()
return newErrorResponse(http.StatusLengthRequired), nil
} else {
// buffer the request
size, rc, err := c.bufferFunc(req)
if err != nil {
if err == errFileBufferExceeded {
c.rwc.Close()
return newErrorResponse(http.StatusRequestEntityTooLarge), nil
}
return nil, err
}
defer rc.Close()
p["CONTENT_LENGTH"] = strconv.FormatInt(size, 10)
}
} else {
req = io.LimitReader(req, cl)
}
r, err := c.Do(p, req)
if err != nil {
return
@ -302,7 +366,7 @@ func (c *client) Post(p map[string]string, method string, bodyType string, body
// PostForm issues a POST to the fcgi responder, with form
// as a string key to a list values (url.Values)
func (c *client) PostForm(p map[string]string, data url.Values) (resp *http.Response, err error) {
body := bytes.NewReader([]byte(data.Encode()))
body := strings.NewReader(data.Encode())
return c.Post(p, "POST", "application/x-www-form-urlencoded", body, int64(body.Len()))
}

View file

@ -15,10 +15,14 @@
package fastcgi
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
@ -80,8 +84,18 @@ type Transport struct {
// be used instead.
CaptureStderr bool `json:"capture_stderr,omitempty"`
serverSoftware string
logger *zap.Logger
// disable buffering of the request body that doesn't have a content length
BodyBufferDisabled bool `json:"body_buffer_disabled,omitempty"`
// memory limit for buffering the request body, the rest will be buffered by temporary files
BodyBufferMemoryLimit int64 `json:"body_buffer_memory_limit,omitempty"`
// total disk storage allowed by the request body buffer
FileBufferSizeLimit int64 `json:"file_buffer_size_limit,omitempty"`
// the path to store the temporary files for the request body buffer
FileBufferFilepath string `json:"file_buffer_filepath"`
serverSoftware string
logger *zap.Logger
tempFileLimiter *fileQuotaLimiter
}
// CaddyModule returns the Caddy module information.
@ -92,6 +106,15 @@ func (Transport) CaddyModule() caddy.ModuleInfo {
}
}
const (
defaultDialTimeout = 3 * time.Second
// nginx default for 64bit platforms
// https://nginx.org/en/docs/http/ngx_http_core_module.html#client_body_buffer_size
defaultMemBufferSize = 1 << 14 // 16KB
// nginx doesn't have an option to limit the total file buffer size
defaultFileBufferSize = 100 << 20 // 100MB
)
// Provision sets up t.
func (t *Transport) Provision(ctx caddy.Context) error {
t.logger = ctx.Logger()
@ -106,12 +129,143 @@ func (t *Transport) Provision(ctx caddy.Context) error {
// Set a relatively short default dial timeout.
// This is helpful to make load-balancer retries more speedy.
if t.DialTimeout == 0 {
t.DialTimeout = caddy.Duration(3 * time.Second)
t.DialTimeout = caddy.Duration(defaultDialTimeout)
}
if !t.BodyBufferDisabled {
if t.FileBufferFilepath == "" {
t.FileBufferFilepath = os.TempDir()
}
// test if temporary file can be created
file, err := os.CreateTemp(t.FileBufferFilepath, "caddy-fastcgi-buffer-")
if err != nil {
return fmt.Errorf("failed to create temporary file: %v", err)
}
file.Close()
os.Remove(file.Name())
if t.BodyBufferMemoryLimit == 0 {
t.BodyBufferMemoryLimit = defaultMemBufferSize
}
if t.FileBufferSizeLimit == 0 {
t.FileBufferSizeLimit = defaultFileBufferSize
}
t.tempFileLimiter = newFileQuotaLimiter(t.FileBufferSizeLimit)
}
return nil
}
type bufferedBody struct {
memBuf *bytes.Buffer
fileBuf *os.File
filesize int64
tempFileLimiter *fileQuotaLimiter
}
func (b *bufferedBody) Read(p []byte) (int, error) {
if b.memBuf != nil {
if b.memBuf.Len() != 0 {
return b.memBuf.Read(p)
}
bufPool.Put(b.memBuf)
b.memBuf = nil
}
if b.fileBuf != nil {
n, err := b.fileBuf.Read(p)
if err != nil {
// close the file and remove it
b.fileBuf.Close()
os.Remove(b.fileBuf.Name())
b.tempFileLimiter.release(b.filesize)
b.fileBuf = nil
return n, err
}
}
return 0, io.EOF
}
func (b *bufferedBody) Close() error {
if b.memBuf != nil {
bufPool.Put(b.memBuf)
b.memBuf = nil
}
if b.fileBuf != nil {
b.fileBuf.Close()
os.Remove(b.fileBuf.Name())
b.tempFileLimiter.release(b.filesize)
b.fileBuf = nil
}
return nil
}
var errFileBufferExceeded = errors.New("temporary file buffer limit exceeded")
func (t Transport) bufferBodyToFile(file *os.File, req io.Reader) (int64, error) {
buf := streamingBufPool.Get().(*[]byte)
defer streamingBufPool.Put(buf)
var size int64
for {
reserved := t.tempFileLimiter.acquire(readBufSize)
if !reserved {
return size, errFileBufferExceeded
}
n, er := req.Read(*buf)
if n > 0 {
nw, ew := file.Write((*buf)[:n])
size += int64(nw)
t.tempFileLimiter.release(int64(readBufSize - nw))
if ew != nil {
return size, ew
}
}
if er != nil {
if er == io.EOF {
return size, nil
}
return size, er
}
}
}
func (t Transport) bufferBody(req io.Reader) (int64, io.ReadCloser, error) {
if closer, ok := req.(io.Closer); ok {
defer closer.Close()
}
memBuf := bufPool.Get().(*bytes.Buffer)
memBuf.Reset()
size, err := io.CopyN(memBuf, req, t.BodyBufferMemoryLimit)
var body bufferedBody // should be closed in case buffering fails
body.memBuf = memBuf
// error while reading the body
if err != nil {
// fully buffered in memory
if err == io.EOF {
return size, &body, nil
}
body.Close()
return 0, nil, err
}
// temporary file is needed here.
fileBuf, err := os.CreateTemp(t.FileBufferFilepath, "caddy-fastcgi-buffer-")
if err != nil {
body.Close()
return 0, nil, err
}
body.fileBuf = fileBuf
// buffer the rest of the body to the file
fSize, err := t.bufferBodyToFile(fileBuf, req)
body.filesize = fSize
if err != nil {
body.Close()
return 0, nil, err
}
return size + fSize, &body, nil
}
// RoundTrip implements http.RoundTripper.
func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
server := r.Context().Value(caddyhttp.ServerCtxKey).(*caddyhttp.Server)
@ -171,10 +325,12 @@ func (t Transport) RoundTrip(r *http.Request) (*http.Response, error) {
// create the client that will facilitate the protocol
client := client{
rwc: conn,
reqID: 1,
logger: logger,
stderr: t.CaptureStderr,
rwc: conn,
reqID: 1,
logger: logger,
stderr: t.CaptureStderr,
buffer: !t.BodyBufferDisabled,
bufferFunc: t.bufferBody,
}
// read/write timeouts

View file

@ -24,3 +24,16 @@ var bufPool = sync.Pool{
return new(bytes.Buffer)
},
}
const readBufSize = 4096
var streamingBufPool = sync.Pool{
New: func() any {
// The Pool's New function should generally only return pointer
// types, since a pointer can be put into the return interface
// value without an allocation
// - (from the package docs)
b := make([]byte, readBufSize)
return &b
},
}

View file

@ -0,0 +1,34 @@
package fastcgi
import "sync"
type fileQuotaLimiter struct {
maxUsage int64
currentUsage int64
mu sync.Mutex
}
func newFileQuotaLimiter(maxUsage int64) *fileQuotaLimiter {
return &fileQuotaLimiter{
maxUsage: maxUsage,
}
}
func (l *fileQuotaLimiter) acquire(n int64) bool {
l.mu.Lock()
defer l.mu.Unlock()
if l.currentUsage+n > l.maxUsage {
return false
}
l.currentUsage += n
return true
}
func (l *fileQuotaLimiter) release(n int64) {
l.mu.Lock()
defer l.mu.Unlock()
l.currentUsage -= n
}