add reverse proxying websocket connections

if we recognize that a request for a WebForward is trying to turn the
connection into a websocket, we forward it to the backend and check if the
backend understands the websocket request. if so, we pass back the upgrade
response and get out of the way, copying bytes between the two. we do log the
total amount of bytes read from the client and written to the client. if the
backend doesn't respond with a websocke response, or an invalid one, we respond
with a regular non-websocket response. and we log details about the failed
connection, should help with debugging and any bug reports.

we don't try to parse the websocket framing, that's between the client and the
backend.  we could try to parse it, in part to protect the backend from bad
frames, but it would be a lot of work and could be brittle in the face of
extensions.

this doesn't yet handle websocket connections when a http proxy is configured.
we'll implement it when someone needs it. we do recognize it and fail the
connection.

for issue #25
This commit is contained in:
Mechiel Lukkien 2023-05-30 22:11:31 +02:00
parent aca64828bd
commit 259928ab62
No known key found for this signature in database
15 changed files with 1966 additions and 49 deletions

View file

@ -393,7 +393,7 @@ func (wr WebRedirect) equal(o WebRedirect) bool {
type WebForward struct { type WebForward struct {
StripPath bool `sconf:"optional" sconf-doc:"Strip the matching WebHandler path from the WebHandler before forwarding the request."` StripPath bool `sconf:"optional" sconf-doc:"Strip the matching WebHandler path from the WebHandler before forwarding the request."`
URL string `sconf-doc:"URL to forward HTTP requests to, e.g. http://127.0.0.1:8123/base. If StripPath is false the full request path is added to the URL. Host headers are sent unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string in the URL is ignored. Requests are made using Go's net/http.DefaultTransport that takes environment variables HTTP_PROXY and HTTPS_PROXY into account."` URL string `sconf-doc:"URL to forward HTTP requests to, e.g. http://127.0.0.1:8123/base. If StripPath is false the full request path is added to the URL. Host headers are sent unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string in the URL is ignored. Requests are made using Go's net/http.DefaultTransport that takes environment variables HTTP_PROXY and HTTPS_PROXY into account. Websocket connections are forwarded and data is copied between client and backend without looking at the framing. The websocket 'version' and 'key'/'accept' headers are verified during the handshake, but other websocket headers, including 'origin', 'protocol' and 'extensions' headers, are not inspected and the backend is responsible for verifying/interpreting them."`
ResponseHeaders map[string]string `sconf:"optional" sconf-doc:"Headers to add to the response. Useful for adding security- and cache-related headers."` ResponseHeaders map[string]string `sconf:"optional" sconf-doc:"Headers to add to the response. Useful for adding security- and cache-related headers."`
TargetURL *url.URL `sconf:"-" json:"-"` TargetURL *url.URL `sconf:"-" json:"-"`

View file

@ -715,6 +715,11 @@ describe-static" and "mox config describe-domains":
# unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string # unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string
# in the URL is ignored. Requests are made using Go's net/http.DefaultTransport # in the URL is ignored. Requests are made using Go's net/http.DefaultTransport
# that takes environment variables HTTP_PROXY and HTTPS_PROXY into account. # that takes environment variables HTTP_PROXY and HTTPS_PROXY into account.
# Websocket connections are forwarded and data is copied between client and
# backend without looking at the framing. The websocket 'version' and
# 'key'/'accept' headers are verified during the handshake, but other websocket
# headers, including 'origin', 'protocol' and 'extensions' headers, are not
# inspected and the backend is responsible for verifying/interpreting them.
URL: URL:
# Headers to add to the response. Useful for adding security- and cache-related # Headers to add to the response. Useful for adding security- and cache-related

View file

@ -1855,7 +1855,7 @@ const webserver = async () => {
), ),
dom.td( dom.td(
'URL', 'URL',
attr({title: "URL to forward HTTP requests to, e.g. http://127.0.0.1:8123/base. If StripPath is false the full request path is added to the URL. Host headers are sent unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string in the URL is ignored. Requests are made using Go's net/http.DefaultTransport that takes environment variables HTTP_PROXY and HTTPS_PROXY into account."}), attr({title: "URL to forward HTTP requests to, e.g. http://127.0.0.1:8123/base. If StripPath is false the full request path is added to the URL. Host headers are sent unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string in the URL is ignored. Requests are made using Go's net/http.DefaultTransport that takes environment variables HTTP_PROXY and HTTPS_PROXY into account. Websocket connections are forwarded and data is copied between client and backend without looking at the framing. The websocket 'version' and 'key'/'accept' headers are verified during the handshake, but other websocket headers, including 'origin', 'protocol' and 'extensions' headers, are not inspected and the backend is responsible for verifying/interpreting them."}),
), ),
dom.td( dom.td(
dom.span( dom.span(

View file

@ -42,7 +42,7 @@ var (
}, },
[]string{ []string{
"handler", // Name from webhandler, can be empty. "handler", // Name from webhandler, can be empty.
"proto", // "http" or "https" "proto", // "http", "https", "ws", "wss"
"method", // "(unknown)" and otherwise only common verbs "method", // "(unknown)" and otherwise only common verbs
"code", "code",
}, },
@ -58,7 +58,7 @@ var (
}, },
[]string{ []string{
"handler", // Name from webhandler, can be empty. "handler", // Name from webhandler, can be empty.
"proto", // "http" or "https" "proto", // "http", "https", "ws", "wss"
"method", // "(unknown)" and otherwise only common verbs "method", // "(unknown)" and otherwise only common verbs
"code", "code",
}, },
@ -72,19 +72,34 @@ type loggingWriter struct {
W http.ResponseWriter // Calls are forwarded. W http.ResponseWriter // Calls are forwarded.
Start time.Time Start time.Time
R *http.Request R *http.Request
WebsocketRequest bool // Whether request from was websocket.
Handler string // Set by router. Handler string // Set by router.
// Set by handlers. // Set by handlers.
StatusCode int StatusCode int
Size int64 Size int64 // Of data served, for non-websocket responses.
WriteErr error Err error
WebsocketResponse bool // If this was a successful websocket connection with backend.
SizeFromClient, SizeToClient int64 // Websocket data.
} }
func (w *loggingWriter) Header() http.Header { func (w *loggingWriter) Header() http.Header {
return w.W.Header() return w.W.Header()
} }
// protocol, for logging.
func (w *loggingWriter) proto(websocket bool) string {
proto := "http"
if websocket {
proto = "ws"
}
if w.R.TLS != nil {
proto += "s"
}
return proto
}
func (w *loggingWriter) setStatusCode(statusCode int) { func (w *loggingWriter) setStatusCode(statusCode int) {
if w.StatusCode != 0 { if w.StatusCode != 0 {
return return
@ -92,11 +107,7 @@ func (w *loggingWriter) setStatusCode(statusCode int) {
w.StatusCode = statusCode w.StatusCode = statusCode
method := metricHTTPMethod(w.R.Method) method := metricHTTPMethod(w.R.Method)
proto := "http" metricRequest.WithLabelValues(w.Handler, w.proto(w.WebsocketRequest), method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
if w.R.TLS != nil {
proto = "https"
}
metricRequest.WithLabelValues(w.Handler, proto, method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
} }
func (w *loggingWriter) Write(buf []byte) (int, error) { func (w *loggingWriter) Write(buf []byte) (int, error) {
@ -108,8 +119,8 @@ func (w *loggingWriter) Write(buf []byte) (int, error) {
if n > 0 { if n > 0 {
w.Size += int64(n) w.Size += int64(n)
} }
if err != nil && w.WriteErr == nil { if err != nil {
w.WriteErr = err w.error(err)
} }
return n, err return n, err
} }
@ -136,13 +147,15 @@ func metricHTTPMethod(method string) string {
return "(other)" return "(other)"
} }
func (w *loggingWriter) error(err error) {
if w.Err == nil {
w.Err = err
}
}
func (w *loggingWriter) Done() { func (w *loggingWriter) Done() {
method := metricHTTPMethod(w.R.Method) method := metricHTTPMethod(w.R.Method)
proto := "http" metricResponse.WithLabelValues(w.Handler, w.proto(w.WebsocketResponse), method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
if w.R.TLS != nil {
proto = "https"
}
metricResponse.WithLabelValues(w.Handler, proto, method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
tlsinfo := "plain" tlsinfo := "plain"
if w.R.TLS != nil { if w.R.TLS != nil {
@ -152,26 +165,42 @@ func (w *loggingWriter) Done() {
tlsinfo = "(other)" tlsinfo = "(other)"
} }
} }
err := w.WriteErr err := w.Err
if err == nil { if err == nil {
err = w.R.Context().Err() err = w.R.Context().Err()
} }
xlog.WithContext(w.R.Context()).Debugx("http request", err, fields := []mlog.Pair{
mlog.Field("httpaccess", ""), mlog.Field("httpaccess", ""),
mlog.Field("handler", w.Handler), mlog.Field("handler", w.Handler),
mlog.Field("method", method), mlog.Field("method", method),
mlog.Field("url", w.R.URL), mlog.Field("url", w.R.URL),
mlog.Field("host", w.R.Host), mlog.Field("host", w.R.Host),
mlog.Field("duration", time.Since(w.Start)), mlog.Field("duration", time.Since(w.Start)),
mlog.Field("size", w.Size),
mlog.Field("statuscode", w.StatusCode), mlog.Field("statuscode", w.StatusCode),
mlog.Field("proto", strings.ToLower(w.R.Proto)), mlog.Field("proto", strings.ToLower(w.R.Proto)),
mlog.Field("remoteaddr", w.R.RemoteAddr), mlog.Field("remoteaddr", w.R.RemoteAddr),
mlog.Field("tlsinfo", tlsinfo), mlog.Field("tlsinfo", tlsinfo),
mlog.Field("useragent", w.R.Header.Get("User-Agent")), mlog.Field("useragent", w.R.Header.Get("User-Agent")),
mlog.Field("referrr", w.R.Header.Get("Referrer")), mlog.Field("referrr", w.R.Header.Get("Referrer")),
}
if w.WebsocketRequest {
fields = append(fields,
mlog.Field("websocketrequest", true),
) )
} }
if w.WebsocketResponse {
fields = append(fields,
mlog.Field("websocket", true),
mlog.Field("sizetoclient", w.SizeToClient),
mlog.Field("sizefromclient", w.SizeFromClient),
)
} else {
fields = append(fields,
mlog.Field("size", w.Size),
)
}
xlog.WithContext(w.R.Context()).Debugx("http request", err, fields...)
}
// Set some http headers that should prevent potential abuse. Better safe than sorry. // Set some http headers that should prevent potential abuse. Better safe than sorry.
func safeHeaders(fn http.Handler) http.Handler { func safeHeaders(fn http.Handler) http.Handler {

View file

@ -1,14 +1,22 @@
package http package http
import ( import (
"bufio"
"bytes"
"context" "context"
"crypto/sha1"
"crypto/tls"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
htmltemplate "html/template" htmltemplate "html/template"
"io" "io"
golog "log" golog "log"
"net"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"net/textproto"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"sort" "sort"
@ -23,6 +31,14 @@ import (
"github.com/mjl-/mox/moxio" "github.com/mjl-/mox/moxio"
) )
func recvid(r *http.Request) string {
cid := mox.CidFromCtx(r.Context())
if cid <= 0 {
return ""
}
return " (id " + mox.ReceivedID(cid) + ")"
}
// WebHandle serves an HTTP request by going through the list of WebHandlers, // WebHandle serves an HTTP request by going through the list of WebHandlers,
// check if there is a domain+path match, and running the handler if so. // check if there is a domain+path match, and running the handler if so.
// WebHandle runs after the built-in handlers for mta-sts, autoconfig, etc. // WebHandle runs after the built-in handlers for mta-sts, autoconfig, etc.
@ -131,14 +147,6 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
log := func() *mlog.Log { log := func() *mlog.Log {
return xlog.WithContext(r.Context()) return xlog.WithContext(r.Context())
} }
recvid := func() string {
cid := mox.CidFromCtx(r.Context())
if cid <= 0 {
return ""
}
return " (id " + mox.ReceivedID(cid) + ")"
}
if r.Method != "GET" && r.Method != "HEAD" { if r.Method != "GET" && r.Method != "HEAD" {
if h.ContinueNotFound { if h.ContinueNotFound {
// Give another handler that is presumbly configured, for the same path, a chance. // Give another handler that is presumbly configured, for the same path, a chance.
@ -194,7 +202,7 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
ifi, err = index.Stat() ifi, err = index.Stat()
if err != nil { if err != nil {
log().Errorx("stat index.html in directory we cannot list", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath)) log().Errorx("stat index.html in directory we cannot list", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
http.Error(w, "500 - internal server error"+recvid(), http.StatusInternalServerError) http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
return true return true
} }
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
@ -205,7 +213,7 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
return true return true
} }
log().Errorx("open file for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath)) log().Errorx("open file for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
http.Error(w, "500 - internal server error"+recvid(), http.StatusInternalServerError) http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
return true return true
} }
defer f.Close() defer f.Close()
@ -213,7 +221,7 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
fi, err := f.Stat() fi, err := f.Stat()
if err != nil { if err != nil {
log().Errorx("stat file for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath)) log().Errorx("stat file for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
http.Error(w, "500 - internal server error"+recvid(), http.StatusInternalServerError) http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
return true return true
} }
// Redirect if the local path is a directory. // Redirect if the local path is a directory.
@ -251,7 +259,7 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
} }
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
log().Errorx("stat for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath)) log().Errorx("stat for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
http.Error(w, "500 - internal server error"+recvid(), http.StatusInternalServerError) http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
return true return true
} }
@ -292,7 +300,7 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
break break
} else if err != nil { } else if err != nil {
log().Errorx("reading directory for file listing", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath)) log().Errorx("reading directory for file listing", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
http.Error(w, "500 - internal server error"+recvid(), http.StatusInternalServerError) http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
return true return true
} }
} }
@ -370,24 +378,22 @@ func HandleRedirect(h *config.WebRedirect, w http.ResponseWriter, r *http.Reques
} }
// HandleForward handles a request by forwarding it to another webserver and // HandleForward handles a request by forwarding it to another webserver and
// passing the response on. I.e. a reverse proxy. // passing the response on. I.e. a reverse proxy. It handles websocket
// connections by monitoring the websocket handshake and then just passing along the
// websocket frames.
func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) { func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
log := func() *mlog.Log { log := func() *mlog.Log {
return xlog.WithContext(r.Context()) return xlog.WithContext(r.Context())
} }
recvid := func() string {
cid := mox.CidFromCtx(r.Context())
if cid <= 0 {
return ""
}
return " (id " + mox.ReceivedID(cid) + ")"
}
xr := *r xr := *r
r = &xr r = &xr
if h.StripPath { if h.StripPath {
u := *r.URL u := *r.URL
u.Path = r.URL.Path[len(path):] u.Path = r.URL.Path[len(path):]
if !strings.HasPrefix(u.Path, "/") {
u.Path = "/" + u.Path
}
u.RawPath = "" u.RawPath = ""
r.URL = &u r.URL = &u
} }
@ -409,8 +415,31 @@ func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request,
proto = "https" proto = "https"
} }
r.Header["X-Forwarded-Proto"] = []string{proto} r.Header["X-Forwarded-Proto"] = []string{proto}
// note: We are not using "ws" or "wss" for websocket. The request we are
// forwarding is http(s), and we don't yet know if the backend even supports
// websockets.
// todo: add Forwarded header? is anyone using it? // todo: add Forwarded header? is anyone using it?
// If we see an Upgrade: websocket, we're going to assume the client needs
// websocket and only attempt to talk websocket with the backend. If the backend
// doesn't do websocket, we'll send back a "bad request" response. For other values
// of Upgrade, we don't do anything special.
// https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml
// Upgrade: ../rfc/9110:2798
// Upgrade headers are not for http/1.0, ../rfc/9110:2880
// Websocket client "handshake" is described at ../rfc/6455:1134
upgrade := r.Header.Get("Upgrade")
if upgrade != "" && !(r.ProtoMajor == 1 && r.ProtoMinor == 0) {
// Websockets have case-insensitive string "websocket".
for _, s := range strings.Split(upgrade, ",") {
if strings.EqualFold(textproto.TrimString(s), "websocket") {
forwardWebsocket(h, w, r, path)
return true
}
}
}
// ReverseProxy will append any remaining path to the configured target URL. // ReverseProxy will append any remaining path to the configured target URL.
proxy := httputil.NewSingleHostReverseProxy(h.TargetURL) proxy := httputil.NewSingleHostReverseProxy(h.TargetURL)
proxy.FlushInterval = time.Duration(-1) // Flush after each write. proxy.FlushInterval = time.Duration(-1) // Flush after each write.
@ -422,9 +451,9 @@ func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request,
} }
log().Errorx("forwarding request to backend webserver", err, mlog.Field("url", r.URL)) log().Errorx("forwarding request to backend webserver", err, mlog.Field("url", r.URL))
if os.IsTimeout(err) { if os.IsTimeout(err) {
http.Error(w, "504 - gateway timeout"+recvid(), http.StatusGatewayTimeout) http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
} else { } else {
http.Error(w, "502 - bad gateway"+recvid(), http.StatusBadGateway) http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
} }
} }
whdr := w.Header() whdr := w.Header()
@ -434,3 +463,353 @@ func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request,
proxy.ServeHTTP(w, r) proxy.ServeHTTP(w, r)
return true return true
} }
var errResponseNotWebsocket = errors.New("not a valid websocket response to request")
var errNotImplemented = errors.New("functionality not yet implemented")
// Request has an Upgrade: websocket header. Check more websocketiness about the
// request. If it looks good, we forward it to the backend. If the backend responds
// with a valid websocket response, indicating it is indeed a websocket server, we
// pass the response along and start copying data between the client and the
// backend. We don't look at the frames and payloads. The backend already needs to
// know enough websocket to handle the frames. It wouldn't necessarily hurt to
// monitor the frames too, and check if they are valid, but it's quite a bit of
// work for little benefit. Besides, the whole point of websockets is to exchange
// bytes without HTTP being in the way, so let's do that.
func forwardWebsocket(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
log := func() *mlog.Log {
return xlog.WithContext(r.Context())
}
lw := w.(*loggingWriter)
lw.WebsocketRequest = true // For correct protocol in metrics.
// We check the requested websocket version first. A future websocket version may
// have different request requirements.
// ../rfc/6455:1160
wsversion := r.Header.Get("Sec-WebSocket-Version")
if wsversion != "13" {
// Indicate we only support version 13. Should get a client from the future to fall back to version 13.
// ../rfc/6455:1435
w.Header().Set("Sec-WebSocket-Version", "13")
http.Error(w, "400 - bad request - websockets only supported with version 13"+recvid(r), http.StatusBadRequest)
lw.error(fmt.Errorf("Sec-WebSocket-Version %q not supported", wsversion))
return true
}
// ../rfc/6455:1143
if r.Method != "GET" {
http.Error(w, "400 - bad request - websockets only allowed with method GET"+recvid(r), http.StatusBadRequest)
lw.error(fmt.Errorf("websocket request only allowed with method GET"))
return true
}
// ../rfc/6455:1153
var connectionUpgrade bool
for _, s := range strings.Split(r.Header.Get("Connection"), ",") {
if strings.EqualFold(textproto.TrimString(s), "upgrade") {
connectionUpgrade = true
break
}
}
if !connectionUpgrade {
http.Error(w, "400 - bad request - connection header must be \"upgrade\""+recvid(r), http.StatusBadRequest)
lw.error(fmt.Errorf(`connection header is %q, must be "upgrade"`, r.Header.Get("Connection")))
return true
}
// ../rfc/6455:1156
wskey := r.Header.Get("Sec-WebSocket-Key")
key, err := base64.StdEncoding.DecodeString(wskey)
if err != nil || len(key) != 16 {
http.Error(w, "400 - bad request - websockets requires Sec-WebSocket-Key with 16 bytes base64-encoded value"+recvid(r), http.StatusBadRequest)
lw.error(fmt.Errorf("bad Sec-WebSocket-Key %q, must be 16 byte base64-encoded value", wskey))
return true
}
// ../rfc/6455:1162
// We don't look at the origin header. The backend needs to handle it, if it thinks
// that helps...
// We also don't look at Sec-WebSocket-Protocol and Sec-WebSocket-Extensions. The
// backend can set them, but it doesn't influence our forwarding of the data.
// If this is not a hijacker, there is not point in connecting to the backend.
hj, ok := lw.W.(http.Hijacker)
var cbr *bufio.ReadWriter
if !ok {
log().Info("cannot turn http connection into tcp connection (http.Hijacker)")
http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
lw.error(fmt.Errorf("connection not a http.Hijacker (%T)", lw.W))
return
}
freq := *r
freq.Proto = "HTTP/1.1"
freq.ProtoMajor = 1
freq.ProtoMinor = 1
fresp, beconn, err := websocketTransact(r.Context(), h.TargetURL, &freq)
if err != nil {
if errors.Is(err, errResponseNotWebsocket) {
http.Error(w, "400 - bad request - websocket not supported"+recvid(r), http.StatusBadRequest)
} else if errors.Is(err, errNotImplemented) {
http.Error(w, "501 - not implemented - "+err.Error()+recvid(r), http.StatusNotImplemented)
} else if os.IsTimeout(err) {
http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
} else {
http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
}
lw.error(err)
return
}
defer func() {
if beconn != nil {
beconn.Close()
}
}()
// Hijack the client connection so we can write the response ourselves, and start
// copying the websocket frames.
var cconn net.Conn
cconn, cbr, err = hj.Hijack()
if err != nil {
log().Debugx("cannot turn http transaction into websocket connection", err)
http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
lw.error(err)
return
}
defer func() {
if cconn != nil {
cconn.Close()
}
}()
// Below this point, we can no longer write to the ResponseWriter.
// Mark as websocket response, for logging.
lw.WebsocketResponse = true
lw.setStatusCode(fresp.StatusCode)
for k, v := range h.ResponseHeaders {
fresp.Header.Add(k, v)
}
// Write the response to the client, completing its websocket handshake.
if err := fresp.Write(cconn); err != nil {
lw.error(fmt.Errorf("writing websocket response to client: %w", err))
return
}
errc := make(chan error, 1)
// Copy from client to backend.
go func() {
buf, err := cbr.Peek(cbr.Reader.Buffered())
if err != nil {
errc <- err
return
}
if len(buf) > 0 {
n, err := beconn.Write(buf)
if err != nil {
errc <- err
return
}
lw.SizeFromClient += int64(n)
}
n, err := io.Copy(beconn, cconn)
lw.SizeFromClient += n
errc <- err
}()
// Copy from backend to client.
go func() {
n, err := io.Copy(cconn, beconn)
lw.SizeToClient = n
errc <- err
}()
// Stop and close connection on first error from either size, typically a closed
// connection whose closing was already announced with a websocket frame.
lw.error(<-errc)
// Close connections so other goroutine stops as well.
cconn.Close()
beconn.Close()
cconn = nil
// Wait for goroutine so it has updated the logWriter.Size*Client fields before we
// continue with logging.
<-errc
return true
}
func websocketTransact(ctx context.Context, targetURL *url.URL, r *http.Request) (rresp *http.Response, rconn net.Conn, rerr error) {
log := func() *mlog.Log {
return xlog.WithContext(r.Context())
}
// Dial the backend, possibly doing TLS. We assume the net/http DefaultTransport is
// unmodified.
transport := http.DefaultTransport.(*http.Transport)
// We haven't implemented using a proxy for websocket requests yet. If we need one,
// return an error instead of trying to connect directly, which would be a
// potential security issue.
treq := *r
treq.URL = targetURL
if purl, err := transport.Proxy(&treq); err != nil {
return nil, nil, fmt.Errorf("determining proxy for websocket backend connection: %w", err)
} else if purl != nil {
return nil, nil, fmt.Errorf("%w: proxy required for websocket connection to backend", errNotImplemented) // todo: implement?
}
host, port, err := net.SplitHostPort(targetURL.Host)
if err != nil {
host = targetURL.Host
if targetURL.Scheme == "https" {
port = "443"
} else {
port = "80"
}
}
addr := net.JoinHostPort(host, port)
conn, err := transport.DialContext(r.Context(), "tcp", addr)
if err != nil {
return nil, nil, fmt.Errorf("dial: %w", err)
}
if targetURL.Scheme == "https" {
tlsconn := tls.Client(conn, transport.TLSClientConfig)
ctx, cancel := context.WithTimeout(r.Context(), transport.TLSHandshakeTimeout)
defer cancel()
if err := tlsconn.HandshakeContext(ctx); err != nil {
return nil, nil, fmt.Errorf("tls handshake: %w", err)
}
conn = tlsconn
}
defer func() {
if rerr != nil {
conn.Close()
}
}()
// todo: make timeout configurable?
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
log().Check(err, "set deadline for websocket request to backend")
}
// Set clean connection headers.
removeHopByHopHeaders(r.Header)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
// Write the websocket request to the backend.
if err := r.Write(conn); err != nil {
return nil, nil, fmt.Errorf("writing request to backend: %w", err)
}
// Read response from backend.
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, r)
if err != nil {
return nil, nil, fmt.Errorf("reading response from backend: %w", err)
}
defer func() {
if rerr != nil {
resp.Body.Close()
}
}()
if err := conn.SetDeadline(time.Time{}); err != nil {
log().Check(err, "clearing deadline on websocket connection to backend")
}
// Check that the response from the backend server indicates it is websocket. If
// not, don't pass the backend response, but an error that websocket is not
// appropriate.
if err := checkWebsocketResponse(resp, r); err != nil {
return resp, nil, err
}
// note: net/http.Response.Body documents that it implements io.Writer for a
// status: 101 response. But that's not the case when the response has been read
// with http.ReadResponse. We'll write to the connection directly.
buf, err := br.Peek(br.Buffered())
if err != nil {
return resp, nil, fmt.Errorf("peek at buffered data written by backend: %w", err)
}
return resp, websocketConn{io.MultiReader(bytes.NewReader(buf), conn), conn}, nil
}
// A net.Conn but with reads coming from an io multireader (due to buffered reader
// needed for http.ReadResponse).
type websocketConn struct {
r io.Reader
net.Conn
}
func (c websocketConn) Read(buf []byte) (int, error) {
return c.r.Read(buf)
}
// Check that an HTTP response (from a backend) is a valid websocket response, i.e.
// that it accepts the WebSocket "upgrade".
// ../rfc/6455:1299
func checkWebsocketResponse(resp *http.Response, req *http.Request) error {
if resp.StatusCode != 101 {
return fmt.Errorf("%w: response http status not 101 but %s", errResponseNotWebsocket, resp.Status)
}
if upgrade := resp.Header.Get("Upgrade"); !strings.EqualFold(upgrade, "websocket") {
return fmt.Errorf(`%w: response http status is 101, but Upgrade header is %q, should be "websocket"`, errResponseNotWebsocket, upgrade)
}
if connection := resp.Header.Get("Connection"); !strings.EqualFold(connection, "upgrade") {
return fmt.Errorf(`%w: response http status is 101, Upgrade is websocket, but Connection header is %q, should be "Upgrade"`, errResponseNotWebsocket, connection)
}
accept, err := base64.StdEncoding.DecodeString(resp.Header.Get("Sec-WebSocket-Accept"))
if err != nil {
return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but Sec-WebSocket-Accept header is not valid base64: %v`, errResponseNotWebsocket, err)
}
exp := sha1.Sum([]byte(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
if !bytes.Equal(accept, exp[:]) {
return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but backend Sec-WebSocket-Accept value does not match`, errResponseNotWebsocket)
}
// We don't have requirements for the other Sec-WebSocket headers. ../rfc/6455:1340
return nil
}
// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
// ../rfc/2616:5128
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding",
"Upgrade",
}
// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
// removeHopByHopHeaders removes hop-by-hop headers.
func removeHopByHopHeaders(h http.Header) {
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
// ../rfc/7230:2817
for _, f := range h["Connection"] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
// This behavior is superseded by the RFC 7230 Connection header, but
// preserve it for backwards compatibility.
// ../rfc/2616:5128
for _, f := range hopHeaders {
h.Del(f)
}
}

