From e0daa39cd3373d26dcf78ae0d30ac24d1df5dd57 Mon Sep 17 00:00:00 2001 From: WeidiDeng Date: Wed, 17 Apr 2024 23:00:37 +0800 Subject: [PATCH] caddyhttp: record num. bytes read when response writer is hijacked (#6173) * record the number of bytes read when response writer is hijacked * record body size when not nil --- modules/caddyhttp/responsewriter.go | 37 +++++++++++++++++++++++++++++ modules/caddyhttp/server.go | 5 ++++ 2 files changed, 42 insertions(+) diff --git a/modules/caddyhttp/responsewriter.go b/modules/caddyhttp/responsewriter.go index 51f672ee..12627d45 100644 --- a/modules/caddyhttp/responsewriter.go +++ b/modules/caddyhttp/responsewriter.go @@ -66,6 +66,8 @@ type responseRecorder struct { size int wroteHeader bool stream bool + + readSize *int } // NewResponseRecorder returns a new ResponseRecorder that can be @@ -240,6 +242,12 @@ func (rr *responseRecorder) FlushError() error { return nil } +// Private interface so it can only be used in this package +// #TODO: maybe export it later +func (rr *responseRecorder) setReadSize(size *int) { + rr.readSize = size +} + func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { //nolint:bodyclose conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack() @@ -249,6 +257,15 @@ func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { // Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not conn = &hijackedConn{conn, rr} brw.Writer.Reset(conn) + + buffered := brw.Reader.Buffered() + if buffered != 0 { + conn.(*hijackedConn).updateReadSize(buffered) + data, _ := brw.Peek(buffered) + brw.Reader.Reset(io.MultiReader(bytes.NewReader(data), conn)) + } else { + brw.Reader.Reset(conn) + } return conn, brw, nil } @@ -258,6 +275,24 @@ type hijackedConn struct { rr *responseRecorder } +func (hc *hijackedConn) updateReadSize(n int) { + if hc.rr.readSize != nil { + *hc.rr.readSize += n + } +} + +func (hc *hijackedConn) Read(p []byte) (int, error) { + n, err := hc.Conn.Read(p) + hc.updateReadSize(n) + return n, err +} + +func (hc *hijackedConn) WriteTo(w io.Writer) (int64, error) { + n, err := io.Copy(w, hc.Conn) + hc.updateReadSize(int(n)) + return n, err +} + func (hc *hijackedConn) Write(p []byte) (int, error) { n, err := hc.Conn.Write(p) hc.rr.size += n @@ -298,4 +333,6 @@ var ( _ io.ReaderFrom = (*ResponseWriterWrapper)(nil) _ io.ReaderFrom = (*responseRecorder)(nil) _ io.ReaderFrom = (*hijackedConn)(nil) + + _ io.WriterTo = (*hijackedConn)(nil) ) diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index 0e88ef26..2418a590 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -326,6 +326,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Body != nil { bodyReader = &lengthReader{Source: r.Body} r.Body = bodyReader + + // should always be true, private interface can only be referenced in the same package + if setReadSizer, ok := wrec.(interface{ setReadSize(*int) }); ok { + setReadSizer.setReadSize(&bodyReader.Length) + } } // capture the original version of the request