mox/http/webserver.go
Mechiel Lukkien 259928ab62
add reverse proxying websocket connections
if we recognize that a request for a WebForward is trying to turn the
connection into a websocket, we forward it to the backend and check if the
backend understands the websocket request. if so, we pass back the upgrade
response and get out of the way, copying bytes between the two. we do log the
total amount of bytes read from the client and written to the client. if the
backend doesn't respond with a websocke response, or an invalid one, we respond
with a regular non-websocket response. and we log details about the failed
connection, should help with debugging and any bug reports.

we don't try to parse the websocket framing, that's between the client and the
backend.  we could try to parse it, in part to protect the backend from bad
frames, but it would be a lot of work and could be brittle in the face of
extensions.

this doesn't yet handle websocket connections when a http proxy is configured.
we'll implement it when someone needs it. we do recognize it and fail the
connection.

for issue #25
2023-05-30 22:11:31 +02:00

815 lines
26 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package http
import (
"bufio"
"bytes"
"context"
"crypto/sha1"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
htmltemplate "html/template"
"io"
golog "log"
"net"
"net/http"
"net/http/httputil"
"net/textproto"
"net/url"
"os"
"path/filepath"
"sort"
"strings"
"syscall"
"time"
"github.com/mjl-/mox/config"
"github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mlog"
"github.com/mjl-/mox/mox-"
"github.com/mjl-/mox/moxio"
)
func recvid(r *http.Request) string {
cid := mox.CidFromCtx(r.Context())
if cid <= 0 {
return ""
}
return " (id " + mox.ReceivedID(cid) + ")"
}
// WebHandle serves an HTTP request by going through the list of WebHandlers,
// check if there is a domain+path match, and running the handler if so.
// WebHandle runs after the built-in handlers for mta-sts, autoconfig, etc.
// If no handler matched, false is returned.
// WebHandle sets w.Name to that of the matching handler.
func WebHandle(w *loggingWriter, r *http.Request, host dns.Domain) (handled bool) {
redirects, handlers := mox.Conf.WebServer()
for from, to := range redirects {
if host != from {
continue
}
u := r.URL
u.Scheme = "https"
u.Host = to.Name()
w.Handler = "(domainredirect)"
http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
return true
}
for _, h := range handlers {
if host != h.DNSDomain {
continue
}
loc := h.Path.FindStringIndex(r.URL.Path)
if loc == nil {
continue
}
s := loc[0]
e := loc[1]
path := r.URL.Path[s:e]
if r.TLS == nil && !h.DontRedirectPlainHTTP {
u := *r.URL
u.Scheme = "https"
u.Host = h.DNSDomain.Name()
w.Handler = h.Name
http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
return true
}
if h.WebStatic != nil && HandleStatic(h.WebStatic, w, r) {
w.Handler = h.Name
return true
}
if h.WebRedirect != nil && HandleRedirect(h.WebRedirect, w, r) {
w.Handler = h.Name
return true
}
if h.WebForward != nil && HandleForward(h.WebForward, w, r, path) {
w.Handler = h.Name
return true
}
}
return false
}
var lsTemplate = htmltemplate.Must(htmltemplate.New("ls").Parse(`<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>ls</title>
<style>
body, html { padding: 1em; font-size: 16px; }
* { font-size: inherit; font-family: ubuntu, lato, sans-serif; margin: 0; padding: 0; box-sizing: border-box; }
h1 { margin-bottom: 1ex; font-size: 1.2rem; }
table td, table th { padding: .2em .5em; }
table > tbody > tr:nth-child(odd) { background-color: #f8f8f8; }
[title] { text-decoration: underline; text-decoration-style: dotted; }
</style>
</head>
<body>
<h1>ls</h1>
<table>
<thead>
<tr>
<th>Size in MB</th>
<th>Modified (UTC)</th>
<th>Name</th>
</tr>
</thead>
<tbody>
{{ if not .Files }}
<tr><td colspan="3">No files.</td></tr>
{{ end }}
{{ range .Files }}
<tr>
<td title="{{ .Size }} bytes" style="text-align: right">{{ .SizeReadable }}{{ if .SizePad }}<span style="visibility:hidden">.</span>{{ end }}</td>
<td>{{ .Modified }}</td>
<td><a style="display: block" href="{{ .Name }}">{{ .Name }}</a></td>
</tr>
{{ end }}
</tbody>
</table>
</body>
</html>
`))
// HandleStatic serves static files. If a directory is requested and the URL
// path doesn't end with a slash, a response with a redirect to the URL path with trailing
// slash is written. If a directory is requested and an index.html exists, that
// file is returned. Otherwise, for directories with ListFiles configured, a
// directory listing is returned.
func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (handled bool) {
log := func() *mlog.Log {
return xlog.WithContext(r.Context())
}
if r.Method != "GET" && r.Method != "HEAD" {
if h.ContinueNotFound {
// Give another handler that is presumbly configured, for the same path, a chance.
// E.g. an app that may generate this file for future requests to pick up.
return false
}
http.Error(w, "405 - method not allowed", http.StatusMethodNotAllowed)
return true
}
var fspath string
if h.StripPrefix != "" {
if !strings.HasPrefix(r.URL.Path, h.StripPrefix) {
if h.ContinueNotFound {
// We haven't handled this request, try a next WebHandler in the list.
return false
}
http.NotFound(w, r)
return true
}
fspath = filepath.Join(h.Root, strings.TrimPrefix(r.URL.Path, h.StripPrefix))
} else {
fspath = filepath.Join(h.Root, r.URL.Path)
}
// fspath will not have a trailing slash anymore, we'll correct for it
// later when the path turns out to be file instead of a directory.
serveFile := func(name string, mtime time.Time, content *os.File) {
// ServeContent only sets a content-type if not already present in the response headers.
hdr := w.Header()
for k, v := range h.ResponseHeaders {
hdr.Add(k, v)
}
http.ServeContent(w, r, name, mtime, content)
}
f, err := os.Open(fspath)
if err != nil {
if os.IsNotExist(err) || errors.Is(err, syscall.ENOTDIR) {
if h.ContinueNotFound {
// We haven't handled this request, try a next WebHandler in the list.
return false
}
http.NotFound(w, r)
return true
} else if os.IsPermission(err) {
// If we tried opening a directory, we may not have permission to read it, but
// still access files inside it (execute bit), such as index.html. So try to serve it.
index, err := os.Open(filepath.Join(fspath, "index.html"))
if err == nil {
defer index.Close()
var ifi os.FileInfo
ifi, err = index.Stat()
if err != nil {
log().Errorx("stat index.html in directory we cannot list", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
return true
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
serveFile("index.html", ifi.ModTime(), index)
return true
}
http.Error(w, "403 - permission denied", http.StatusForbidden)
return true
}
log().Errorx("open file for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
return true
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
log().Errorx("stat file for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
return true
}
// Redirect if the local path is a directory.
if fi.IsDir() && !strings.HasSuffix(r.URL.Path, "/") {
http.Redirect(w, r, r.URL.Path+"/", http.StatusTemporaryRedirect)
return true
} else if !fi.IsDir() && strings.HasSuffix(r.URL.Path, "/") {
if h.ContinueNotFound {
return false
}
http.NotFound(w, r)
return true
}
if fi.IsDir() {
index, err := os.Open(filepath.Join(fspath, "index.html"))
if err != nil && os.IsPermission(err) {
http.Error(w, "403 - permission denied", http.StatusForbidden)
return true
} else if err != nil && os.IsNotExist(err) && !h.ListFiles {
if h.ContinueNotFound {
return false
}
http.Error(w, "403 - permission denied", http.StatusForbidden)
return true
} else if err == nil {
defer index.Close()
var ifi os.FileInfo
ifi, err = index.Stat()
if err == nil {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
serveFile("index.html", ifi.ModTime(), index)
return true
}
}
if !os.IsNotExist(err) {
log().Errorx("stat for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
return true
}
type File struct {
Name string
Size int64
SizeReadable string
SizePad bool // Whether the size needs padding because it has no decimal point.
Modified string
}
files := []File{}
if r.URL.Path != "/" {
files = append(files, File{"..", 0, "", false, ""})
}
for {
l, err := f.Readdir(1000)
for _, e := range l {
mb := float64(e.Size()) / (1024 * 1024)
var size string
var sizepad bool
if !e.IsDir() {
if mb >= 10 {
size = fmt.Sprintf("%d", int64(mb))
sizepad = true
} else {
size = fmt.Sprintf("%.2f", mb)
}
}
const dateTime = "2006-01-02 15:04:05" // time.DateTime, but only since go1.20.
modified := e.ModTime().UTC().Format(dateTime)
f := File{e.Name(), e.Size(), size, sizepad, modified}
if e.IsDir() {
f.Name += "/"
}
files = append(files, f)
}
if err == io.EOF {
break
} else if err != nil {
log().Errorx("reading directory for file listing", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
return true
}
}
sort.Slice(files, func(i, j int) bool {
return files[i].Name < files[j].Name
})
hdr := w.Header()
hdr.Set("Content-Type", "text/html; charset=utf-8")
for k, v := range h.ResponseHeaders {
if !strings.EqualFold(k, "content-type") {
hdr.Add(k, v)
}
}
err = lsTemplate.Execute(w, map[string]any{"Files": files})
if err != nil && !moxio.IsClosed(err) {
log().Errorx("executing directory listing template", err)
}
return true
}
serveFile(fspath, fi.ModTime(), f)
return true
}
// HandleRedirect writes a response with an HTTP redirect.
func HandleRedirect(h *config.WebRedirect, w http.ResponseWriter, r *http.Request) (handled bool) {
var dstpath string
if h.OrigPath == nil {
// No path rewrite necessary.
dstpath = r.URL.Path
} else if !h.OrigPath.MatchString(r.URL.Path) {
http.NotFound(w, r)
return true
} else {
dstpath = h.OrigPath.ReplaceAllString(r.URL.Path, h.ReplacePath)
}
u := *r.URL
u.Opaque = ""
u.RawPath = ""
u.OmitHost = false
if h.URL != nil {
u.Scheme = h.URL.Scheme
u.Host = h.URL.Host
u.ForceQuery = h.URL.ForceQuery
u.RawQuery = h.URL.RawQuery
u.Fragment = h.URL.Fragment
if r.URL.RawQuery != "" {
if u.RawQuery != "" {
u.RawQuery += "&"
}
u.RawQuery += r.URL.RawQuery
}
}
u.Path = dstpath
code := http.StatusPermanentRedirect
if h.StatusCode != 0 {
code = h.StatusCode
}
// If we would be redirecting to the same scheme,host,path, we would get here again
// causing a redirect loop. Instead, this causes this redirect to not match,
// allowing to try the next WebHandler. This can be used to redirect all plain http
// requests to https.
reqscheme := "http"
if r.TLS != nil {
reqscheme = "https"
}
if reqscheme == u.Scheme && r.Host == u.Host && r.URL.Path == u.Path {
return false
}
http.Redirect(w, r, u.String(), code)
return true
}
// HandleForward handles a request by forwarding it to another webserver and
// passing the response on. I.e. a reverse proxy. It handles websocket
// connections by monitoring the websocket handshake and then just passing along the
// websocket frames.
func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
log := func() *mlog.Log {
return xlog.WithContext(r.Context())
}
xr := *r
r = &xr
if h.StripPath {
u := *r.URL
u.Path = r.URL.Path[len(path):]
if !strings.HasPrefix(u.Path, "/") {
u.Path = "/" + u.Path
}
u.RawPath = ""
r.URL = &u
}
// Remove any forwarded headers passed in by client.
hdr := http.Header{}
for k, vl := range r.Header {
if k == "Forwarded" || k == "X-Forwarded" || strings.HasPrefix(k, "X-Forwarded-") {
continue
}
hdr[k] = vl
}
r.Header = hdr
// Add our own X-Forwarded headers. ReverseProxy will add X-Forwarded-For.
r.Header["X-Forwarded-Host"] = []string{r.Host}
proto := "http"
if r.TLS != nil {
proto = "https"
}
r.Header["X-Forwarded-Proto"] = []string{proto}
// note: We are not using "ws" or "wss" for websocket. The request we are
// forwarding is http(s), and we don't yet know if the backend even supports
// websockets.
// todo: add Forwarded header? is anyone using it?
// If we see an Upgrade: websocket, we're going to assume the client needs
// websocket and only attempt to talk websocket with the backend. If the backend
// doesn't do websocket, we'll send back a "bad request" response. For other values
// of Upgrade, we don't do anything special.
// https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml
// Upgrade: ../rfc/9110:2798
// Upgrade headers are not for http/1.0, ../rfc/9110:2880
// Websocket client "handshake" is described at ../rfc/6455:1134
upgrade := r.Header.Get("Upgrade")
if upgrade != "" && !(r.ProtoMajor == 1 && r.ProtoMinor == 0) {
// Websockets have case-insensitive string "websocket".
for _, s := range strings.Split(upgrade, ",") {
if strings.EqualFold(textproto.TrimString(s), "websocket") {
forwardWebsocket(h, w, r, path)
return true
}
}
}
// ReverseProxy will append any remaining path to the configured target URL.
proxy := httputil.NewSingleHostReverseProxy(h.TargetURL)
proxy.FlushInterval = time.Duration(-1) // Flush after each write.
proxy.ErrorLog = golog.New(mlog.ErrWriter(mlog.New("net/http/httputil").WithContext(r.Context()), mlog.LevelDebug, "reverseproxy error"), "", 0)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, context.Canceled) {
log().Debugx("forwarding request to backend webserver", err, mlog.Field("url", r.URL))
return
}
log().Errorx("forwarding request to backend webserver", err, mlog.Field("url", r.URL))
if os.IsTimeout(err) {
http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
} else {
http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
}
}
whdr := w.Header()
for k, v := range h.ResponseHeaders {
whdr.Add(k, v)
}
proxy.ServeHTTP(w, r)
return true
}
var errResponseNotWebsocket = errors.New("not a valid websocket response to request")
var errNotImplemented = errors.New("functionality not yet implemented")
// Request has an Upgrade: websocket header. Check more websocketiness about the
// request. If it looks good, we forward it to the backend. If the backend responds
// with a valid websocket response, indicating it is indeed a websocket server, we
// pass the response along and start copying data between the client and the
// backend. We don't look at the frames and payloads. The backend already needs to
// know enough websocket to handle the frames. It wouldn't necessarily hurt to
// monitor the frames too, and check if they are valid, but it's quite a bit of
// work for little benefit. Besides, the whole point of websockets is to exchange
// bytes without HTTP being in the way, so let's do that.
func forwardWebsocket(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
log := func() *mlog.Log {
return xlog.WithContext(r.Context())
}
lw := w.(*loggingWriter)
lw.WebsocketRequest = true // For correct protocol in metrics.
// We check the requested websocket version first. A future websocket version may
// have different request requirements.
// ../rfc/6455:1160
wsversion := r.Header.Get("Sec-WebSocket-Version")
if wsversion != "13" {
// Indicate we only support version 13. Should get a client from the future to fall back to version 13.
// ../rfc/6455:1435
w.Header().Set("Sec-WebSocket-Version", "13")
http.Error(w, "400 - bad request - websockets only supported with version 13"+recvid(r), http.StatusBadRequest)
lw.error(fmt.Errorf("Sec-WebSocket-Version %q not supported", wsversion))
return true
}
// ../rfc/6455:1143
if r.Method != "GET" {
http.Error(w, "400 - bad request - websockets only allowed with method GET"+recvid(r), http.StatusBadRequest)
lw.error(fmt.Errorf("websocket request only allowed with method GET"))
return true
}
// ../rfc/6455:1153
var connectionUpgrade bool
for _, s := range strings.Split(r.Header.Get("Connection"), ",") {
if strings.EqualFold(textproto.TrimString(s), "upgrade") {
connectionUpgrade = true
break
}
}
if !connectionUpgrade {
http.Error(w, "400 - bad request - connection header must be \"upgrade\""+recvid(r), http.StatusBadRequest)
lw.error(fmt.Errorf(`connection header is %q, must be "upgrade"`, r.Header.Get("Connection")))
return true
}
// ../rfc/6455:1156
wskey := r.Header.Get("Sec-WebSocket-Key")
key, err := base64.StdEncoding.DecodeString(wskey)
if err != nil || len(key) != 16 {
http.Error(w, "400 - bad request - websockets requires Sec-WebSocket-Key with 16 bytes base64-encoded value"+recvid(r), http.StatusBadRequest)
lw.error(fmt.Errorf("bad Sec-WebSocket-Key %q, must be 16 byte base64-encoded value", wskey))
return true
}
// ../rfc/6455:1162
// We don't look at the origin header. The backend needs to handle it, if it thinks
// that helps...
// We also don't look at Sec-WebSocket-Protocol and Sec-WebSocket-Extensions. The
// backend can set them, but it doesn't influence our forwarding of the data.
// If this is not a hijacker, there is not point in connecting to the backend.
hj, ok := lw.W.(http.Hijacker)
var cbr *bufio.ReadWriter
if !ok {
log().Info("cannot turn http connection into tcp connection (http.Hijacker)")
http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
lw.error(fmt.Errorf("connection not a http.Hijacker (%T)", lw.W))
return
}
freq := *r
freq.Proto = "HTTP/1.1"
freq.ProtoMajor = 1
freq.ProtoMinor = 1
fresp, beconn, err := websocketTransact(r.Context(), h.TargetURL, &freq)
if err != nil {
if errors.Is(err, errResponseNotWebsocket) {
http.Error(w, "400 - bad request - websocket not supported"+recvid(r), http.StatusBadRequest)
} else if errors.Is(err, errNotImplemented) {
http.Error(w, "501 - not implemented - "+err.Error()+recvid(r), http.StatusNotImplemented)
} else if os.IsTimeout(err) {
http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
} else {
http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
}
lw.error(err)
return
}
defer func() {
if beconn != nil {
beconn.Close()
}
}()
// Hijack the client connection so we can write the response ourselves, and start
// copying the websocket frames.
var cconn net.Conn
cconn, cbr, err = hj.Hijack()
if err != nil {
log().Debugx("cannot turn http transaction into websocket connection", err)
http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
lw.error(err)
return
}
defer func() {
if cconn != nil {
cconn.Close()
}
}()
// Below this point, we can no longer write to the ResponseWriter.
// Mark as websocket response, for logging.
lw.WebsocketResponse = true
lw.setStatusCode(fresp.StatusCode)
for k, v := range h.ResponseHeaders {
fresp.Header.Add(k, v)
}
// Write the response to the client, completing its websocket handshake.
if err := fresp.Write(cconn); err != nil {
lw.error(fmt.Errorf("writing websocket response to client: %w", err))
return
}
errc := make(chan error, 1)
// Copy from client to backend.
go func() {
buf, err := cbr.Peek(cbr.Reader.Buffered())
if err != nil {
errc <- err
return
}
if len(buf) > 0 {
n, err := beconn.Write(buf)
if err != nil {
errc <- err
return
}
lw.SizeFromClient += int64(n)
}
n, err := io.Copy(beconn, cconn)
lw.SizeFromClient += n
errc <- err
}()
// Copy from backend to client.
go func() {
n, err := io.Copy(cconn, beconn)
lw.SizeToClient = n
errc <- err
}()
// Stop and close connection on first error from either size, typically a closed
// connection whose closing was already announced with a websocket frame.
lw.error(<-errc)
// Close connections so other goroutine stops as well.
cconn.Close()
beconn.Close()
cconn = nil
// Wait for goroutine so it has updated the logWriter.Size*Client fields before we
// continue with logging.
<-errc
return true
}
func websocketTransact(ctx context.Context, targetURL *url.URL, r *http.Request) (rresp *http.Response, rconn net.Conn, rerr error) {
log := func() *mlog.Log {
return xlog.WithContext(r.Context())
}
// Dial the backend, possibly doing TLS. We assume the net/http DefaultTransport is
// unmodified.
transport := http.DefaultTransport.(*http.Transport)
// We haven't implemented using a proxy for websocket requests yet. If we need one,
// return an error instead of trying to connect directly, which would be a
// potential security issue.
treq := *r
treq.URL = targetURL
if purl, err := transport.Proxy(&treq); err != nil {
return nil, nil, fmt.Errorf("determining proxy for websocket backend connection: %w", err)
} else if purl != nil {
return nil, nil, fmt.Errorf("%w: proxy required for websocket connection to backend", errNotImplemented) // todo: implement?
}
host, port, err := net.SplitHostPort(targetURL.Host)
if err != nil {
host = targetURL.Host
if targetURL.Scheme == "https" {
port = "443"
} else {
port = "80"
}
}
addr := net.JoinHostPort(host, port)
conn, err := transport.DialContext(r.Context(), "tcp", addr)
if err != nil {
return nil, nil, fmt.Errorf("dial: %w", err)
}
if targetURL.Scheme == "https" {
tlsconn := tls.Client(conn, transport.TLSClientConfig)
ctx, cancel := context.WithTimeout(r.Context(), transport.TLSHandshakeTimeout)
defer cancel()
if err := tlsconn.HandshakeContext(ctx); err != nil {
return nil, nil, fmt.Errorf("tls handshake: %w", err)
}
conn = tlsconn
}
defer func() {
if rerr != nil {
conn.Close()
}
}()
// todo: make timeout configurable?
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
log().Check(err, "set deadline for websocket request to backend")
}
// Set clean connection headers.
removeHopByHopHeaders(r.Header)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
// Write the websocket request to the backend.
if err := r.Write(conn); err != nil {
return nil, nil, fmt.Errorf("writing request to backend: %w", err)
}
// Read response from backend.
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, r)
if err != nil {
return nil, nil, fmt.Errorf("reading response from backend: %w", err)
}
defer func() {
if rerr != nil {
resp.Body.Close()
}
}()
if err := conn.SetDeadline(time.Time{}); err != nil {
log().Check(err, "clearing deadline on websocket connection to backend")
}
// Check that the response from the backend server indicates it is websocket. If
// not, don't pass the backend response, but an error that websocket is not
// appropriate.
if err := checkWebsocketResponse(resp, r); err != nil {
return resp, nil, err
}
// note: net/http.Response.Body documents that it implements io.Writer for a
// status: 101 response. But that's not the case when the response has been read
// with http.ReadResponse. We'll write to the connection directly.
buf, err := br.Peek(br.Buffered())
if err != nil {
return resp, nil, fmt.Errorf("peek at buffered data written by backend: %w", err)
}
return resp, websocketConn{io.MultiReader(bytes.NewReader(buf), conn), conn}, nil
}
// A net.Conn but with reads coming from an io multireader (due to buffered reader
// needed for http.ReadResponse).
type websocketConn struct {
r io.Reader
net.Conn
}
func (c websocketConn) Read(buf []byte) (int, error) {
return c.r.Read(buf)
}
// Check that an HTTP response (from a backend) is a valid websocket response, i.e.
// that it accepts the WebSocket "upgrade".
// ../rfc/6455:1299
func checkWebsocketResponse(resp *http.Response, req *http.Request) error {
if resp.StatusCode != 101 {
return fmt.Errorf("%w: response http status not 101 but %s", errResponseNotWebsocket, resp.Status)
}
if upgrade := resp.Header.Get("Upgrade"); !strings.EqualFold(upgrade, "websocket") {
return fmt.Errorf(`%w: response http status is 101, but Upgrade header is %q, should be "websocket"`, errResponseNotWebsocket, upgrade)
}
if connection := resp.Header.Get("Connection"); !strings.EqualFold(connection, "upgrade") {
return fmt.Errorf(`%w: response http status is 101, Upgrade is websocket, but Connection header is %q, should be "Upgrade"`, errResponseNotWebsocket, connection)
}
accept, err := base64.StdEncoding.DecodeString(resp.Header.Get("Sec-WebSocket-Accept"))
if err != nil {
return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but Sec-WebSocket-Accept header is not valid base64: %v`, errResponseNotWebsocket, err)
}
exp := sha1.Sum([]byte(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
if !bytes.Equal(accept, exp[:]) {
return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but backend Sec-WebSocket-Accept value does not match`, errResponseNotWebsocket)
}
// We don't have requirements for the other Sec-WebSocket headers. ../rfc/6455:1340
return nil
}
// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
// ../rfc/2616:5128
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
// removeHopByHopHeaders removes hop-by-hop headers.
func removeHopByHopHeaders(h http.Header) {
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
// ../rfc/7230:2817
for _, f := range h["Connection"] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
// This behavior is superseded by the RFC 7230 Connection header, but
// preserve it for backwards compatibility.
// ../rfc/2616:5128
for _, f := range hopHeaders {
h.Del(f)
}
}