View file

@ -2,6 +2,9 @@ package http
import ( import (
"bytes" "bytes"
"fmt"
"io"
"net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -10,6 +13,8 @@ import (
"strings" "strings"
"testing" "testing"
"golang.org/x/net/websocket"
"github.com/mjl-/mox/mox-" "github.com/mjl-/mox/mox-"
) )
@ -119,3 +124,184 @@ func TestWebserver(t *testing.T) {
test("GET", "http://mox.example/bogus", nil, http.StatusNotFound, "", nil) // path not registered. test("GET", "http://mox.example/bogus", nil, http.StatusNotFound, "", nil) // path not registered.
test("GET", "http://bogus.mox.example/static/", nil, http.StatusNotFound, "", nil) // domain not registered. test("GET", "http://bogus.mox.example/static/", nil, http.StatusNotFound, "", nil) // domain not registered.
} }
func TestWebsocket(t *testing.T) {
os.RemoveAll("../testdata/websocket/data")
mox.ConfigStaticPath = "../testdata/websocket/mox.conf"
mox.ConfigDynamicPath = filepath.Join(filepath.Dir(mox.ConfigStaticPath), "domains.conf")
mox.MustLoadConfig(false)
srv := &serve{Webserver: true}
var handler http.Handler // Active handler during test.
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r)
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatalf("parsing backend url: %v", err)
}
backendURL.Path = "/"
// warning: it is not normally allowed to access the dynamic config without lock. don't propagate accesses like this!
mox.Conf.Dynamic.WebHandlers[len(mox.Conf.Dynamic.WebHandlers)-1].WebForward.TargetURL = backendURL
server := httptest.NewServer(srv)
defer server.Close()
serverURL, err := url.Parse(server.URL)
tcheck(t, err, "parsing server url")
_, port, err := net.SplitHostPort(serverURL.Host)
tcheck(t, err, "parsing host port in server url")
wsurl := fmt.Sprintf("ws://%s/ws/", net.JoinHostPort("localhost", port))
handler = websocket.Handler(func(c *websocket.Conn) {
io.Copy(c, c)
})
// Test a correct websocket connection.
wsconn, err := websocket.Dial(wsurl, "ignored", "http://ignored.example")
tcheck(t, err, "websocket dial")
_, err = fmt.Fprint(wsconn, "test")
tcheck(t, err, "write to websocket")
buf := make([]byte, 128)
n, err := wsconn.Read(buf)
tcheck(t, err, "read from websocket")
if string(buf[:n]) != "test" {
t.Fatalf(`got websocket data %q, expected "test"`, buf[:n])
}
err = wsconn.Close()
tcheck(t, err, "closing websocket connection")
// Test with server.ServeHTTP directly.
test := func(method string, reqhdrs map[string]string, expCode int, expHeaders map[string]string) {
t.Helper()
req := httptest.NewRequest(method, wsurl, nil)
for k, v := range reqhdrs {
req.Header.Add(k, v)
}
rw := httptest.NewRecorder()
rw.Body = &bytes.Buffer{}
srv.ServeHTTP(rw, req)
resp := rw.Result()
if resp.StatusCode != expCode {
t.Fatalf("got statuscode %d, expected %d", resp.StatusCode, expCode)
}
for k, v := range expHeaders {
if xv := resp.Header.Get(k); xv != v {
t.Fatalf("got %q for header %q, expected %q", xv, k, v)
}
}
}
wsreqhdrs := map[string]string{
"Upgrade": "keep-alive, websocket",
"Connection": "X, Upgrade",
"Sec-Websocket-Version": "13",
"Sec-Websocket-Key": "AAAAAAAAAAAAAAAAAAAAAA==",
}
test("POST", wsreqhdrs, http.StatusBadRequest, nil)
clone := func(m map[string]string) map[string]string {
r := map[string]string{}
for k, v := range m {
r[k] = v
}
return r
}
hdrs := clone(wsreqhdrs)
hdrs["Sec-Websocket-Version"] = "14"
test("GET", hdrs, http.StatusBadRequest, map[string]string{"Sec-Websocket-Version": "13"})
httpurl := fmt.Sprintf("http://%s/ws/", net.JoinHostPort("localhost", port))
// Must now do actual HTTP requests and read the HTTP response. Cannot call
// ServeHTTP because ResponseRecorder is not a http.Hijacker.
test = func(method string, reqhdrs map[string]string, expCode int, expHeaders map[string]string) {
t.Helper()
req, err := http.NewRequest(method, httpurl, nil)
for k, v := range reqhdrs {
req.Header.Add(k, v)
}
resp, err := http.DefaultClient.Do(req)
tcheck(t, err, "http transaction")
if resp.StatusCode != expCode {
t.Fatalf("got statuscode %d, expected %d", resp.StatusCode, expCode)
}
for k, v := range expHeaders {
if xv := resp.Header.Get(k); xv != v {
t.Fatalf("got %q for header %q, expected %q", xv, k, v)
}
}
}
hdrs = clone(wsreqhdrs)
hdrs["Sec-Websocket-Key"] = "malformed"
test("GET", hdrs, http.StatusBadRequest, nil)
hdrs = clone(wsreqhdrs)
hdrs["Sec-Websocket-Key"] = "c2hvcnQK" // "short"
test("GET", hdrs, http.StatusBadRequest, nil)
// Not responding with a 101, but with regular 200 OK response.
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad", http.StatusOK)
})
test("GET", wsreqhdrs, http.StatusBadRequest, nil)
// Respond with 101, but other websocket response headers missing.
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusSwitchingProtocols)
})
test("GET", wsreqhdrs, http.StatusBadRequest, nil)
// With Upgrade: websocket, without Connection: Upgrade
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Upgrade", "websocket")
w.WriteHeader(http.StatusSwitchingProtocols)
})
test("GET", wsreqhdrs, http.StatusBadRequest, nil)
// With malformed Sec-WebSocket-Accept response header.
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h.Set("Upgrade", "websocket")
h.Set("Connection", "Upgrade")
h.Set("Sec-WebSocket-Accept", "malformed")
w.WriteHeader(http.StatusSwitchingProtocols)
})
test("GET", wsreqhdrs, http.StatusBadRequest, nil)
// With malformed Sec-WebSocket-Accept response header.
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h.Set("Upgrade", "websocket")
h.Set("Connection", "Upgrade")
h.Set("Sec-WebSocket-Accept", "YmFk") // "bad"
w.WriteHeader(http.StatusSwitchingProtocols)
})
test("GET", wsreqhdrs, http.StatusBadRequest, nil)
// All good.
wsresphdrs := map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-Websocket-Accept": "ICX+Yqv66kxgM0FcWaLWlFLwTAI=",
"X-Test": "mox",
}
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h.Set("Upgrade", "websocket")
h.Set("Connection", "Upgrade")
h.Set("Sec-WebSocket-Accept", "ICX+Yqv66kxgM0FcWaLWlFLwTAI=")
w.WriteHeader(http.StatusSwitchingProtocols)
})
test("GET", wsreqhdrs, http.StatusSwitchingProtocols, wsresphdrs)
}

