mirror of
https://github.com/mjl-/mox.git
synced 2025-01-28 07:15:55 +03:00
a69887bfab
this simplifies some of the code that makes modifications to the config file. a few protected functions can make changes to the dynamic config, which webadmin can use. instead of having separate functions in mox-/admin.go for each type of change. this also exports the parsed full dynamic config to webadmin, so we need fewer functions for specific config fields too.
834 lines
27 KiB
Go
834 lines
27 KiB
Go
package http
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"crypto/sha1"
|
||
"crypto/tls"
|
||
"encoding/base64"
|
||
"errors"
|
||
"fmt"
|
||
htmltemplate "html/template"
|
||
"io"
|
||
"io/fs"
|
||
golog "log"
|
||
"log/slog"
|
||
"net"
|
||
"net/http"
|
||
"net/http/httputil"
|
||
"net/textproto"
|
||
"net/url"
|
||
"os"
|
||
"path/filepath"
|
||
"sort"
|
||
"strings"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/mjl-/mox/config"
|
||
"github.com/mjl-/mox/dns"
|
||
"github.com/mjl-/mox/mlog"
|
||
"github.com/mjl-/mox/mox-"
|
||
"github.com/mjl-/mox/moxio"
|
||
)
|
||
|
||
func recvid(r *http.Request) string {
|
||
cid := mox.CidFromCtx(r.Context())
|
||
if cid <= 0 {
|
||
return ""
|
||
}
|
||
return " (id " + mox.ReceivedID(cid) + ")"
|
||
}
|
||
|
||
// WebHandle serves an HTTP request by going through the list of WebHandlers,
|
||
// check if there is a domain+path match, and running the handler if so.
|
||
// WebHandle runs after the built-in handlers for mta-sts, autoconfig, etc.
|
||
// If no handler matched, false is returned.
|
||
// WebHandle sets w.Name to that of the matching handler.
|
||
func WebHandle(w *loggingWriter, r *http.Request, host dns.Domain) (handled bool) {
|
||
conf := mox.Conf.DynamicConfig()
|
||
redirects := conf.WebDNSDomainRedirects
|
||
handlers := conf.WebHandlers
|
||
|
||
for from, to := range redirects {
|
||
if host != from {
|
||
continue
|
||
}
|
||
u := r.URL
|
||
u.Scheme = "https"
|
||
u.Host = to.Name()
|
||
w.Handler = "(domainredirect)"
|
||
http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
|
||
return true
|
||
}
|
||
|
||
for _, h := range handlers {
|
||
if host != h.DNSDomain {
|
||
continue
|
||
}
|
||
loc := h.Path.FindStringIndex(r.URL.Path)
|
||
if loc == nil {
|
||
continue
|
||
}
|
||
s := loc[0]
|
||
e := loc[1]
|
||
path := r.URL.Path[s:e]
|
||
|
||
if r.TLS == nil && !h.DontRedirectPlainHTTP {
|
||
u := *r.URL
|
||
u.Scheme = "https"
|
||
u.Host = h.DNSDomain.Name()
|
||
w.Handler = h.Name
|
||
w.Compress = h.Compress
|
||
http.Redirect(w, r, u.String(), http.StatusPermanentRedirect)
|
||
return true
|
||
}
|
||
|
||
// We don't want the loggingWriter to override the static handler's decisions to compress.
|
||
w.Compress = h.Compress
|
||
if h.WebStatic != nil && HandleStatic(h.WebStatic, h.Compress, w, r) {
|
||
w.Handler = h.Name
|
||
return true
|
||
}
|
||
if h.WebRedirect != nil && HandleRedirect(h.WebRedirect, w, r) {
|
||
w.Handler = h.Name
|
||
return true
|
||
}
|
||
if h.WebForward != nil && HandleForward(h.WebForward, w, r, path) {
|
||
w.Handler = h.Name
|
||
return true
|
||
}
|
||
}
|
||
w.Compress = false
|
||
return false
|
||
}
|
||
|
||
var lsTemplate = htmltemplate.Must(htmltemplate.New("ls").Parse(`<!doctype html>
|
||
<html>
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
<title>ls</title>
|
||
<style>
|
||
body, html { padding: 1em; font-size: 16px; }
|
||
* { font-size: inherit; font-family: ubuntu, lato, sans-serif; margin: 0; padding: 0; box-sizing: border-box; }
|
||
h1 { margin-bottom: 1ex; font-size: 1.2rem; }
|
||
table td, table th { padding: .2em .5em; }
|
||
table > tbody > tr:nth-child(odd) { background-color: #f8f8f8; }
|
||
[title] { text-decoration: underline; text-decoration-style: dotted; }
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<h1>ls</h1>
|
||
<table>
|
||
<thead>
|
||
<tr>
|
||
<th>Size in MB</th>
|
||
<th>Modified (UTC)</th>
|
||
<th>Name</th>
|
||
</tr>
|
||
</thead>
|
||
<tbody>
|
||
{{ if not .Files }}
|
||
<tr><td colspan="3">No files.</td></tr>
|
||
{{ end }}
|
||
{{ range .Files }}
|
||
<tr>
|
||
<td title="{{ .Size }} bytes" style="text-align: right">{{ .SizeReadable }}{{ if .SizePad }}<span style="visibility:hidden">. </span>{{ end }}</td>
|
||
<td>{{ .Modified }}</td>
|
||
<td><a style="display: block" href="{{ .Name }}">{{ .Name }}</a></td>
|
||
</tr>
|
||
{{ end }}
|
||
</tbody>
|
||
</table>
|
||
</body>
|
||
</html>
|
||
`))
|
||
|
||
// HandleStatic serves static files. If a directory is requested and the URL
|
||
// path doesn't end with a slash, a response with a redirect to the URL path with trailing
|
||
// slash is written. If a directory is requested and an index.html exists, that
|
||
// file is returned. Otherwise, for directories with ListFiles configured, a
|
||
// directory listing is returned.
|
||
func HandleStatic(h *config.WebStatic, compress bool, w http.ResponseWriter, r *http.Request) (handled bool) {
|
||
log := func() mlog.Log {
|
||
return pkglog.WithContext(r.Context())
|
||
}
|
||
if r.Method != "GET" && r.Method != "HEAD" {
|
||
if h.ContinueNotFound {
|
||
// Give another handler that is presumbly configured, for the same path, a chance.
|
||
// E.g. an app that may generate this file for future requests to pick up.
|
||
return false
|
||
}
|
||
http.Error(w, "405 - method not allowed", http.StatusMethodNotAllowed)
|
||
return true
|
||
}
|
||
|
||
var fspath string
|
||
if h.StripPrefix != "" {
|
||
if !strings.HasPrefix(r.URL.Path, h.StripPrefix) {
|
||
if h.ContinueNotFound {
|
||
// We haven't handled this request, try a next WebHandler in the list.
|
||
return false
|
||
}
|
||
http.NotFound(w, r)
|
||
return true
|
||
}
|
||
fspath = filepath.Join(h.Root, strings.TrimPrefix(r.URL.Path, h.StripPrefix))
|
||
} else {
|
||
fspath = filepath.Join(h.Root, r.URL.Path)
|
||
}
|
||
// fspath will not have a trailing slash anymore, we'll correct for it
|
||
// later when the path turns out to be file instead of a directory.
|
||
|
||
serveFile := func(name string, fi fs.FileInfo, content *os.File) {
|
||
// ServeContent only sets a content-type if not already present in the response headers.
|
||
hdr := w.Header()
|
||
for k, v := range h.ResponseHeaders {
|
||
hdr.Add(k, v)
|
||
}
|
||
// We transparently compress here, but still use ServeContent, because it handles
|
||
// conditional requests, range requests. It's a bit of a hack, but on first write
|
||
// to staticgzcacheReplacer where we are compressing, we write the full compressed
|
||
// file instead, and return an error to ServeContent so it stops. We still have all
|
||
// the useful behaviour (status code and headers) from ServeContent.
|
||
xw := w
|
||
if compress && acceptsGzip(r) && compressibleContent(content) {
|
||
xw = &staticgzcacheReplacer{w, r, content.Name(), content, fi.ModTime(), fi.Size(), 0, false}
|
||
} else {
|
||
w.(*loggingWriter).Compress = false
|
||
}
|
||
http.ServeContent(xw, r, name, fi.ModTime(), content)
|
||
}
|
||
|
||
f, err := os.Open(fspath)
|
||
if err != nil {
|
||
if os.IsNotExist(err) || errors.Is(err, syscall.ENOTDIR) {
|
||
if h.ContinueNotFound {
|
||
// We haven't handled this request, try a next WebHandler in the list.
|
||
return false
|
||
}
|
||
http.NotFound(w, r)
|
||
return true
|
||
} else if os.IsPermission(err) {
|
||
// If we tried opening a directory, we may not have permission to read it, but
|
||
// still access files inside it (execute bit), such as index.html. So try to serve it.
|
||
index, err := os.Open(filepath.Join(fspath, "index.html"))
|
||
if err == nil {
|
||
defer index.Close()
|
||
var ifi os.FileInfo
|
||
ifi, err = index.Stat()
|
||
if err != nil {
|
||
log().Errorx("stat index.html in directory we cannot list", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
|
||
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
|
||
return true
|
||
}
|
||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||
serveFile("index.html", ifi, index)
|
||
return true
|
||
}
|
||
http.Error(w, "403 - permission denied", http.StatusForbidden)
|
||
return true
|
||
}
|
||
log().Errorx("open file for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
|
||
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
|
||
return true
|
||
}
|
||
defer f.Close()
|
||
|
||
fi, err := f.Stat()
|
||
if err != nil {
|
||
log().Errorx("stat file for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
|
||
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
|
||
return true
|
||
}
|
||
// Redirect if the local path is a directory.
|
||
if fi.IsDir() && !strings.HasSuffix(r.URL.Path, "/") {
|
||
http.Redirect(w, r, r.URL.Path+"/", http.StatusTemporaryRedirect)
|
||
return true
|
||
} else if !fi.IsDir() && strings.HasSuffix(r.URL.Path, "/") {
|
||
if h.ContinueNotFound {
|
||
return false
|
||
}
|
||
http.NotFound(w, r)
|
||
return true
|
||
}
|
||
|
||
if fi.IsDir() {
|
||
index, err := os.Open(filepath.Join(fspath, "index.html"))
|
||
if err != nil && os.IsPermission(err) {
|
||
http.Error(w, "403 - permission denied", http.StatusForbidden)
|
||
return true
|
||
} else if err != nil && os.IsNotExist(err) && !h.ListFiles {
|
||
if h.ContinueNotFound {
|
||
return false
|
||
}
|
||
http.Error(w, "403 - permission denied", http.StatusForbidden)
|
||
return true
|
||
} else if err == nil {
|
||
defer index.Close()
|
||
var ifi os.FileInfo
|
||
ifi, err = index.Stat()
|
||
if err == nil {
|
||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||
serveFile("index.html", ifi, index)
|
||
return true
|
||
}
|
||
}
|
||
if !os.IsNotExist(err) {
|
||
log().Errorx("stat for static file serving", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
|
||
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
|
||
return true
|
||
}
|
||
|
||
type File struct {
|
||
Name string
|
||
Size int64
|
||
SizeReadable string
|
||
SizePad bool // Whether the size needs padding because it has no decimal point.
|
||
Modified string
|
||
}
|
||
files := []File{}
|
||
if r.URL.Path != "/" {
|
||
files = append(files, File{"..", 0, "", false, ""})
|
||
}
|
||
for {
|
||
l, err := f.Readdir(1000)
|
||
for _, e := range l {
|
||
mb := float64(e.Size()) / (1024 * 1024)
|
||
var size string
|
||
var sizepad bool
|
||
if !e.IsDir() {
|
||
if mb >= 10 {
|
||
size = fmt.Sprintf("%d", int64(mb))
|
||
sizepad = true
|
||
} else {
|
||
size = fmt.Sprintf("%.2f", mb)
|
||
}
|
||
}
|
||
const dateTime = "2006-01-02 15:04:05" // time.DateTime, but only since go1.20.
|
||
modified := e.ModTime().UTC().Format(dateTime)
|
||
f := File{e.Name(), e.Size(), size, sizepad, modified}
|
||
if e.IsDir() {
|
||
f.Name += "/"
|
||
}
|
||
files = append(files, f)
|
||
}
|
||
if err == io.EOF {
|
||
break
|
||
} else if err != nil {
|
||
log().Errorx("reading directory for file listing", err, slog.Any("url", r.URL), slog.String("fspath", fspath))
|
||
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
|
||
return true
|
||
}
|
||
}
|
||
sort.Slice(files, func(i, j int) bool {
|
||
return files[i].Name < files[j].Name
|
||
})
|
||
hdr := w.Header()
|
||
hdr.Set("Content-Type", "text/html; charset=utf-8")
|
||
for k, v := range h.ResponseHeaders {
|
||
if !strings.EqualFold(k, "content-type") {
|
||
hdr.Add(k, v)
|
||
}
|
||
}
|
||
err = lsTemplate.Execute(w, map[string]any{"Files": files})
|
||
if err != nil && !moxio.IsClosed(err) {
|
||
log().Errorx("executing directory listing template", err)
|
||
}
|
||
return true
|
||
}
|
||
|
||
serveFile(fspath, fi, f)
|
||
return true
|
||
}
|
||
|
||
// HandleRedirect writes a response with an HTTP redirect.
|
||
func HandleRedirect(h *config.WebRedirect, w http.ResponseWriter, r *http.Request) (handled bool) {
|
||
var dstpath string
|
||
if h.OrigPath == nil {
|
||
// No path rewrite necessary.
|
||
dstpath = r.URL.Path
|
||
} else if !h.OrigPath.MatchString(r.URL.Path) {
|
||
http.NotFound(w, r)
|
||
return true
|
||
} else {
|
||
dstpath = h.OrigPath.ReplaceAllString(r.URL.Path, h.ReplacePath)
|
||
}
|
||
|
||
u := *r.URL
|
||
u.Opaque = ""
|
||
u.RawPath = ""
|
||
u.OmitHost = false
|
||
if h.URL != nil {
|
||
u.Scheme = h.URL.Scheme
|
||
u.Host = h.URL.Host
|
||
u.ForceQuery = h.URL.ForceQuery
|
||
u.RawQuery = h.URL.RawQuery
|
||
u.Fragment = h.URL.Fragment
|
||
if r.URL.RawQuery != "" {
|
||
if u.RawQuery != "" {
|
||
u.RawQuery += "&"
|
||
}
|
||
u.RawQuery += r.URL.RawQuery
|
||
}
|
||
}
|
||
u.Path = dstpath
|
||
code := http.StatusPermanentRedirect
|
||
if h.StatusCode != 0 {
|
||
code = h.StatusCode
|
||
}
|
||
|
||
// If we would be redirecting to the same scheme,host,path, we would get here again
|
||
// causing a redirect loop. Instead, this causes this redirect to not match,
|
||
// allowing to try the next WebHandler. This can be used to redirect all plain http
|
||
// requests to https.
|
||
reqscheme := "http"
|
||
if r.TLS != nil {
|
||
reqscheme = "https"
|
||
}
|
||
if reqscheme == u.Scheme && r.Host == u.Host && r.URL.Path == u.Path {
|
||
return false
|
||
}
|
||
|
||
http.Redirect(w, r, u.String(), code)
|
||
return true
|
||
}
|
||
|
||
// HandleForward handles a request by forwarding it to another webserver and
|
||
// passing the response on. I.e. a reverse proxy. It handles websocket
|
||
// connections by monitoring the websocket handshake and then just passing along the
|
||
// websocket frames.
|
||
func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
|
||
log := func() mlog.Log {
|
||
return pkglog.WithContext(r.Context())
|
||
}
|
||
|
||
xr := *r
|
||
r = &xr
|
||
if h.StripPath {
|
||
u := *r.URL
|
||
u.Path = r.URL.Path[len(path):]
|
||
if !strings.HasPrefix(u.Path, "/") {
|
||
u.Path = "/" + u.Path
|
||
}
|
||
u.RawPath = ""
|
||
r.URL = &u
|
||
}
|
||
|
||
// Remove any forwarded headers passed in by client.
|
||
hdr := http.Header{}
|
||
for k, vl := range r.Header {
|
||
if k == "Forwarded" || k == "X-Forwarded" || strings.HasPrefix(k, "X-Forwarded-") {
|
||
continue
|
||
}
|
||
hdr[k] = vl
|
||
}
|
||
r.Header = hdr
|
||
|
||
// Add our own X-Forwarded headers. ReverseProxy will add X-Forwarded-For.
|
||
r.Header["X-Forwarded-Host"] = []string{r.Host}
|
||
proto := "http"
|
||
if r.TLS != nil {
|
||
proto = "https"
|
||
}
|
||
r.Header["X-Forwarded-Proto"] = []string{proto}
|
||
// note: We are not using "ws" or "wss" for websocket. The request we are
|
||
// forwarding is http(s), and we don't yet know if the backend even supports
|
||
// websockets.
|
||
|
||
// todo: add Forwarded header? is anyone using it?
|
||
|
||
// If we see an Upgrade: websocket, we're going to assume the client needs
|
||
// websocket and only attempt to talk websocket with the backend. If the backend
|
||
// doesn't do websocket, we'll send back a "bad request" response. For other values
|
||
// of Upgrade, we don't do anything special.
|
||
// https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml
|
||
// Upgrade: ../rfc/9110:2798
|
||
// Upgrade headers are not for http/1.0, ../rfc/9110:2880
|
||
// Websocket client "handshake" is described at ../rfc/6455:1134
|
||
upgrade := r.Header.Get("Upgrade")
|
||
if upgrade != "" && !(r.ProtoMajor == 1 && r.ProtoMinor == 0) {
|
||
// Websockets have case-insensitive string "websocket".
|
||
for _, s := range strings.Split(upgrade, ",") {
|
||
if strings.EqualFold(textproto.TrimString(s), "websocket") {
|
||
forwardWebsocket(h, w, r, path)
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
|
||
// ReverseProxy will append any remaining path to the configured target URL.
|
||
proxy := httputil.NewSingleHostReverseProxy(h.TargetURL)
|
||
proxy.FlushInterval = time.Duration(-1) // Flush after each write.
|
||
proxy.ErrorLog = golog.New(mlog.LogWriter(mlog.New("net/http/httputil", nil).WithContext(r.Context()), mlog.LevelDebug, "reverseproxy error"), "", 0)
|
||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||
if errors.Is(err, context.Canceled) {
|
||
log().Debugx("forwarding request to backend webserver", err, slog.Any("url", r.URL))
|
||
return
|
||
}
|
||
log().Errorx("forwarding request to backend webserver", err, slog.Any("url", r.URL))
|
||
if os.IsTimeout(err) {
|
||
http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
|
||
} else {
|
||
http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
|
||
}
|
||
}
|
||
whdr := w.Header()
|
||
for k, v := range h.ResponseHeaders {
|
||
whdr.Add(k, v)
|
||
}
|
||
proxy.ServeHTTP(w, r)
|
||
return true
|
||
}
|
||
|
||
var errResponseNotWebsocket = errors.New("not a valid websocket response to request")
|
||
var errNotImplemented = errors.New("functionality not yet implemented")
|
||
|
||
// Request has an Upgrade: websocket header. Check more websocketiness about the
|
||
// request. If it looks good, we forward it to the backend. If the backend responds
|
||
// with a valid websocket response, indicating it is indeed a websocket server, we
|
||
// pass the response along and start copying data between the client and the
|
||
// backend. We don't look at the frames and payloads. The backend already needs to
|
||
// know enough websocket to handle the frames. It wouldn't necessarily hurt to
|
||
// monitor the frames too, and check if they are valid, but it's quite a bit of
|
||
// work for little benefit. Besides, the whole point of websockets is to exchange
|
||
// bytes without HTTP being in the way, so let's do that.
|
||
func forwardWebsocket(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
|
||
log := func() mlog.Log {
|
||
return pkglog.WithContext(r.Context())
|
||
}
|
||
|
||
lw := w.(*loggingWriter)
|
||
lw.WebsocketRequest = true // For correct protocol in metrics.
|
||
|
||
// We check the requested websocket version first. A future websocket version may
|
||
// have different request requirements.
|
||
// ../rfc/6455:1160
|
||
wsversion := r.Header.Get("Sec-WebSocket-Version")
|
||
if wsversion != "13" {
|
||
// Indicate we only support version 13. Should get a client from the future to fall back to version 13.
|
||
// ../rfc/6455:1435
|
||
w.Header().Set("Sec-WebSocket-Version", "13")
|
||
http.Error(w, "400 - bad request - websockets only supported with version 13"+recvid(r), http.StatusBadRequest)
|
||
lw.error(fmt.Errorf("Sec-WebSocket-Version %q not supported", wsversion))
|
||
return true
|
||
}
|
||
|
||
// ../rfc/6455:1143
|
||
if r.Method != "GET" {
|
||
http.Error(w, "400 - bad request - websockets only allowed with method GET"+recvid(r), http.StatusBadRequest)
|
||
lw.error(fmt.Errorf("websocket request only allowed with method GET"))
|
||
return true
|
||
}
|
||
|
||
// ../rfc/6455:1153
|
||
var connectionUpgrade bool
|
||
for _, s := range strings.Split(r.Header.Get("Connection"), ",") {
|
||
if strings.EqualFold(textproto.TrimString(s), "upgrade") {
|
||
connectionUpgrade = true
|
||
break
|
||
}
|
||
}
|
||
if !connectionUpgrade {
|
||
http.Error(w, "400 - bad request - connection header must be \"upgrade\""+recvid(r), http.StatusBadRequest)
|
||
lw.error(fmt.Errorf(`connection header is %q, must be "upgrade"`, r.Header.Get("Connection")))
|
||
return true
|
||
}
|
||
|
||
// ../rfc/6455:1156
|
||
wskey := r.Header.Get("Sec-WebSocket-Key")
|
||
key, err := base64.StdEncoding.DecodeString(wskey)
|
||
if err != nil || len(key) != 16 {
|
||
http.Error(w, "400 - bad request - websockets requires Sec-WebSocket-Key with 16 bytes base64-encoded value"+recvid(r), http.StatusBadRequest)
|
||
lw.error(fmt.Errorf("bad Sec-WebSocket-Key %q, must be 16 byte base64-encoded value", wskey))
|
||
return true
|
||
}
|
||
|
||
// ../rfc/6455:1162
|
||
// We don't look at the origin header. The backend needs to handle it, if it thinks
|
||
// that helps...
|
||
// We also don't look at Sec-WebSocket-Protocol and Sec-WebSocket-Extensions. The
|
||
// backend can set them, but it doesn't influence our forwarding of the data.
|
||
|
||
// If this is not a hijacker, there is not point in connecting to the backend.
|
||
hj, ok := lw.W.(http.Hijacker)
|
||
var cbr *bufio.ReadWriter
|
||
if !ok {
|
||
log().Info("cannot turn http connection into tcp connection (http.Hijacker)")
|
||
http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
|
||
lw.error(fmt.Errorf("connection not a http.Hijacker (%T)", lw.W))
|
||
return
|
||
}
|
||
|
||
freq := *r
|
||
freq.Proto = "HTTP/1.1"
|
||
freq.ProtoMajor = 1
|
||
freq.ProtoMinor = 1
|
||
fresp, beconn, err := websocketTransact(r.Context(), h.TargetURL, &freq)
|
||
if err != nil {
|
||
if errors.Is(err, errResponseNotWebsocket) {
|
||
http.Error(w, "400 - bad request - websocket not supported"+recvid(r), http.StatusBadRequest)
|
||
} else if errors.Is(err, errNotImplemented) {
|
||
http.Error(w, "501 - not implemented - "+err.Error()+recvid(r), http.StatusNotImplemented)
|
||
} else if os.IsTimeout(err) {
|
||
http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
|
||
} else {
|
||
http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
|
||
}
|
||
lw.error(err)
|
||
return
|
||
}
|
||
defer func() {
|
||
if beconn != nil {
|
||
beconn.Close()
|
||
}
|
||
}()
|
||
|
||
// Hijack the client connection so we can write the response ourselves, and start
|
||
// copying the websocket frames.
|
||
var cconn net.Conn
|
||
cconn, cbr, err = hj.Hijack()
|
||
if err != nil {
|
||
log().Debugx("cannot turn http transaction into websocket connection", err)
|
||
http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
|
||
lw.error(err)
|
||
return
|
||
}
|
||
defer func() {
|
||
if cconn != nil {
|
||
cconn.Close()
|
||
}
|
||
}()
|
||
|
||
// Below this point, we can no longer write to the ResponseWriter.
|
||
|
||
// Mark as websocket response, for logging.
|
||
lw.WebsocketResponse = true
|
||
lw.setStatusCode(fresp.StatusCode)
|
||
|
||
for k, v := range h.ResponseHeaders {
|
||
fresp.Header.Add(k, v)
|
||
}
|
||
|
||
// Write the response to the client, completing its websocket handshake.
|
||
if err := fresp.Write(cconn); err != nil {
|
||
lw.error(fmt.Errorf("writing websocket response to client: %w", err))
|
||
return
|
||
}
|
||
|
||
errc := make(chan error, 1)
|
||
|
||
// Copy from client to backend.
|
||
go func() {
|
||
buf, err := cbr.Peek(cbr.Reader.Buffered())
|
||
if err != nil {
|
||
errc <- err
|
||
return
|
||
}
|
||
if len(buf) > 0 {
|
||
n, err := beconn.Write(buf)
|
||
if err != nil {
|
||
errc <- err
|
||
return
|
||
}
|
||
lw.SizeFromClient += int64(n)
|
||
}
|
||
n, err := io.Copy(beconn, cconn)
|
||
lw.SizeFromClient += n
|
||
errc <- err
|
||
}()
|
||
|
||
// Copy from backend to client.
|
||
go func() {
|
||
n, err := io.Copy(cconn, beconn)
|
||
lw.SizeToClient = n
|
||
errc <- err
|
||
}()
|
||
|
||
// Stop and close connection on first error from either size, typically a closed
|
||
// connection whose closing was already announced with a websocket frame.
|
||
lw.error(<-errc)
|
||
// Close connections so other goroutine stops as well.
|
||
cconn.Close()
|
||
beconn.Close()
|
||
// Wait for goroutine so it has updated the logWriter.Size*Client fields before we
|
||
// continue with logging.
|
||
<-errc
|
||
cconn = nil
|
||
return true
|
||
}
|
||
|
||
func websocketTransact(ctx context.Context, targetURL *url.URL, r *http.Request) (rresp *http.Response, rconn net.Conn, rerr error) {
|
||
log := func() mlog.Log {
|
||
return pkglog.WithContext(r.Context())
|
||
}
|
||
|
||
// Dial the backend, possibly doing TLS. We assume the net/http DefaultTransport is
|
||
// unmodified.
|
||
transport := http.DefaultTransport.(*http.Transport)
|
||
|
||
// We haven't implemented using a proxy for websocket requests yet. If we need one,
|
||
// return an error instead of trying to connect directly, which would be a
|
||
// potential security issue.
|
||
treq := *r
|
||
treq.URL = targetURL
|
||
if purl, err := transport.Proxy(&treq); err != nil {
|
||
return nil, nil, fmt.Errorf("determining proxy for websocket backend connection: %w", err)
|
||
} else if purl != nil {
|
||
return nil, nil, fmt.Errorf("%w: proxy required for websocket connection to backend", errNotImplemented) // todo: implement?
|
||
}
|
||
|
||
host, port, err := net.SplitHostPort(targetURL.Host)
|
||
if err != nil {
|
||
host = targetURL.Host
|
||
if targetURL.Scheme == "https" {
|
||
port = "443"
|
||
} else {
|
||
port = "80"
|
||
}
|
||
}
|
||
addr := net.JoinHostPort(host, port)
|
||
conn, err := transport.DialContext(r.Context(), "tcp", addr)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("dial: %w", err)
|
||
}
|
||
if targetURL.Scheme == "https" {
|
||
tlsconn := tls.Client(conn, transport.TLSClientConfig)
|
||
ctx, cancel := context.WithTimeout(r.Context(), transport.TLSHandshakeTimeout)
|
||
defer cancel()
|
||
if err := tlsconn.HandshakeContext(ctx); err != nil {
|
||
return nil, nil, fmt.Errorf("tls handshake: %w", err)
|
||
}
|
||
conn = tlsconn
|
||
}
|
||
defer func() {
|
||
if rerr != nil {
|
||
conn.Close()
|
||
}
|
||
}()
|
||
|
||
// todo: make timeout configurable?
|
||
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||
log().Check(err, "set deadline for websocket request to backend")
|
||
}
|
||
|
||
// Set clean connection headers.
|
||
removeHopByHopHeaders(r.Header)
|
||
r.Header.Set("Connection", "Upgrade")
|
||
r.Header.Set("Upgrade", "websocket")
|
||
|
||
// Write the websocket request to the backend.
|
||
if err := r.Write(conn); err != nil {
|
||
return nil, nil, fmt.Errorf("writing request to backend: %w", err)
|
||
}
|
||
|
||
// Read response from backend.
|
||
br := bufio.NewReader(conn)
|
||
resp, err := http.ReadResponse(br, r)
|
||
if err != nil {
|
||
return nil, nil, fmt.Errorf("reading response from backend: %w", err)
|
||
}
|
||
defer func() {
|
||
if rerr != nil {
|
||
resp.Body.Close()
|
||
}
|
||
}()
|
||
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||
log().Check(err, "clearing deadline on websocket connection to backend")
|
||
}
|
||
|
||
// Check that the response from the backend server indicates it is websocket. If
|
||
// not, don't pass the backend response, but an error that websocket is not
|
||
// appropriate.
|
||
if err := checkWebsocketResponse(resp, r); err != nil {
|
||
return resp, nil, err
|
||
}
|
||
|
||
// note: net/http.Response.Body documents that it implements io.Writer for a
|
||
// status: 101 response. But that's not the case when the response has been read
|
||
// with http.ReadResponse. We'll write to the connection directly.
|
||
|
||
buf, err := br.Peek(br.Buffered())
|
||
if err != nil {
|
||
return resp, nil, fmt.Errorf("peek at buffered data written by backend: %w", err)
|
||
}
|
||
return resp, websocketConn{io.MultiReader(bytes.NewReader(buf), conn), conn}, nil
|
||
}
|
||
|
||
// A net.Conn but with reads coming from an io multireader (due to buffered reader
|
||
// needed for http.ReadResponse).
|
||
type websocketConn struct {
|
||
r io.Reader
|
||
net.Conn
|
||
}
|
||
|
||
func (c websocketConn) Read(buf []byte) (int, error) {
|
||
return c.r.Read(buf)
|
||
}
|
||
|
||
// Check that an HTTP response (from a backend) is a valid websocket response, i.e.
|
||
// that it accepts the WebSocket "upgrade".
|
||
// ../rfc/6455:1299
|
||
func checkWebsocketResponse(resp *http.Response, req *http.Request) error {
|
||
if resp.StatusCode != 101 {
|
||
return fmt.Errorf("%w: response http status not 101 but %s", errResponseNotWebsocket, resp.Status)
|
||
}
|
||
if upgrade := resp.Header.Get("Upgrade"); !strings.EqualFold(upgrade, "websocket") {
|
||
return fmt.Errorf(`%w: response http status is 101, but Upgrade header is %q, should be "websocket"`, errResponseNotWebsocket, upgrade)
|
||
}
|
||
if connection := resp.Header.Get("Connection"); !strings.EqualFold(connection, "upgrade") {
|
||
return fmt.Errorf(`%w: response http status is 101, Upgrade is websocket, but Connection header is %q, should be "Upgrade"`, errResponseNotWebsocket, connection)
|
||
}
|
||
accept, err := base64.StdEncoding.DecodeString(resp.Header.Get("Sec-WebSocket-Accept"))
|
||
if err != nil {
|
||
return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but Sec-WebSocket-Accept header is not valid base64: %v`, errResponseNotWebsocket, err)
|
||
}
|
||
exp := sha1.Sum([]byte(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
|
||
if !bytes.Equal(accept, exp[:]) {
|
||
return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but backend Sec-WebSocket-Accept value does not match`, errResponseNotWebsocket)
|
||
}
|
||
// We don't have requirements for the other Sec-WebSocket headers. ../rfc/6455:1340
|
||
return nil
|
||
}
|
||
|
||
// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
|
||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||
// Connection header field. These are the headers defined by the
|
||
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||
// compatibility.
|
||
// ../rfc/2616:5128
|
||
var hopHeaders = []string{
|
||
"Connection",
|
||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||
"Keep-Alive",
|
||
"Proxy-Authenticate",
|
||
"Proxy-Authorization",
|
||
"Te", // canonicalized version of "TE"
|
||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||
"Transfer-Encoding",
|
||
"Upgrade",
|
||
}
|
||
|
||
// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
|
||
// removeHopByHopHeaders removes hop-by-hop headers.
|
||
func removeHopByHopHeaders(h http.Header) {
|
||
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
|
||
// ../rfc/7230:2817
|
||
for _, f := range h["Connection"] {
|
||
for _, sf := range strings.Split(f, ",") {
|
||
if sf = textproto.TrimString(sf); sf != "" {
|
||
h.Del(sf)
|
||
}
|
||
}
|
||
}
|
||
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
|
||
// This behavior is superseded by the RFC 7230 Connection header, but
|
||
// preserve it for backwards compatibility.
|
||
// ../rfc/2616:5128
|
||
for _, f := range hopHeaders {
|
||
h.Del(f)
|
||
}
|
||
}
|