reverseproxy: Log status code and byte count for websockets (#5140)

* log response size for websocket request

* record size when using hijack bufio.Writer
This commit is contained in:
WeidiDeng 2023-02-07 07:14:59 +08:00 committed by GitHub
parent 12bcbe2c49
commit c77a6bea66
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 14 deletions

View file

@ -211,11 +211,7 @@ func (rr *responseRecorder) ReadFrom(r io.Reader) (int64, error) {
var n int64 var n int64
var err error var err error
if rr.stream { if rr.stream {
if rf, ok := rr.ResponseWriter.(io.ReaderFrom); ok { n, err = rr.ResponseWriterWrapper.ReadFrom(r)
n, err = rf.ReadFrom(r)
} else {
n, err = io.Copy(rr.ResponseWriter, r)
}
} else { } else {
n, err = rr.buf.ReadFrom(r) n, err = rr.buf.ReadFrom(r)
} }
@ -260,6 +256,35 @@ func (rr *responseRecorder) WriteResponse() error {
return err return err
} }
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
conn, brw, err := rr.ResponseWriterWrapper.Hijack()
if err != nil {
return nil, nil, err
}
// Per http documentation, returned bufio.Writer is empty, but bufio.Read maybe not
conn = &hijackedConn{conn, rr}
brw.Writer.Reset(conn)
return conn, brw, nil
}
// used to track the size of hijacked response writers
type hijackedConn struct {
net.Conn
rr *responseRecorder
}
func (hc *hijackedConn) Write(p []byte) (int, error) {
n, err := hc.Conn.Write(p)
hc.rr.size += n
return n, err
}
func (hc *hijackedConn) ReadFrom(r io.Reader) (int64, error) {
n, err := io.Copy(hc.Conn, r)
hc.rr.size += int(n)
return n, err
}
// ResponseRecorder is a http.ResponseWriter that records // ResponseRecorder is a http.ResponseWriter that records
// responses instead of writing them to the client. See // responses instead of writing them to the client. See
// docs for NewResponseRecorder for proper usage. // docs for NewResponseRecorder for proper usage.

View file

@ -74,27 +74,24 @@ func (h Handler) handleUpgradeResponse(logger *zap.Logger, rw http.ResponseWrite
}() }()
defer close(backConnCloseCh) defer close(backConnCloseCh)
// write header first, response headers should not be counted in size
// like the rest of handler chain.
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
logger.Debug("upgrading connection") logger.Debug("upgrading connection")
conn, brw, err := hj.Hijack() conn, brw, err := hj.Hijack()
if err != nil { if err != nil {
h.logger.Error("hijack failed on protocol switch", zap.Error(err)) h.logger.Error("hijack failed on protocol switch", zap.Error(err))
return return
} }
defer conn.Close()
start := time.Now() start := time.Now()
defer func() { defer func() {
conn.Close()
logger.Debug("connection closed", zap.Duration("duration", time.Since(start))) logger.Debug("connection closed", zap.Duration("duration", time.Since(start)))
}() }()
copyHeader(rw.Header(), res.Header)
res.Header = rw.Header()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
h.logger.Debug("response write", zap.Error(err))
return
}
if err := brw.Flush(); err != nil { if err := brw.Flush(); err != nil {
h.logger.Debug("response flush", zap.Error(err)) h.logger.Debug("response flush", zap.Error(err))
return return