View file

@ -297,6 +297,16 @@ and many more, see http://sieve.info/documents
9157 Revised IANA Considerations for DNSSEC 9157 Revised IANA Considerations for DNSSEC
9276 Guidance for NSEC3 Parameter Settings 9276 Guidance for NSEC3 Parameter Settings
# HTTP
2616 Hypertext Transfer Protocol -- HTTP/1.1
7230 Hypertext Transfer Protocol (HTTP/1.1): Message Syntax and Routing
9110 HTTP Semantics
# Websockets
6455 The WebSocket Protocol
# More # More
3339 Date and Time on the Internet: Timestamps 3339 Date and Time on the Internet: Timestamps

19
testdata/websocket/domains.conf vendored Normal file
View file

@ -0,0 +1,19 @@
Domains:
mox.example:
LocalpartCaseSensitive: false
Accounts:
mjl:
Domain: mox.example
Destinations:
mjl@mox.example: nil
WebHandlers:
-
LogName: websocket
Domain: localhost
PathRegexp: ^/ws/
DontRedirectPlainHTTP: true
WebForward:
# replaced while testing
URL: http://127.0.0.1:1/
ResponseHeaders:
X-Test: mox

13
testdata/websocket/mox.conf vendored Normal file
View file

@ -0,0 +1,13 @@
DataDir: data
User: 1000
LogLevel: trace
Hostname: mox.example
Listeners:
local:
IPs:
- 0.0.0.0
WebserverHTTP:
Enabled: true
Postmaster:
Account: mjl
Mailbox: postmaster

106
vendor/golang.org/x/net/websocket/client.go generated vendored Normal file
View file

@ -0,0 +1,106 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"io"
"net"
"net/http"
"net/url"
)
// DialError is an error that occurs while dialling a websocket server.
type DialError struct {
*Config
Err error
}
func (e *DialError) Error() string {
return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error()
}
// NewConfig creates a new WebSocket config for client connection.
func NewConfig(server, origin string) (config *Config, err error) {
config = new(Config)
config.Version = ProtocolVersionHybi13
config.Location, err = url.ParseRequestURI(server)
if err != nil {
return
}
config.Origin, err = url.ParseRequestURI(origin)
if err != nil {
return
}
config.Header = http.Header(make(map[string][]string))
return
}
// NewClient creates a new WebSocket client connection over rwc.
func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
err = hybiClientHandshake(config, br, bw)
if err != nil {
return
}
buf := bufio.NewReadWriter(br, bw)
ws = newHybiClientConn(config, buf, rwc)
return
}
// Dial opens a new client connection to a WebSocket.
func Dial(url_, protocol, origin string) (ws *Conn, err error) {
config, err := NewConfig(url_, origin)
if err != nil {
return nil, err
}
if protocol != "" {
config.Protocol = []string{protocol}
}
return DialConfig(config)
}
var portMap = map[string]string{
"ws": "80",
"wss": "443",
}
func parseAuthority(location *url.URL) string {
if _, ok := portMap[location.Scheme]; ok {
if _, _, err := net.SplitHostPort(location.Host); err != nil {
return net.JoinHostPort(location.Host, portMap[location.Scheme])
}
}
return location.Host
}
// DialConfig opens a new client connection to a WebSocket with a config.
func DialConfig(config *Config) (ws *Conn, err error) {
var client net.Conn
if config.Location == nil {
return nil, &DialError{config, ErrBadWebSocketLocation}
}
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}
dialer := config.Dialer
if dialer == nil {
dialer = &net.Dialer{}
}
client, err = dialWithDialer(dialer, config)
if err != nil {
goto Error
}
ws, err = NewClient(config, client)
if err != nil {
client.Close()
goto Error
}
return
Error:
return nil, &DialError{config, err}
}

24
vendor/golang.org/x/net/websocket/dial.go generated vendored Normal file
View file

@ -0,0 +1,24 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"crypto/tls"
"net"
)
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
switch config.Location.Scheme {
case "ws":
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
case "wss":
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
default:
err = ErrBadScheme
}
return
}

583
vendor/golang.org/x/net/websocket/hybi.go generated vendored Normal file
View file

@ -0,0 +1,583 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
// This file implements a protocol of hybi draft.
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
)
const (
websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
closeStatusNormal = 1000
closeStatusGoingAway = 1001
closeStatusProtocolError = 1002
closeStatusUnsupportedData = 1003
closeStatusFrameTooLarge = 1004
closeStatusNoStatusRcvd = 1005
closeStatusAbnormalClosure = 1006
closeStatusBadMessageData = 1007
closeStatusPolicyViolation = 1008
closeStatusTooBigData = 1009
closeStatusExtensionMismatch = 1010
maxControlFramePayloadLength = 125
)
var (
ErrBadMaskingKey = &ProtocolError{"bad masking key"}
ErrBadPongMessage = &ProtocolError{"bad pong message"}
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
ErrNotImplemented = &ProtocolError{"not implemented"}
handshakeHeader = map[string]bool{
"Host": true,
"Upgrade": true,
"Connection": true,
"Sec-Websocket-Key": true,
"Sec-Websocket-Origin": true,
"Sec-Websocket-Version": true,
"Sec-Websocket-Protocol": true,
"Sec-Websocket-Accept": true,
}
)
// A hybiFrameHeader is a frame header as defined in hybi draft.
type hybiFrameHeader struct {
Fin bool
Rsv [3]bool
OpCode byte
Length int64
MaskingKey []byte
data *bytes.Buffer
}
// A hybiFrameReader is a reader for hybi frame.
type hybiFrameReader struct {
reader io.Reader
header hybiFrameHeader
pos int64
length int
}
func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
n, err = frame.reader.Read(msg)
if frame.header.MaskingKey != nil {
for i := 0; i < n; i++ {
msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
frame.pos++
}
}
return n, err
}
func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
func (frame *hybiFrameReader) HeaderReader() io.Reader {
if frame.header.data == nil {
return nil
}
if frame.header.data.Len() == 0 {
return nil
}
return frame.header.data
}
func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
func (frame *hybiFrameReader) Len() (n int) { return frame.length }
// A hybiFrameReaderFactory creates new frame reader based on its frame type.
type hybiFrameReaderFactory struct {
*bufio.Reader
}
// NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
// See Section 5.2 Base Framing protocol for detail.
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
hybiFrame := new(hybiFrameReader)
frame = hybiFrame
var header []byte
var b byte
// First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
for i := 0; i < 3; i++ {
j := uint(6 - i)
hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
}
hybiFrame.header.OpCode = header[0] & 0x0f
// Second byte. Mask/Payload len(7bits)
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
mask := (b & 0x80) != 0
b &= 0x7f
lengthFields := 0
switch {
case b <= 125: // Payload length 7bits.
hybiFrame.header.Length = int64(b)
case b == 126: // Payload length 7+16bits
lengthFields = 2
case b == 127: // Payload length 7+64bits
lengthFields = 8
}
for i := 0; i < lengthFields; i++ {
b, err = buf.ReadByte()
if err != nil {
return
}
if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits
b &= 0x7f
}
header = append(header, b)
hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
}
if mask {
// Masking key. 4 bytes.
for i := 0; i < 4; i++ {
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
}
}
hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
hybiFrame.header.data = bytes.NewBuffer(header)
hybiFrame.length = len(header) + int(hybiFrame.header.Length)
return
}
// A HybiFrameWriter is a writer for hybi frame.
type hybiFrameWriter struct {
writer *bufio.Writer
header *hybiFrameHeader
}
func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
var header []byte
var b byte
if frame.header.Fin {
b |= 0x80
}
for i := 0; i < 3; i++ {
if frame.header.Rsv[i] {
j := uint(6 - i)
b |= 1 << j
}
}
b |= frame.header.OpCode
header = append(header, b)
if frame.header.MaskingKey != nil {
b = 0x80
} else {
b = 0
}
lengthFields := 0
length := len(msg)
switch {
case length <= 125:
b |= byte(length)
case length < 65536:
b |= 126
lengthFields = 2
default:
b |= 127
lengthFields = 8
}
header = append(header, b)
for i := 0; i < lengthFields; i++ {
j := uint((lengthFields - i - 1) * 8)
b = byte((length >> j) & 0xff)
header = append(header, b)
}
if frame.header.MaskingKey != nil {
if len(frame.header.MaskingKey) != 4 {
return 0, ErrBadMaskingKey
}
header = append(header, frame.header.MaskingKey...)
frame.writer.Write(header)
data := make([]byte, length)
for i := range data {
data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
}
frame.writer.Write(data)
err = frame.writer.Flush()
return length, err
}
frame.writer.Write(header)
frame.writer.Write(msg)
err = frame.writer.Flush()
return length, err
}
func (frame *hybiFrameWriter) Close() error { return nil }
type hybiFrameWriterFactory struct {
*bufio.Writer
needMaskingKey bool
}
func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
if buf.needMaskingKey {
frameHeader.MaskingKey, err = generateMaskingKey()
if err != nil {
return nil, err
}
}
return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
}
type hybiFrameHandler struct {
conn *Conn
payloadType byte
}
func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
if handler.conn.IsServerConn() {
// The client MUST mask all frames sent to the server.
if frame.(*hybiFrameReader).header.MaskingKey == nil {
handler.WriteClose(closeStatusProtocolError)
return nil, io.EOF
}
} else {
// The server MUST NOT mask all frames.
if frame.(*hybiFrameReader).header.MaskingKey != nil {
handler.WriteClose(closeStatusProtocolError)
return nil, io.EOF
}
}
if header := frame.HeaderReader(); header != nil {
io.Copy(ioutil.Discard, header)
}
switch frame.PayloadType() {
case ContinuationFrame:
frame.(*hybiFrameReader).header.OpCode = handler.payloadType
case TextFrame, BinaryFrame:
handler.payloadType = frame.PayloadType()
case CloseFrame:
return nil, io.EOF
case PingFrame, PongFrame:
b := make([]byte, maxControlFramePayloadLength)
n, err := io.ReadFull(frame, b)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return nil, err
}
io.Copy(ioutil.Discard, frame)
if frame.PayloadType() == PingFrame {
if _, err := handler.WritePong(b[:n]); err != nil {
return nil, err
}
}
return nil, nil
}
return frame, nil
}
func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
handler.conn.wio.Lock()
defer handler.conn.wio.Unlock()
w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
if err != nil {
return err
}
msg := make([]byte, 2)
binary.BigEndian.PutUint16(msg, uint16(status))
_, err = w.Write(msg)
w.Close()
return err
}
func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
handler.conn.wio.Lock()
defer handler.conn.wio.Unlock()
w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
if err != nil {
return 0, err
}
n, err = w.Write(msg)
w.Close()
return n, err
}
// newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
if buf == nil {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
buf = bufio.NewReadWriter(br, bw)
}
ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
frameWriterFactory: hybiFrameWriterFactory{
buf.Writer, request == nil},
PayloadType: TextFrame,
defaultCloseStatus: closeStatusNormal}
ws.frameHandler = &hybiFrameHandler{conn: ws}
return ws
}
// generateMaskingKey generates a masking key for a frame.
func generateMaskingKey() (maskingKey []byte, err error) {
maskingKey = make([]byte, 4)
if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
return
}
return
}
// generateNonce generates a nonce consisting of a randomly selected 16-byte
// value that has been base64-encoded.
func generateNonce() (nonce []byte) {
key := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
panic(err)
}
nonce = make([]byte, 24)
base64.StdEncoding.Encode(nonce, key)
return
}
// removeZone removes IPv6 zone identifier from host.
// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
func removeZone(host string) string {
if !strings.HasPrefix(host, "[") {
return host
}
i := strings.LastIndex(host, "]")
if i < 0 {
return host
}
j := strings.LastIndex(host[:i], "%")
if j < 0 {
return host
}
return host[:j] + host[i:]
}
// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
func getNonceAccept(nonce []byte) (expected []byte, err error) {
h := sha1.New()
if _, err = h.Write(nonce); err != nil {
return
}
if _, err = h.Write([]byte(websocketGUID)); err != nil {
return
}
expected = make([]byte, 28)
base64.StdEncoding.Encode(expected, h.Sum(nil))
return
}
// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
// According to RFC 6874, an HTTP client, proxy, or other
// intermediary must remove any IPv6 zone identifier attached
// to an outgoing URI.
bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n")
bw.WriteString("Upgrade: websocket\r\n")
bw.WriteString("Connection: Upgrade\r\n")
nonce := generateNonce()
if config.handshakeData != nil {
nonce = []byte(config.handshakeData["key"])
}
bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
if config.Version != ProtocolVersionHybi13 {
return ErrBadProtocolVersion
}
bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
if len(config.Protocol) > 0 {
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
}
// TODO(ukai): send Sec-WebSocket-Extensions.
err = config.Header.WriteSubset(bw, handshakeHeader)
if err != nil {
return err
}
bw.WriteString("\r\n")
if err = bw.Flush(); err != nil {
return err
}
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
if err != nil {
return err
}
if resp.StatusCode != 101 {
return ErrBadStatus
}
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
return ErrBadUpgrade
}
expectedAccept, err := getNonceAccept(nonce)
if err != nil {
return err
}
if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
return ErrChallengeResponse
}
if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
return ErrUnsupportedExtensions
}
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
if offeredProtocol != "" {
protocolMatched := false
for i := 0; i < len(config.Protocol); i++ {
if config.Protocol[i] == offeredProtocol {
protocolMatched = true
break
}
}
if !protocolMatched {
return ErrBadWebSocketProtocol
}
config.Protocol = []string{offeredProtocol}
}
return nil
}
// newHybiClientConn creates a client WebSocket connection after handshake.
func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
return newHybiConn(config, buf, rwc, nil)
}
// A HybiServerHandshaker performs a server handshake using hybi draft protocol.
type hybiServerHandshaker struct {
*Config
accept []byte
}
func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
c.Version = ProtocolVersionHybi13
if req.Method != "GET" {
return http.StatusMethodNotAllowed, ErrBadRequestMethod
}
// HTTP version can be safely ignored.
if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
return http.StatusBadRequest, ErrNotWebSocket
}
key := req.Header.Get("Sec-Websocket-Key")
if key == "" {
return http.StatusBadRequest, ErrChallengeResponse
}
version := req.Header.Get("Sec-Websocket-Version")
switch version {
case "13":
c.Version = ProtocolVersionHybi13
default:
return http.StatusBadRequest, ErrBadWebSocketVersion
}
var scheme string
if req.TLS != nil {
scheme = "wss"
} else {
scheme = "ws"
}
c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
if err != nil {
return http.StatusBadRequest, err
}
protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
if protocol != "" {
protocols := strings.Split(protocol, ",")
for i := 0; i < len(protocols); i++ {
c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
}
}
c.accept, err = getNonceAccept([]byte(key))
if err != nil {
return http.StatusInternalServerError, err
}
return http.StatusSwitchingProtocols, nil
}
// Origin parses the Origin header in req.
// If the Origin header is not set, it returns nil and nil.
func Origin(config *Config, req *http.Request) (*url.URL, error) {
var origin string
switch config.Version {
case ProtocolVersionHybi13:
origin = req.Header.Get("Origin")
}
if origin == "" {
return nil, nil
}
return url.ParseRequestURI(origin)
}
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
if len(c.Protocol) > 0 {
if len(c.Protocol) != 1 {
// You need choose a Protocol in Handshake func in Server.
return ErrBadWebSocketProtocol
}
}
buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
buf.WriteString("Upgrade: websocket\r\n")
buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
if len(c.Protocol) > 0 {
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
}
// TODO(ukai): send Sec-WebSocket-Extensions.
if c.Header != nil {
err := c.Header.WriteSubset(buf, handshakeHeader)
if err != nil {
return err
}
}
buf.WriteString("\r\n")
return buf.Flush()
}
func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
return newHybiServerConn(c.Config, buf, rwc, request)
}
// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
return newHybiConn(config, buf, rwc, request)
}

113
vendor/golang.org/x/net/websocket/server.go generated vendored Normal file
View file

@ -0,0 +1,113 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"fmt"
"io"
"net/http"
)
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
code, err := hs.ReadHandshake(buf.Reader, req)
if err == ErrBadWebSocketVersion {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if err != nil {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if handshake != nil {
err = handshake(config, req)
if err != nil {
code = http.StatusForbidden
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
}
err = hs.AcceptHandshake(buf.Writer)
if err != nil {
code = http.StatusBadRequest
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
conn = hs.NewServerConn(buf, rwc, req)
return
}
// Server represents a server of a WebSocket.
type Server struct {
// Config is a WebSocket configuration for new WebSocket connection.
Config
// Handshake is an optional function in WebSocket handshake.
// For example, you can check, or don't check Origin header.
// Another example, you can select config.Protocol.
Handshake func(*Config, *http.Request) error
// Handler handles a WebSocket connection.
Handler
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.serveWebSocket(w, req)
}
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.Error())
}
// The server should abort the WebSocket connection if it finds
// the client did not send a handshake that matches with protocol
// specification.
defer rwc.Close()
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
if err != nil {
return
}
if conn == nil {
panic("unexpected nil conn")
}
s.Handler(conn)
}
// Handler is a simple interface to a WebSocket browser client.
// It checks if Origin header is valid URL by default.
// You might want to verify websocket.Conn.Config().Origin in the func.
// If you use Server instead of Handler, you could call websocket.Origin and
// check the origin in your Handshake func. So, if you want to accept
// non-browser clients, which do not send an Origin header, set a
// Server.Handshake that does not check the origin.
type Handler func(*Conn)
func checkOrigin(config *Config, req *http.Request) (err error) {
config.Origin, err = Origin(config, req)
if err == nil && config.Origin == nil {
return fmt.Errorf("null origin")
}
return err
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s := Server{Handler: h, Handshake: checkOrigin}
s.serveWebSocket(w, req)
}

449
vendor/golang.org/x/net/websocket/websocket.go generated vendored Normal file
View file

@ -0,0 +1,449 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package websocket implements a client and server for the WebSocket protocol
// as specified in RFC 6455.
//
// This package currently lacks some features found in an alternative
// and more actively maintained WebSocket package:
//
// https://pkg.go.dev/nhooyr.io/websocket
package websocket // import "golang.org/x/net/websocket"
import (
"bufio"
"crypto/tls"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"sync"
"time"
)
const (
ProtocolVersionHybi13 = 13
ProtocolVersionHybi = ProtocolVersionHybi13
SupportedProtocolVersion = "13"
ContinuationFrame = 0
TextFrame = 1
BinaryFrame = 2
CloseFrame = 8
PingFrame = 9
PongFrame = 10
UnknownFrame = 255
DefaultMaxPayloadBytes = 32 << 20 // 32MB
)
// ProtocolError represents WebSocket protocol errors.
type ProtocolError struct {
ErrorString string
}
func (err *ProtocolError) Error() string { return err.ErrorString }
var (
ErrBadProtocolVersion = &ProtocolError{"bad protocol version"}
ErrBadScheme = &ProtocolError{"bad scheme"}
ErrBadStatus = &ProtocolError{"bad status"}
ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"}
ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"}
ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"}
ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"}
ErrBadFrame = &ProtocolError{"bad frame"}
ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"}
ErrNotWebSocket = &ProtocolError{"not websocket protocol"}
ErrBadRequestMethod = &ProtocolError{"bad method"}
ErrNotSupported = &ProtocolError{"not supported"}
)
// ErrFrameTooLarge is returned by Codec's Receive method if payload size
// exceeds limit set by Conn.MaxPayloadBytes
var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
// Addr is an implementation of net.Addr for WebSocket.
type Addr struct {
*url.URL
}
// Network returns the network type for a WebSocket, "websocket".
func (addr *Addr) Network() string { return "websocket" }
// Config is a WebSocket configuration
type Config struct {
// A WebSocket server address.
Location *url.URL
// A Websocket client origin.
Origin *url.URL
// WebSocket subprotocols.
Protocol []string
// WebSocket protocol version.
Version int
// TLS config for secure WebSocket (wss).
TlsConfig *tls.Config
// Additional header fields to be sent in WebSocket opening handshake.
Header http.Header
// Dialer used when opening websocket connections.
Dialer *net.Dialer
handshakeData map[string]string
}
// serverHandshaker is an interface to handle WebSocket server side handshake.
type serverHandshaker interface {
// ReadHandshake reads handshake request message from client.
// Returns http response code and error if any.
ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
// AcceptHandshake accepts the client handshake request and sends
// handshake response back to client.
AcceptHandshake(buf *bufio.Writer) (err error)
// NewServerConn creates a new WebSocket connection.
NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
}
// frameReader is an interface to read a WebSocket frame.
type frameReader interface {
// Reader is to read payload of the frame.
io.Reader
// PayloadType returns payload type.
PayloadType() byte
// HeaderReader returns a reader to read header of the frame.
HeaderReader() io.Reader
// TrailerReader returns a reader to read trailer of the frame.
// If it returns nil, there is no trailer in the frame.
TrailerReader() io.Reader
// Len returns total length of the frame, including header and trailer.
Len() int
}
// frameReaderFactory is an interface to creates new frame reader.
type frameReaderFactory interface {
NewFrameReader() (r frameReader, err error)
}
// frameWriter is an interface to write a WebSocket frame.
type frameWriter interface {
// Writer is to write payload of the frame.
io.WriteCloser
}
// frameWriterFactory is an interface to create new frame writer.
type frameWriterFactory interface {
NewFrameWriter(payloadType byte) (w frameWriter, err error)
}
type frameHandler interface {
HandleFrame(frame frameReader) (r frameReader, err error)
WriteClose(status int) (err error)
}
// Conn represents a WebSocket connection.
//
// Multiple goroutines may invoke methods on a Conn simultaneously.
type Conn struct {
config *Config
request *http.Request
buf *bufio.ReadWriter
rwc io.ReadWriteCloser
rio sync.Mutex
frameReaderFactory
frameReader
wio sync.Mutex
frameWriterFactory
frameHandler
PayloadType byte
defaultCloseStatus int
// MaxPayloadBytes limits the size of frame payload received over Conn
// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
MaxPayloadBytes int
}
// Read implements the io.Reader interface:
// it reads data of a frame from the WebSocket connection.
// if msg is not large enough for the frame data, it fills the msg and next Read
// will read the rest of the frame data.
// it reads Text frame or Binary frame.
func (ws *Conn) Read(msg []byte) (n int, err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
again:
if ws.frameReader == nil {
frame, err := ws.frameReaderFactory.NewFrameReader()
if err != nil {
return 0, err
}
ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
if err != nil {
return 0, err
}
if ws.frameReader == nil {
goto again
}
}
n, err = ws.frameReader.Read(msg)
if err == io.EOF {
if trailer := ws.frameReader.TrailerReader(); trailer != nil {
io.Copy(ioutil.Discard, trailer)
}
ws.frameReader = nil
goto again
}
return n, err
}
// Write implements the io.Writer interface:
// it writes data as a frame to the WebSocket connection.
func (ws *Conn) Write(msg []byte) (n int, err error) {
ws.wio.Lock()
defer ws.wio.Unlock()
w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
if err != nil {
return 0, err
}
n, err = w.Write(msg)
w.Close()
return n, err
}
// Close implements the io.Closer interface.
func (ws *Conn) Close() error {
err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
err1 := ws.rwc.Close()
if err != nil {
return err
}
return err1
}
// IsClientConn reports whether ws is a client-side connection.
func (ws *Conn) IsClientConn() bool { return ws.request == nil }
// IsServerConn reports whether ws is a server-side connection.
func (ws *Conn) IsServerConn() bool { return ws.request != nil }
// LocalAddr returns the WebSocket Origin for the connection for client, or
// the WebSocket location for server.
func (ws *Conn) LocalAddr() net.Addr {
if ws.IsClientConn() {
return &Addr{ws.config.Origin}
}
return &Addr{ws.config.Location}
}
// RemoteAddr returns the WebSocket location for the connection for client, or
// the Websocket Origin for server.
func (ws *Conn) RemoteAddr() net.Addr {
if ws.IsClientConn() {
return &Addr{ws.config.Location}
}
return &Addr{ws.config.Origin}
}
var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
// SetDeadline sets the connection's network read & write deadlines.
func (ws *Conn) SetDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetDeadline(t)
}
return errSetDeadline
}
// SetReadDeadline sets the connection's network read deadline.
func (ws *Conn) SetReadDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetReadDeadline(t)
}
return errSetDeadline
}
// SetWriteDeadline sets the connection's network write deadline.
func (ws *Conn) SetWriteDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetWriteDeadline(t)
}
return errSetDeadline
}
// Config returns the WebSocket config.
func (ws *Conn) Config() *Config { return ws.config }
// Request returns the http request upgraded to the WebSocket.
// It is nil for client side.
func (ws *Conn) Request() *http.Request { return ws.request }
// Codec represents a symmetric pair of functions that implement a codec.
type Codec struct {
Marshal func(v interface{}) (data []byte, payloadType byte, err error)
Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
}
// Send sends v marshaled by cd.Marshal as single frame to ws.
func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
data, payloadType, err := cd.Marshal(v)
if err != nil {
return err
}
ws.wio.Lock()
defer ws.wio.Unlock()
w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
if err != nil {
return err
}
_, err = w.Write(data)
w.Close()
return err
}
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
// in v. The whole frame payload is read to an in-memory buffer; max size of
// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
// completely. The next call to Receive would read and discard leftover data of
// previous oversized frame before processing next frame.
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
if ws.frameReader != nil {
_, err = io.Copy(ioutil.Discard, ws.frameReader)
if err != nil {
return err
}
ws.frameReader = nil
}
again:
frame, err := ws.frameReaderFactory.NewFrameReader()
if err != nil {
return err
}
frame, err = ws.frameHandler.HandleFrame(frame)
if err != nil {
return err
}
if frame == nil {
goto again
}
maxPayloadBytes := ws.MaxPayloadBytes
if maxPayloadBytes == 0 {
maxPayloadBytes = DefaultMaxPayloadBytes
}
if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
// payload size exceeds limit, no need to call Unmarshal
//
// set frameReader to current oversized frame so that
// the next call to this function can drain leftover
// data before processing the next frame
ws.frameReader = frame
return ErrFrameTooLarge
}
payloadType := frame.PayloadType()
data, err := ioutil.ReadAll(frame)
if err != nil {
return err
}
return cd.Unmarshal(data, payloadType, v)
}
func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
switch data := v.(type) {
case string:
return []byte(data), TextFrame, nil
case []byte:
return data, BinaryFrame, nil
}
return nil, UnknownFrame, ErrNotSupported
}
func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
switch data := v.(type) {
case *string:
*data = string(msg)
return nil
case *[]byte:
*data = msg
return nil
}
return ErrNotSupported
}
/*
Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
To send/receive text frame, use string type.
To send/receive binary frame, use []byte type.
Trivial usage:
import "websocket"
// receive text frame
var message string
websocket.Message.Receive(ws, &message)
// send text frame
message = "hello"
websocket.Message.Send(ws, message)
// receive binary frame
var data []byte
websocket.Message.Receive(ws, &data)
// send binary frame
data = []byte{0, 1, 2}
websocket.Message.Send(ws, data)
*/
var Message = Codec{marshal, unmarshal}
func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
msg, err = json.Marshal(v)
return msg, TextFrame, err
}
func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
return json.Unmarshal(msg, v)
}
/*
JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
Trivial usage:
import "websocket"
type T struct {
Msg string
Count int
}
// receive JSON type T
var data T
websocket.JSON.Receive(ws, &data)
// send JSON type T
websocket.JSON.Send(ws, data)
*/
var JSON = Codec{jsonMarshal, jsonUnmarshal}

1
vendor/modules.txt vendored
View file

@ -71,6 +71,7 @@ golang.org/x/mod/semver
golang.org/x/net/html golang.org/x/net/html
golang.org/x/net/html/atom golang.org/x/net/html/atom
golang.org/x/net/idna golang.org/x/net/idna
golang.org/x/net/websocket
# golang.org/x/sys v0.7.0 # golang.org/x/sys v0.7.0
## explicit; go 1.17 ## explicit; go 1.17
golang.org/x/sys/cpu golang.org/x/sys/cpu