mirror of
https://github.com/mjl-/mox.git
synced 2024-12-25 16:03:48 +03:00
add reverse proxying websocket connections
if we recognize that a request for a WebForward is trying to turn the connection into a websocket, we forward it to the backend and check if the backend understands the websocket request. if so, we pass back the upgrade response and get out of the way, copying bytes between the two. we do log the total amount of bytes read from the client and written to the client. if the backend doesn't respond with a websocke response, or an invalid one, we respond with a regular non-websocket response. and we log details about the failed connection, should help with debugging and any bug reports. we don't try to parse the websocket framing, that's between the client and the backend. we could try to parse it, in part to protect the backend from bad frames, but it would be a lot of work and could be brittle in the face of extensions. this doesn't yet handle websocket connections when a http proxy is configured. we'll implement it when someone needs it. we do recognize it and fail the connection. for issue #25
This commit is contained in:
parent
aca64828bd
commit
259928ab62
15 changed files with 1966 additions and 49 deletions
|
@ -393,7 +393,7 @@ func (wr WebRedirect) equal(o WebRedirect) bool {
|
|||
|
||||
type WebForward struct {
|
||||
StripPath bool `sconf:"optional" sconf-doc:"Strip the matching WebHandler path from the WebHandler before forwarding the request."`
|
||||
URL string `sconf-doc:"URL to forward HTTP requests to, e.g. http://127.0.0.1:8123/base. If StripPath is false the full request path is added to the URL. Host headers are sent unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string in the URL is ignored. Requests are made using Go's net/http.DefaultTransport that takes environment variables HTTP_PROXY and HTTPS_PROXY into account."`
|
||||
URL string `sconf-doc:"URL to forward HTTP requests to, e.g. http://127.0.0.1:8123/base. If StripPath is false the full request path is added to the URL. Host headers are sent unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string in the URL is ignored. Requests are made using Go's net/http.DefaultTransport that takes environment variables HTTP_PROXY and HTTPS_PROXY into account. Websocket connections are forwarded and data is copied between client and backend without looking at the framing. The websocket 'version' and 'key'/'accept' headers are verified during the handshake, but other websocket headers, including 'origin', 'protocol' and 'extensions' headers, are not inspected and the backend is responsible for verifying/interpreting them."`
|
||||
ResponseHeaders map[string]string `sconf:"optional" sconf-doc:"Headers to add to the response. Useful for adding security- and cache-related headers."`
|
||||
|
||||
TargetURL *url.URL `sconf:"-" json:"-"`
|
||||
|
|
|
@ -715,6 +715,11 @@ describe-static" and "mox config describe-domains":
|
|||
# unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string
|
||||
# in the URL is ignored. Requests are made using Go's net/http.DefaultTransport
|
||||
# that takes environment variables HTTP_PROXY and HTTPS_PROXY into account.
|
||||
# Websocket connections are forwarded and data is copied between client and
|
||||
# backend without looking at the framing. The websocket 'version' and
|
||||
# 'key'/'accept' headers are verified during the handshake, but other websocket
|
||||
# headers, including 'origin', 'protocol' and 'extensions' headers, are not
|
||||
# inspected and the backend is responsible for verifying/interpreting them.
|
||||
URL:
|
||||
|
||||
# Headers to add to the response. Useful for adding security- and cache-related
|
||||
|
|
|
@ -1855,7 +1855,7 @@ const webserver = async () => {
|
|||
),
|
||||
dom.td(
|
||||
'URL',
|
||||
attr({title: "URL to forward HTTP requests to, e.g. http://127.0.0.1:8123/base. If StripPath is false the full request path is added to the URL. Host headers are sent unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string in the URL is ignored. Requests are made using Go's net/http.DefaultTransport that takes environment variables HTTP_PROXY and HTTPS_PROXY into account."}),
|
||||
attr({title: "URL to forward HTTP requests to, e.g. http://127.0.0.1:8123/base. If StripPath is false the full request path is added to the URL. Host headers are sent unmodified. New X-Forwarded-{For,Host,Proto} headers are set. Any query string in the URL is ignored. Requests are made using Go's net/http.DefaultTransport that takes environment variables HTTP_PROXY and HTTPS_PROXY into account. Websocket connections are forwarded and data is copied between client and backend without looking at the framing. The websocket 'version' and 'key'/'accept' headers are verified during the handshake, but other websocket headers, including 'origin', 'protocol' and 'extensions' headers, are not inspected and the backend is responsible for verifying/interpreting them."}),
|
||||
),
|
||||
dom.td(
|
||||
dom.span(
|
||||
|
|
77
http/web.go
77
http/web.go
|
@ -42,7 +42,7 @@ var (
|
|||
},
|
||||
[]string{
|
||||
"handler", // Name from webhandler, can be empty.
|
||||
"proto", // "http" or "https"
|
||||
"proto", // "http", "https", "ws", "wss"
|
||||
"method", // "(unknown)" and otherwise only common verbs
|
||||
"code",
|
||||
},
|
||||
|
@ -58,7 +58,7 @@ var (
|
|||
},
|
||||
[]string{
|
||||
"handler", // Name from webhandler, can be empty.
|
||||
"proto", // "http" or "https"
|
||||
"proto", // "http", "https", "ws", "wss"
|
||||
"method", // "(unknown)" and otherwise only common verbs
|
||||
"code",
|
||||
},
|
||||
|
@ -69,22 +69,37 @@ var (
|
|||
|
||||
// http.ResponseWriter that writes access log and tracks metrics at end of response.
|
||||
type loggingWriter struct {
|
||||
W http.ResponseWriter // Calls are forwarded.
|
||||
Start time.Time
|
||||
R *http.Request
|
||||
W http.ResponseWriter // Calls are forwarded.
|
||||
Start time.Time
|
||||
R *http.Request
|
||||
WebsocketRequest bool // Whether request from was websocket.
|
||||
|
||||
Handler string // Set by router.
|
||||
|
||||
// Set by handlers.
|
||||
StatusCode int
|
||||
Size int64
|
||||
WriteErr error
|
||||
StatusCode int
|
||||
Size int64 // Of data served, for non-websocket responses.
|
||||
Err error
|
||||
WebsocketResponse bool // If this was a successful websocket connection with backend.
|
||||
SizeFromClient, SizeToClient int64 // Websocket data.
|
||||
}
|
||||
|
||||
func (w *loggingWriter) Header() http.Header {
|
||||
return w.W.Header()
|
||||
}
|
||||
|
||||
// protocol, for logging.
|
||||
func (w *loggingWriter) proto(websocket bool) string {
|
||||
proto := "http"
|
||||
if websocket {
|
||||
proto = "ws"
|
||||
}
|
||||
if w.R.TLS != nil {
|
||||
proto += "s"
|
||||
}
|
||||
return proto
|
||||
}
|
||||
|
||||
func (w *loggingWriter) setStatusCode(statusCode int) {
|
||||
if w.StatusCode != 0 {
|
||||
return
|
||||
|
@ -92,11 +107,7 @@ func (w *loggingWriter) setStatusCode(statusCode int) {
|
|||
|
||||
w.StatusCode = statusCode
|
||||
method := metricHTTPMethod(w.R.Method)
|
||||
proto := "http"
|
||||
if w.R.TLS != nil {
|
||||
proto = "https"
|
||||
}
|
||||
metricRequest.WithLabelValues(w.Handler, proto, method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
|
||||
metricRequest.WithLabelValues(w.Handler, w.proto(w.WebsocketRequest), method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
|
||||
}
|
||||
|
||||
func (w *loggingWriter) Write(buf []byte) (int, error) {
|
||||
|
@ -108,8 +119,8 @@ func (w *loggingWriter) Write(buf []byte) (int, error) {
|
|||
if n > 0 {
|
||||
w.Size += int64(n)
|
||||
}
|
||||
if err != nil && w.WriteErr == nil {
|
||||
w.WriteErr = err
|
||||
if err != nil {
|
||||
w.error(err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
@ -136,13 +147,15 @@ func metricHTTPMethod(method string) string {
|
|||
return "(other)"
|
||||
}
|
||||
|
||||
func (w *loggingWriter) error(err error) {
|
||||
if w.Err == nil {
|
||||
w.Err = err
|
||||
}
|
||||
}
|
||||
|
||||
func (w *loggingWriter) Done() {
|
||||
method := metricHTTPMethod(w.R.Method)
|
||||
proto := "http"
|
||||
if w.R.TLS != nil {
|
||||
proto = "https"
|
||||
}
|
||||
metricResponse.WithLabelValues(w.Handler, proto, method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
|
||||
metricResponse.WithLabelValues(w.Handler, w.proto(w.WebsocketResponse), method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
|
||||
|
||||
tlsinfo := "plain"
|
||||
if w.R.TLS != nil {
|
||||
|
@ -152,25 +165,41 @@ func (w *loggingWriter) Done() {
|
|||
tlsinfo = "(other)"
|
||||
}
|
||||
}
|
||||
err := w.WriteErr
|
||||
err := w.Err
|
||||
if err == nil {
|
||||
err = w.R.Context().Err()
|
||||
}
|
||||
xlog.WithContext(w.R.Context()).Debugx("http request", err,
|
||||
fields := []mlog.Pair{
|
||||
mlog.Field("httpaccess", ""),
|
||||
mlog.Field("handler", w.Handler),
|
||||
mlog.Field("method", method),
|
||||
mlog.Field("url", w.R.URL),
|
||||
mlog.Field("host", w.R.Host),
|
||||
mlog.Field("duration", time.Since(w.Start)),
|
||||
mlog.Field("size", w.Size),
|
||||
mlog.Field("statuscode", w.StatusCode),
|
||||
mlog.Field("proto", strings.ToLower(w.R.Proto)),
|
||||
mlog.Field("remoteaddr", w.R.RemoteAddr),
|
||||
mlog.Field("tlsinfo", tlsinfo),
|
||||
mlog.Field("useragent", w.R.Header.Get("User-Agent")),
|
||||
mlog.Field("referrr", w.R.Header.Get("Referrer")),
|
||||
)
|
||||
}
|
||||
if w.WebsocketRequest {
|
||||
fields = append(fields,
|
||||
mlog.Field("websocketrequest", true),
|
||||
)
|
||||
}
|
||||
if w.WebsocketResponse {
|
||||
fields = append(fields,
|
||||
mlog.Field("websocket", true),
|
||||
mlog.Field("sizetoclient", w.SizeToClient),
|
||||
mlog.Field("sizefromclient", w.SizeFromClient),
|
||||
)
|
||||
} else {
|
||||
fields = append(fields,
|
||||
mlog.Field("size", w.Size),
|
||||
)
|
||||
}
|
||||
xlog.WithContext(w.R.Context()).Debugx("http request", err, fields...)
|
||||
}
|
||||
|
||||
// Set some http headers that should prevent potential abuse. Better safe than sorry.
|
||||
|
|
|
@ -1,14 +1,22 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha1"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
htmltemplate "html/template"
|
||||
"io"
|
||||
golog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
@ -23,6 +31,14 @@ import (
|
|||
"github.com/mjl-/mox/moxio"
|
||||
)
|
||||
|
||||
func recvid(r *http.Request) string {
|
||||
cid := mox.CidFromCtx(r.Context())
|
||||
if cid <= 0 {
|
||||
return ""
|
||||
}
|
||||
return " (id " + mox.ReceivedID(cid) + ")"
|
||||
}
|
||||
|
||||
// WebHandle serves an HTTP request by going through the list of WebHandlers,
|
||||
// check if there is a domain+path match, and running the handler if so.
|
||||
// WebHandle runs after the built-in handlers for mta-sts, autoconfig, etc.
|
||||
|
@ -131,14 +147,6 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
|
|||
log := func() *mlog.Log {
|
||||
return xlog.WithContext(r.Context())
|
||||
}
|
||||
recvid := func() string {
|
||||
cid := mox.CidFromCtx(r.Context())
|
||||
if cid <= 0 {
|
||||
return ""
|
||||
}
|
||||
return " (id " + mox.ReceivedID(cid) + ")"
|
||||
}
|
||||
|
||||
if r.Method != "GET" && r.Method != "HEAD" {
|
||||
if h.ContinueNotFound {
|
||||
// Give another handler that is presumbly configured, for the same path, a chance.
|
||||
|
@ -194,7 +202,7 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
|
|||
ifi, err = index.Stat()
|
||||
if err != nil {
|
||||
log().Errorx("stat index.html in directory we cannot list", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
|
||||
http.Error(w, "500 - internal server error"+recvid(), http.StatusInternalServerError)
|
||||
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
|
||||
return true
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
|
@ -205,7 +213,7 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
|
|||
return true
|
||||
}
|
||||
log().Errorx("open file for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
|
||||
http.Error(w, "500 - internal server error"+recvid(), http.StatusInternalServerError)
|
||||
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
|
||||
return true
|
||||
}
|
||||
defer f.Close()
|
||||
|
@ -213,7 +221,7 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
|
|||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
log().Errorx("stat file for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
|
||||
http.Error(w, "500 - internal server error"+recvid(), http.StatusInternalServerError)
|
||||
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
|
||||
return true
|
||||
}
|
||||
// Redirect if the local path is a directory.
|
||||
|
@ -251,7 +259,7 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
|
|||
}
|
||||
if !os.IsNotExist(err) {
|
||||
log().Errorx("stat for static file serving", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
|
||||
http.Error(w, "500 - internal server error"+recvid(), http.StatusInternalServerError)
|
||||
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -292,7 +300,7 @@ func HandleStatic(h *config.WebStatic, w http.ResponseWriter, r *http.Request) (
|
|||
break
|
||||
} else if err != nil {
|
||||
log().Errorx("reading directory for file listing", err, mlog.Field("url", r.URL), mlog.Field("fspath", fspath))
|
||||
http.Error(w, "500 - internal server error"+recvid(), http.StatusInternalServerError)
|
||||
http.Error(w, "500 - internal server error"+recvid(r), http.StatusInternalServerError)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@ -370,24 +378,22 @@ func HandleRedirect(h *config.WebRedirect, w http.ResponseWriter, r *http.Reques
|
|||
}
|
||||
|
||||
// HandleForward handles a request by forwarding it to another webserver and
|
||||
// passing the response on. I.e. a reverse proxy.
|
||||
// passing the response on. I.e. a reverse proxy. It handles websocket
|
||||
// connections by monitoring the websocket handshake and then just passing along the
|
||||
// websocket frames.
|
||||
func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
|
||||
log := func() *mlog.Log {
|
||||
return xlog.WithContext(r.Context())
|
||||
}
|
||||
recvid := func() string {
|
||||
cid := mox.CidFromCtx(r.Context())
|
||||
if cid <= 0 {
|
||||
return ""
|
||||
}
|
||||
return " (id " + mox.ReceivedID(cid) + ")"
|
||||
}
|
||||
|
||||
xr := *r
|
||||
r = &xr
|
||||
if h.StripPath {
|
||||
u := *r.URL
|
||||
u.Path = r.URL.Path[len(path):]
|
||||
if !strings.HasPrefix(u.Path, "/") {
|
||||
u.Path = "/" + u.Path
|
||||
}
|
||||
u.RawPath = ""
|
||||
r.URL = &u
|
||||
}
|
||||
|
@ -409,8 +415,31 @@ func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request,
|
|||
proto = "https"
|
||||
}
|
||||
r.Header["X-Forwarded-Proto"] = []string{proto}
|
||||
// note: We are not using "ws" or "wss" for websocket. The request we are
|
||||
// forwarding is http(s), and we don't yet know if the backend even supports
|
||||
// websockets.
|
||||
|
||||
// todo: add Forwarded header? is anyone using it?
|
||||
|
||||
// If we see an Upgrade: websocket, we're going to assume the client needs
|
||||
// websocket and only attempt to talk websocket with the backend. If the backend
|
||||
// doesn't do websocket, we'll send back a "bad request" response. For other values
|
||||
// of Upgrade, we don't do anything special.
|
||||
// https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml
|
||||
// Upgrade: ../rfc/9110:2798
|
||||
// Upgrade headers are not for http/1.0, ../rfc/9110:2880
|
||||
// Websocket client "handshake" is described at ../rfc/6455:1134
|
||||
upgrade := r.Header.Get("Upgrade")
|
||||
if upgrade != "" && !(r.ProtoMajor == 1 && r.ProtoMinor == 0) {
|
||||
// Websockets have case-insensitive string "websocket".
|
||||
for _, s := range strings.Split(upgrade, ",") {
|
||||
if strings.EqualFold(textproto.TrimString(s), "websocket") {
|
||||
forwardWebsocket(h, w, r, path)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReverseProxy will append any remaining path to the configured target URL.
|
||||
proxy := httputil.NewSingleHostReverseProxy(h.TargetURL)
|
||||
proxy.FlushInterval = time.Duration(-1) // Flush after each write.
|
||||
|
@ -422,9 +451,9 @@ func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request,
|
|||
}
|
||||
log().Errorx("forwarding request to backend webserver", err, mlog.Field("url", r.URL))
|
||||
if os.IsTimeout(err) {
|
||||
http.Error(w, "504 - gateway timeout"+recvid(), http.StatusGatewayTimeout)
|
||||
http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
|
||||
} else {
|
||||
http.Error(w, "502 - bad gateway"+recvid(), http.StatusBadGateway)
|
||||
http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
|
||||
}
|
||||
}
|
||||
whdr := w.Header()
|
||||
|
@ -434,3 +463,353 @@ func HandleForward(h *config.WebForward, w http.ResponseWriter, r *http.Request,
|
|||
proxy.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
var errResponseNotWebsocket = errors.New("not a valid websocket response to request")
|
||||
var errNotImplemented = errors.New("functionality not yet implemented")
|
||||
|
||||
// Request has an Upgrade: websocket header. Check more websocketiness about the
|
||||
// request. If it looks good, we forward it to the backend. If the backend responds
|
||||
// with a valid websocket response, indicating it is indeed a websocket server, we
|
||||
// pass the response along and start copying data between the client and the
|
||||
// backend. We don't look at the frames and payloads. The backend already needs to
|
||||
// know enough websocket to handle the frames. It wouldn't necessarily hurt to
|
||||
// monitor the frames too, and check if they are valid, but it's quite a bit of
|
||||
// work for little benefit. Besides, the whole point of websockets is to exchange
|
||||
// bytes without HTTP being in the way, so let's do that.
|
||||
func forwardWebsocket(h *config.WebForward, w http.ResponseWriter, r *http.Request, path string) (handled bool) {
|
||||
log := func() *mlog.Log {
|
||||
return xlog.WithContext(r.Context())
|
||||
}
|
||||
|
||||
lw := w.(*loggingWriter)
|
||||
lw.WebsocketRequest = true // For correct protocol in metrics.
|
||||
|
||||
// We check the requested websocket version first. A future websocket version may
|
||||
// have different request requirements.
|
||||
// ../rfc/6455:1160
|
||||
wsversion := r.Header.Get("Sec-WebSocket-Version")
|
||||
if wsversion != "13" {
|
||||
// Indicate we only support version 13. Should get a client from the future to fall back to version 13.
|
||||
// ../rfc/6455:1435
|
||||
w.Header().Set("Sec-WebSocket-Version", "13")
|
||||
http.Error(w, "400 - bad request - websockets only supported with version 13"+recvid(r), http.StatusBadRequest)
|
||||
lw.error(fmt.Errorf("Sec-WebSocket-Version %q not supported", wsversion))
|
||||
return true
|
||||
}
|
||||
|
||||
// ../rfc/6455:1143
|
||||
if r.Method != "GET" {
|
||||
http.Error(w, "400 - bad request - websockets only allowed with method GET"+recvid(r), http.StatusBadRequest)
|
||||
lw.error(fmt.Errorf("websocket request only allowed with method GET"))
|
||||
return true
|
||||
}
|
||||
|
||||
// ../rfc/6455:1153
|
||||
var connectionUpgrade bool
|
||||
for _, s := range strings.Split(r.Header.Get("Connection"), ",") {
|
||||
if strings.EqualFold(textproto.TrimString(s), "upgrade") {
|
||||
connectionUpgrade = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !connectionUpgrade {
|
||||
http.Error(w, "400 - bad request - connection header must be \"upgrade\""+recvid(r), http.StatusBadRequest)
|
||||
lw.error(fmt.Errorf(`connection header is %q, must be "upgrade"`, r.Header.Get("Connection")))
|
||||
return true
|
||||
}
|
||||
|
||||
// ../rfc/6455:1156
|
||||
wskey := r.Header.Get("Sec-WebSocket-Key")
|
||||
key, err := base64.StdEncoding.DecodeString(wskey)
|
||||
if err != nil || len(key) != 16 {
|
||||
http.Error(w, "400 - bad request - websockets requires Sec-WebSocket-Key with 16 bytes base64-encoded value"+recvid(r), http.StatusBadRequest)
|
||||
lw.error(fmt.Errorf("bad Sec-WebSocket-Key %q, must be 16 byte base64-encoded value", wskey))
|
||||
return true
|
||||
}
|
||||
|
||||
// ../rfc/6455:1162
|
||||
// We don't look at the origin header. The backend needs to handle it, if it thinks
|
||||
// that helps...
|
||||
// We also don't look at Sec-WebSocket-Protocol and Sec-WebSocket-Extensions. The
|
||||
// backend can set them, but it doesn't influence our forwarding of the data.
|
||||
|
||||
// If this is not a hijacker, there is not point in connecting to the backend.
|
||||
hj, ok := lw.W.(http.Hijacker)
|
||||
var cbr *bufio.ReadWriter
|
||||
if !ok {
|
||||
log().Info("cannot turn http connection into tcp connection (http.Hijacker)")
|
||||
http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
|
||||
lw.error(fmt.Errorf("connection not a http.Hijacker (%T)", lw.W))
|
||||
return
|
||||
}
|
||||
|
||||
freq := *r
|
||||
freq.Proto = "HTTP/1.1"
|
||||
freq.ProtoMajor = 1
|
||||
freq.ProtoMinor = 1
|
||||
fresp, beconn, err := websocketTransact(r.Context(), h.TargetURL, &freq)
|
||||
if err != nil {
|
||||
if errors.Is(err, errResponseNotWebsocket) {
|
||||
http.Error(w, "400 - bad request - websocket not supported"+recvid(r), http.StatusBadRequest)
|
||||
} else if errors.Is(err, errNotImplemented) {
|
||||
http.Error(w, "501 - not implemented - "+err.Error()+recvid(r), http.StatusNotImplemented)
|
||||
} else if os.IsTimeout(err) {
|
||||
http.Error(w, "504 - gateway timeout"+recvid(r), http.StatusGatewayTimeout)
|
||||
} else {
|
||||
http.Error(w, "502 - bad gateway"+recvid(r), http.StatusBadGateway)
|
||||
}
|
||||
lw.error(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if beconn != nil {
|
||||
beconn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Hijack the client connection so we can write the response ourselves, and start
|
||||
// copying the websocket frames.
|
||||
var cconn net.Conn
|
||||
cconn, cbr, err = hj.Hijack()
|
||||
if err != nil {
|
||||
log().Debugx("cannot turn http transaction into websocket connection", err)
|
||||
http.Error(w, "501 - not implemented - cannot turn this connection into websocket"+recvid(r), http.StatusNotImplemented)
|
||||
lw.error(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if cconn != nil {
|
||||
cconn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Below this point, we can no longer write to the ResponseWriter.
|
||||
|
||||
// Mark as websocket response, for logging.
|
||||
lw.WebsocketResponse = true
|
||||
lw.setStatusCode(fresp.StatusCode)
|
||||
|
||||
for k, v := range h.ResponseHeaders {
|
||||
fresp.Header.Add(k, v)
|
||||
}
|
||||
|
||||
// Write the response to the client, completing its websocket handshake.
|
||||
if err := fresp.Write(cconn); err != nil {
|
||||
lw.error(fmt.Errorf("writing websocket response to client: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
errc := make(chan error, 1)
|
||||
|
||||
// Copy from client to backend.
|
||||
go func() {
|
||||
buf, err := cbr.Peek(cbr.Reader.Buffered())
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
if len(buf) > 0 {
|
||||
n, err := beconn.Write(buf)
|
||||
if err != nil {
|
||||
errc <- err
|
||||
return
|
||||
}
|
||||
lw.SizeFromClient += int64(n)
|
||||
}
|
||||
n, err := io.Copy(beconn, cconn)
|
||||
lw.SizeFromClient += n
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
// Copy from backend to client.
|
||||
go func() {
|
||||
n, err := io.Copy(cconn, beconn)
|
||||
lw.SizeToClient = n
|
||||
errc <- err
|
||||
}()
|
||||
|
||||
// Stop and close connection on first error from either size, typically a closed
|
||||
// connection whose closing was already announced with a websocket frame.
|
||||
lw.error(<-errc)
|
||||
// Close connections so other goroutine stops as well.
|
||||
cconn.Close()
|
||||
beconn.Close()
|
||||
cconn = nil
|
||||
// Wait for goroutine so it has updated the logWriter.Size*Client fields before we
|
||||
// continue with logging.
|
||||
<-errc
|
||||
return true
|
||||
}
|
||||
|
||||
func websocketTransact(ctx context.Context, targetURL *url.URL, r *http.Request) (rresp *http.Response, rconn net.Conn, rerr error) {
|
||||
log := func() *mlog.Log {
|
||||
return xlog.WithContext(r.Context())
|
||||
}
|
||||
|
||||
// Dial the backend, possibly doing TLS. We assume the net/http DefaultTransport is
|
||||
// unmodified.
|
||||
transport := http.DefaultTransport.(*http.Transport)
|
||||
|
||||
// We haven't implemented using a proxy for websocket requests yet. If we need one,
|
||||
// return an error instead of trying to connect directly, which would be a
|
||||
// potential security issue.
|
||||
treq := *r
|
||||
treq.URL = targetURL
|
||||
if purl, err := transport.Proxy(&treq); err != nil {
|
||||
return nil, nil, fmt.Errorf("determining proxy for websocket backend connection: %w", err)
|
||||
} else if purl != nil {
|
||||
return nil, nil, fmt.Errorf("%w: proxy required for websocket connection to backend", errNotImplemented) // todo: implement?
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(targetURL.Host)
|
||||
if err != nil {
|
||||
host = targetURL.Host
|
||||
if targetURL.Scheme == "https" {
|
||||
port = "443"
|
||||
} else {
|
||||
port = "80"
|
||||
}
|
||||
}
|
||||
addr := net.JoinHostPort(host, port)
|
||||
conn, err := transport.DialContext(r.Context(), "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("dial: %w", err)
|
||||
}
|
||||
if targetURL.Scheme == "https" {
|
||||
tlsconn := tls.Client(conn, transport.TLSClientConfig)
|
||||
ctx, cancel := context.WithTimeout(r.Context(), transport.TLSHandshakeTimeout)
|
||||
defer cancel()
|
||||
if err := tlsconn.HandshakeContext(ctx); err != nil {
|
||||
return nil, nil, fmt.Errorf("tls handshake: %w", err)
|
||||
}
|
||||
conn = tlsconn
|
||||
}
|
||||
defer func() {
|
||||
if rerr != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// todo: make timeout configurable?
|
||||
if err := conn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
log().Check(err, "set deadline for websocket request to backend")
|
||||
}
|
||||
|
||||
// Set clean connection headers.
|
||||
removeHopByHopHeaders(r.Header)
|
||||
r.Header.Set("Connection", "Upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
|
||||
// Write the websocket request to the backend.
|
||||
if err := r.Write(conn); err != nil {
|
||||
return nil, nil, fmt.Errorf("writing request to backend: %w", err)
|
||||
}
|
||||
|
||||
// Read response from backend.
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, r)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading response from backend: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if rerr != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
}()
|
||||
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||||
log().Check(err, "clearing deadline on websocket connection to backend")
|
||||
}
|
||||
|
||||
// Check that the response from the backend server indicates it is websocket. If
|
||||
// not, don't pass the backend response, but an error that websocket is not
|
||||
// appropriate.
|
||||
if err := checkWebsocketResponse(resp, r); err != nil {
|
||||
return resp, nil, err
|
||||
}
|
||||
|
||||
// note: net/http.Response.Body documents that it implements io.Writer for a
|
||||
// status: 101 response. But that's not the case when the response has been read
|
||||
// with http.ReadResponse. We'll write to the connection directly.
|
||||
|
||||
buf, err := br.Peek(br.Buffered())
|
||||
if err != nil {
|
||||
return resp, nil, fmt.Errorf("peek at buffered data written by backend: %w", err)
|
||||
}
|
||||
return resp, websocketConn{io.MultiReader(bytes.NewReader(buf), conn), conn}, nil
|
||||
}
|
||||
|
||||
// A net.Conn but with reads coming from an io multireader (due to buffered reader
|
||||
// needed for http.ReadResponse).
|
||||
type websocketConn struct {
|
||||
r io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c websocketConn) Read(buf []byte) (int, error) {
|
||||
return c.r.Read(buf)
|
||||
}
|
||||
|
||||
// Check that an HTTP response (from a backend) is a valid websocket response, i.e.
|
||||
// that it accepts the WebSocket "upgrade".
|
||||
// ../rfc/6455:1299
|
||||
func checkWebsocketResponse(resp *http.Response, req *http.Request) error {
|
||||
if resp.StatusCode != 101 {
|
||||
return fmt.Errorf("%w: response http status not 101 but %s", errResponseNotWebsocket, resp.Status)
|
||||
}
|
||||
if upgrade := resp.Header.Get("Upgrade"); !strings.EqualFold(upgrade, "websocket") {
|
||||
return fmt.Errorf(`%w: response http status is 101, but Upgrade header is %q, should be "websocket"`, errResponseNotWebsocket, upgrade)
|
||||
}
|
||||
if connection := resp.Header.Get("Connection"); !strings.EqualFold(connection, "upgrade") {
|
||||
return fmt.Errorf(`%w: response http status is 101, Upgrade is websocket, but Connection header is %q, should be "Upgrade"`, errResponseNotWebsocket, connection)
|
||||
}
|
||||
accept, err := base64.StdEncoding.DecodeString(resp.Header.Get("Sec-WebSocket-Accept"))
|
||||
if err != nil {
|
||||
return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but Sec-WebSocket-Accept header is not valid base64: %v`, errResponseNotWebsocket, err)
|
||||
}
|
||||
exp := sha1.Sum([]byte(req.Header.Get("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
|
||||
if !bytes.Equal(accept, exp[:]) {
|
||||
return fmt.Errorf(`%w: response http status, Upgrade and Connection header are websocket, but backend Sec-WebSocket-Accept value does not match`, errResponseNotWebsocket)
|
||||
}
|
||||
// We don't have requirements for the other Sec-WebSocket headers. ../rfc/6455:1340
|
||||
return nil
|
||||
}
|
||||
|
||||
// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||
// Connection header field. These are the headers defined by the
|
||||
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||
// compatibility.
|
||||
// ../rfc/2616:5128
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
// From Go 1.20.4 src/net/http/httputil/reverseproxy.go:
|
||||
// removeHopByHopHeaders removes hop-by-hop headers.
|
||||
func removeHopByHopHeaders(h http.Header) {
|
||||
// RFC 7230, section 6.1: Remove headers listed in the "Connection" header.
|
||||
// ../rfc/7230:2817
|
||||
for _, f := range h["Connection"] {
|
||||
for _, sf := range strings.Split(f, ",") {
|
||||
if sf = textproto.TrimString(sf); sf != "" {
|
||||
h.Del(sf)
|
||||
}
|
||||
}
|
||||
}
|
||||
// RFC 2616, section 13.5.1: Remove a set of known hop-by-hop headers.
|
||||
// This behavior is superseded by the RFC 7230 Connection header, but
|
||||
// preserve it for backwards compatibility.
|
||||
// ../rfc/2616:5128
|
||||
for _, f := range hopHeaders {
|
||||
h.Del(f)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,9 @@ package http
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
|
@ -10,6 +13,8 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
|
||||
"github.com/mjl-/mox/mox-"
|
||||
)
|
||||
|
||||
|
@ -119,3 +124,184 @@ func TestWebserver(t *testing.T) {
|
|||
test("GET", "http://mox.example/bogus", nil, http.StatusNotFound, "", nil) // path not registered.
|
||||
test("GET", "http://bogus.mox.example/static/", nil, http.StatusNotFound, "", nil) // domain not registered.
|
||||
}
|
||||
|
||||
func TestWebsocket(t *testing.T) {
|
||||
os.RemoveAll("../testdata/websocket/data")
|
||||
mox.ConfigStaticPath = "../testdata/websocket/mox.conf"
|
||||
mox.ConfigDynamicPath = filepath.Join(filepath.Dir(mox.ConfigStaticPath), "domains.conf")
|
||||
mox.MustLoadConfig(false)
|
||||
|
||||
srv := &serve{Webserver: true}
|
||||
|
||||
var handler http.Handler // Active handler during test.
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler.ServeHTTP(w, r)
|
||||
}))
|
||||
|
||||
defer backend.Close()
|
||||
backendURL, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parsing backend url: %v", err)
|
||||
}
|
||||
backendURL.Path = "/"
|
||||
|
||||
// warning: it is not normally allowed to access the dynamic config without lock. don't propagate accesses like this!
|
||||
mox.Conf.Dynamic.WebHandlers[len(mox.Conf.Dynamic.WebHandlers)-1].WebForward.TargetURL = backendURL
|
||||
|
||||
server := httptest.NewServer(srv)
|
||||
defer server.Close()
|
||||
|
||||
serverURL, err := url.Parse(server.URL)
|
||||
tcheck(t, err, "parsing server url")
|
||||
_, port, err := net.SplitHostPort(serverURL.Host)
|
||||
tcheck(t, err, "parsing host port in server url")
|
||||
wsurl := fmt.Sprintf("ws://%s/ws/", net.JoinHostPort("localhost", port))
|
||||
|
||||
handler = websocket.Handler(func(c *websocket.Conn) {
|
||||
io.Copy(c, c)
|
||||
})
|
||||
|
||||
// Test a correct websocket connection.
|
||||
wsconn, err := websocket.Dial(wsurl, "ignored", "http://ignored.example")
|
||||
tcheck(t, err, "websocket dial")
|
||||
_, err = fmt.Fprint(wsconn, "test")
|
||||
tcheck(t, err, "write to websocket")
|
||||
buf := make([]byte, 128)
|
||||
n, err := wsconn.Read(buf)
|
||||
tcheck(t, err, "read from websocket")
|
||||
if string(buf[:n]) != "test" {
|
||||
t.Fatalf(`got websocket data %q, expected "test"`, buf[:n])
|
||||
}
|
||||
err = wsconn.Close()
|
||||
tcheck(t, err, "closing websocket connection")
|
||||
|
||||
// Test with server.ServeHTTP directly.
|
||||
test := func(method string, reqhdrs map[string]string, expCode int, expHeaders map[string]string) {
|
||||
t.Helper()
|
||||
|
||||
req := httptest.NewRequest(method, wsurl, nil)
|
||||
for k, v := range reqhdrs {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
rw.Body = &bytes.Buffer{}
|
||||
srv.ServeHTTP(rw, req)
|
||||
resp := rw.Result()
|
||||
if resp.StatusCode != expCode {
|
||||
t.Fatalf("got statuscode %d, expected %d", resp.StatusCode, expCode)
|
||||
}
|
||||
for k, v := range expHeaders {
|
||||
if xv := resp.Header.Get(k); xv != v {
|
||||
t.Fatalf("got %q for header %q, expected %q", xv, k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wsreqhdrs := map[string]string{
|
||||
"Upgrade": "keep-alive, websocket",
|
||||
"Connection": "X, Upgrade",
|
||||
"Sec-Websocket-Version": "13",
|
||||
"Sec-Websocket-Key": "AAAAAAAAAAAAAAAAAAAAAA==",
|
||||
}
|
||||
|
||||
test("POST", wsreqhdrs, http.StatusBadRequest, nil)
|
||||
|
||||
clone := func(m map[string]string) map[string]string {
|
||||
r := map[string]string{}
|
||||
for k, v := range m {
|
||||
r[k] = v
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
hdrs := clone(wsreqhdrs)
|
||||
hdrs["Sec-Websocket-Version"] = "14"
|
||||
test("GET", hdrs, http.StatusBadRequest, map[string]string{"Sec-Websocket-Version": "13"})
|
||||
|
||||
httpurl := fmt.Sprintf("http://%s/ws/", net.JoinHostPort("localhost", port))
|
||||
|
||||
// Must now do actual HTTP requests and read the HTTP response. Cannot call
|
||||
// ServeHTTP because ResponseRecorder is not a http.Hijacker.
|
||||
test = func(method string, reqhdrs map[string]string, expCode int, expHeaders map[string]string) {
|
||||
t.Helper()
|
||||
|
||||
req, err := http.NewRequest(method, httpurl, nil)
|
||||
for k, v := range reqhdrs {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
tcheck(t, err, "http transaction")
|
||||
if resp.StatusCode != expCode {
|
||||
t.Fatalf("got statuscode %d, expected %d", resp.StatusCode, expCode)
|
||||
}
|
||||
for k, v := range expHeaders {
|
||||
if xv := resp.Header.Get(k); xv != v {
|
||||
t.Fatalf("got %q for header %q, expected %q", xv, k, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hdrs = clone(wsreqhdrs)
|
||||
hdrs["Sec-Websocket-Key"] = "malformed"
|
||||
test("GET", hdrs, http.StatusBadRequest, nil)
|
||||
|
||||
hdrs = clone(wsreqhdrs)
|
||||
hdrs["Sec-Websocket-Key"] = "c2hvcnQK" // "short"
|
||||
test("GET", hdrs, http.StatusBadRequest, nil)
|
||||
|
||||
// Not responding with a 101, but with regular 200 OK response.
|
||||
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "bad", http.StatusOK)
|
||||
})
|
||||
test("GET", wsreqhdrs, http.StatusBadRequest, nil)
|
||||
|
||||
// Respond with 101, but other websocket response headers missing.
|
||||
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
})
|
||||
test("GET", wsreqhdrs, http.StatusBadRequest, nil)
|
||||
|
||||
// With Upgrade: websocket, without Connection: Upgrade
|
||||
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Upgrade", "websocket")
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
})
|
||||
test("GET", wsreqhdrs, http.StatusBadRequest, nil)
|
||||
|
||||
// With malformed Sec-WebSocket-Accept response header.
|
||||
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := w.Header()
|
||||
h.Set("Upgrade", "websocket")
|
||||
h.Set("Connection", "Upgrade")
|
||||
h.Set("Sec-WebSocket-Accept", "malformed")
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
})
|
||||
test("GET", wsreqhdrs, http.StatusBadRequest, nil)
|
||||
|
||||
// With malformed Sec-WebSocket-Accept response header.
|
||||
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := w.Header()
|
||||
h.Set("Upgrade", "websocket")
|
||||
h.Set("Connection", "Upgrade")
|
||||
h.Set("Sec-WebSocket-Accept", "YmFk") // "bad"
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
})
|
||||
test("GET", wsreqhdrs, http.StatusBadRequest, nil)
|
||||
|
||||
// All good.
|
||||
wsresphdrs := map[string]string{
|
||||
"Connection": "Upgrade",
|
||||
"Upgrade": "websocket",
|
||||
"Sec-Websocket-Accept": "ICX+Yqv66kxgM0FcWaLWlFLwTAI=",
|
||||
"X-Test": "mox",
|
||||
}
|
||||
handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := w.Header()
|
||||
h.Set("Upgrade", "websocket")
|
||||
h.Set("Connection", "Upgrade")
|
||||
h.Set("Sec-WebSocket-Accept", "ICX+Yqv66kxgM0FcWaLWlFLwTAI=")
|
||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||
})
|
||||
test("GET", wsreqhdrs, http.StatusSwitchingProtocols, wsresphdrs)
|
||||
|
||||
}
|
||||
|
|
10
rfc/index.md
10
rfc/index.md
|
@ -297,6 +297,16 @@ and many more, see http://sieve.info/documents
|
|||
9157 Revised IANA Considerations for DNSSEC
|
||||
9276 Guidance for NSEC3 Parameter Settings
|
||||
|
||||
# HTTP
|
||||
|
||||
2616 Hypertext Transfer Protocol -- HTTP/1.1
|
||||
7230 Hypertext Transfer Protocol (HTTP/1.1): Message Syntax and Routing
|
||||
9110 HTTP Semantics
|
||||
|
||||
# Websockets
|
||||
|
||||
6455 The WebSocket Protocol
|
||||
|
||||
# More
|
||||
|
||||
3339 Date and Time on the Internet: Timestamps
|
||||
|
|
19
testdata/websocket/domains.conf
vendored
Normal file
19
testdata/websocket/domains.conf
vendored
Normal file
|
@ -0,0 +1,19 @@
|
|||
Domains:
|
||||
mox.example:
|
||||
LocalpartCaseSensitive: false
|
||||
Accounts:
|
||||
mjl:
|
||||
Domain: mox.example
|
||||
Destinations:
|
||||
mjl@mox.example: nil
|
||||
WebHandlers:
|
||||
-
|
||||
LogName: websocket
|
||||
Domain: localhost
|
||||
PathRegexp: ^/ws/
|
||||
DontRedirectPlainHTTP: true
|
||||
WebForward:
|
||||
# replaced while testing
|
||||
URL: http://127.0.0.1:1/
|
||||
ResponseHeaders:
|
||||
X-Test: mox
|
13
testdata/websocket/mox.conf
vendored
Normal file
13
testdata/websocket/mox.conf
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
DataDir: data
|
||||
User: 1000
|
||||
LogLevel: trace
|
||||
Hostname: mox.example
|
||||
Listeners:
|
||||
local:
|
||||
IPs:
|
||||
- 0.0.0.0
|
||||
WebserverHTTP:
|
||||
Enabled: true
|
||||
Postmaster:
|
||||
Account: mjl
|
||||
Mailbox: postmaster
|
106
vendor/golang.org/x/net/websocket/client.go
generated
vendored
Normal file
106
vendor/golang.org/x/net/websocket/client.go
generated
vendored
Normal file
|
@ -0,0 +1,106 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// DialError is an error that occurs while dialling a websocket server.
|
||||
type DialError struct {
|
||||
*Config
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *DialError) Error() string {
|
||||
return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error()
|
||||
}
|
||||
|
||||
// NewConfig creates a new WebSocket config for client connection.
|
||||
func NewConfig(server, origin string) (config *Config, err error) {
|
||||
config = new(Config)
|
||||
config.Version = ProtocolVersionHybi13
|
||||
config.Location, err = url.ParseRequestURI(server)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
config.Origin, err = url.ParseRequestURI(origin)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
config.Header = http.Header(make(map[string][]string))
|
||||
return
|
||||
}
|
||||
|
||||
// NewClient creates a new WebSocket client connection over rwc.
|
||||
func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) {
|
||||
br := bufio.NewReader(rwc)
|
||||
bw := bufio.NewWriter(rwc)
|
||||
err = hybiClientHandshake(config, br, bw)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
buf := bufio.NewReadWriter(br, bw)
|
||||
ws = newHybiClientConn(config, buf, rwc)
|
||||
return
|
||||
}
|
||||
|
||||
// Dial opens a new client connection to a WebSocket.
|
||||
func Dial(url_, protocol, origin string) (ws *Conn, err error) {
|
||||
config, err := NewConfig(url_, origin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if protocol != "" {
|
||||
config.Protocol = []string{protocol}
|
||||
}
|
||||
return DialConfig(config)
|
||||
}
|
||||
|
||||
var portMap = map[string]string{
|
||||
"ws": "80",
|
||||
"wss": "443",
|
||||
}
|
||||
|
||||
func parseAuthority(location *url.URL) string {
|
||||
if _, ok := portMap[location.Scheme]; ok {
|
||||
if _, _, err := net.SplitHostPort(location.Host); err != nil {
|
||||
return net.JoinHostPort(location.Host, portMap[location.Scheme])
|
||||
}
|
||||
}
|
||||
return location.Host
|
||||
}
|
||||
|
||||
// DialConfig opens a new client connection to a WebSocket with a config.
|
||||
func DialConfig(config *Config) (ws *Conn, err error) {
|
||||
var client net.Conn
|
||||
if config.Location == nil {
|
||||
return nil, &DialError{config, ErrBadWebSocketLocation}
|
||||
}
|
||||
if config.Origin == nil {
|
||||
return nil, &DialError{config, ErrBadWebSocketOrigin}
|
||||
}
|
||||
dialer := config.Dialer
|
||||
if dialer == nil {
|
||||
dialer = &net.Dialer{}
|
||||
}
|
||||
client, err = dialWithDialer(dialer, config)
|
||||
if err != nil {
|
||||
goto Error
|
||||
}
|
||||
ws, err = NewClient(config, client)
|
||||
if err != nil {
|
||||
client.Close()
|
||||
goto Error
|
||||
}
|
||||
return
|
||||
|
||||
Error:
|
||||
return nil, &DialError{config, err}
|
||||
}
|
24
vendor/golang.org/x/net/websocket/dial.go
generated
vendored
Normal file
24
vendor/golang.org/x/net/websocket/dial.go
generated
vendored
Normal file
|
@ -0,0 +1,24 @@
|
|||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
func dialWithDialer(dialer *net.Dialer, config *Config) (conn net.Conn, err error) {
|
||||
switch config.Location.Scheme {
|
||||
case "ws":
|
||||
conn, err = dialer.Dial("tcp", parseAuthority(config.Location))
|
||||
|
||||
case "wss":
|
||||
conn, err = tls.DialWithDialer(dialer, "tcp", parseAuthority(config.Location), config.TlsConfig)
|
||||
|
||||
default:
|
||||
err = ErrBadScheme
|
||||
}
|
||||
return
|
||||
}
|
583
vendor/golang.org/x/net/websocket/hybi.go
generated
vendored
Normal file
583
vendor/golang.org/x/net/websocket/hybi.go
generated
vendored
Normal file
|
@ -0,0 +1,583 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket
|
||||
|
||||
// This file implements a protocol of hybi draft.
|
||||
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
closeStatusNormal = 1000
|
||||
closeStatusGoingAway = 1001
|
||||
closeStatusProtocolError = 1002
|
||||
closeStatusUnsupportedData = 1003
|
||||
closeStatusFrameTooLarge = 1004
|
||||
closeStatusNoStatusRcvd = 1005
|
||||
closeStatusAbnormalClosure = 1006
|
||||
closeStatusBadMessageData = 1007
|
||||
closeStatusPolicyViolation = 1008
|
||||
closeStatusTooBigData = 1009
|
||||
closeStatusExtensionMismatch = 1010
|
||||
|
||||
maxControlFramePayloadLength = 125
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBadMaskingKey = &ProtocolError{"bad masking key"}
|
||||
ErrBadPongMessage = &ProtocolError{"bad pong message"}
|
||||
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
|
||||
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
|
||||
ErrNotImplemented = &ProtocolError{"not implemented"}
|
||||
|
||||
handshakeHeader = map[string]bool{
|
||||
"Host": true,
|
||||
"Upgrade": true,
|
||||
"Connection": true,
|
||||
"Sec-Websocket-Key": true,
|
||||
"Sec-Websocket-Origin": true,
|
||||
"Sec-Websocket-Version": true,
|
||||
"Sec-Websocket-Protocol": true,
|
||||
"Sec-Websocket-Accept": true,
|
||||
}
|
||||
)
|
||||
|
||||
// A hybiFrameHeader is a frame header as defined in hybi draft.
|
||||
type hybiFrameHeader struct {
|
||||
Fin bool
|
||||
Rsv [3]bool
|
||||
OpCode byte
|
||||
Length int64
|
||||
MaskingKey []byte
|
||||
|
||||
data *bytes.Buffer
|
||||
}
|
||||
|
||||
// A hybiFrameReader is a reader for hybi frame.
|
||||
type hybiFrameReader struct {
|
||||
reader io.Reader
|
||||
|
||||
header hybiFrameHeader
|
||||
pos int64
|
||||
length int
|
||||
}
|
||||
|
||||
func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
|
||||
n, err = frame.reader.Read(msg)
|
||||
if frame.header.MaskingKey != nil {
|
||||
for i := 0; i < n; i++ {
|
||||
msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
|
||||
frame.pos++
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
|
||||
|
||||
func (frame *hybiFrameReader) HeaderReader() io.Reader {
|
||||
if frame.header.data == nil {
|
||||
return nil
|
||||
}
|
||||
if frame.header.data.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
return frame.header.data
|
||||
}
|
||||
|
||||
func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
|
||||
|
||||
func (frame *hybiFrameReader) Len() (n int) { return frame.length }
|
||||
|
||||
// A hybiFrameReaderFactory creates new frame reader based on its frame type.
|
||||
type hybiFrameReaderFactory struct {
|
||||
*bufio.Reader
|
||||
}
|
||||
|
||||
// NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
|
||||
// See Section 5.2 Base Framing protocol for detail.
|
||||
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
|
||||
func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
|
||||
hybiFrame := new(hybiFrameReader)
|
||||
frame = hybiFrame
|
||||
var header []byte
|
||||
var b byte
|
||||
// First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
|
||||
b, err = buf.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
header = append(header, b)
|
||||
hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
|
||||
for i := 0; i < 3; i++ {
|
||||
j := uint(6 - i)
|
||||
hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
|
||||
}
|
||||
hybiFrame.header.OpCode = header[0] & 0x0f
|
||||
|
||||
// Second byte. Mask/Payload len(7bits)
|
||||
b, err = buf.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
header = append(header, b)
|
||||
mask := (b & 0x80) != 0
|
||||
b &= 0x7f
|
||||
lengthFields := 0
|
||||
switch {
|
||||
case b <= 125: // Payload length 7bits.
|
||||
hybiFrame.header.Length = int64(b)
|
||||
case b == 126: // Payload length 7+16bits
|
||||
lengthFields = 2
|
||||
case b == 127: // Payload length 7+64bits
|
||||
lengthFields = 8
|
||||
}
|
||||
for i := 0; i < lengthFields; i++ {
|
||||
b, err = buf.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits
|
||||
b &= 0x7f
|
||||
}
|
||||
header = append(header, b)
|
||||
hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
|
||||
}
|
||||
if mask {
|
||||
// Masking key. 4 bytes.
|
||||
for i := 0; i < 4; i++ {
|
||||
b, err = buf.ReadByte()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
header = append(header, b)
|
||||
hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
|
||||
}
|
||||
}
|
||||
hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
|
||||
hybiFrame.header.data = bytes.NewBuffer(header)
|
||||
hybiFrame.length = len(header) + int(hybiFrame.header.Length)
|
||||
return
|
||||
}
|
||||
|
||||
// A HybiFrameWriter is a writer for hybi frame.
|
||||
type hybiFrameWriter struct {
|
||||
writer *bufio.Writer
|
||||
|
||||
header *hybiFrameHeader
|
||||
}
|
||||
|
||||
func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
|
||||
var header []byte
|
||||
var b byte
|
||||
if frame.header.Fin {
|
||||
b |= 0x80
|
||||
}
|
||||
for i := 0; i < 3; i++ {
|
||||
if frame.header.Rsv[i] {
|
||||
j := uint(6 - i)
|
||||
b |= 1 << j
|
||||
}
|
||||
}
|
||||
b |= frame.header.OpCode
|
||||
header = append(header, b)
|
||||
if frame.header.MaskingKey != nil {
|
||||
b = 0x80
|
||||
} else {
|
||||
b = 0
|
||||
}
|
||||
lengthFields := 0
|
||||
length := len(msg)
|
||||
switch {
|
||||
case length <= 125:
|
||||
b |= byte(length)
|
||||
case length < 65536:
|
||||
b |= 126
|
||||
lengthFields = 2
|
||||
default:
|
||||
b |= 127
|
||||
lengthFields = 8
|
||||
}
|
||||
header = append(header, b)
|
||||
for i := 0; i < lengthFields; i++ {
|
||||
j := uint((lengthFields - i - 1) * 8)
|
||||
b = byte((length >> j) & 0xff)
|
||||
header = append(header, b)
|
||||
}
|
||||
if frame.header.MaskingKey != nil {
|
||||
if len(frame.header.MaskingKey) != 4 {
|
||||
return 0, ErrBadMaskingKey
|
||||
}
|
||||
header = append(header, frame.header.MaskingKey...)
|
||||
frame.writer.Write(header)
|
||||
data := make([]byte, length)
|
||||
for i := range data {
|
||||
data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
|
||||
}
|
||||
frame.writer.Write(data)
|
||||
err = frame.writer.Flush()
|
||||
return length, err
|
||||
}
|
||||
frame.writer.Write(header)
|
||||
frame.writer.Write(msg)
|
||||
err = frame.writer.Flush()
|
||||
return length, err
|
||||
}
|
||||
|
||||
func (frame *hybiFrameWriter) Close() error { return nil }
|
||||
|
||||
type hybiFrameWriterFactory struct {
|
||||
*bufio.Writer
|
||||
needMaskingKey bool
|
||||
}
|
||||
|
||||
func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
|
||||
frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
|
||||
if buf.needMaskingKey {
|
||||
frameHeader.MaskingKey, err = generateMaskingKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
|
||||
}
|
||||
|
||||
type hybiFrameHandler struct {
|
||||
conn *Conn
|
||||
payloadType byte
|
||||
}
|
||||
|
||||
func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
|
||||
if handler.conn.IsServerConn() {
|
||||
// The client MUST mask all frames sent to the server.
|
||||
if frame.(*hybiFrameReader).header.MaskingKey == nil {
|
||||
handler.WriteClose(closeStatusProtocolError)
|
||||
return nil, io.EOF
|
||||
}
|
||||
} else {
|
||||
// The server MUST NOT mask all frames.
|
||||
if frame.(*hybiFrameReader).header.MaskingKey != nil {
|
||||
handler.WriteClose(closeStatusProtocolError)
|
||||
return nil, io.EOF
|
||||
}
|
||||
}
|
||||
if header := frame.HeaderReader(); header != nil {
|
||||
io.Copy(ioutil.Discard, header)
|
||||
}
|
||||
switch frame.PayloadType() {
|
||||
case ContinuationFrame:
|
||||
frame.(*hybiFrameReader).header.OpCode = handler.payloadType
|
||||
case TextFrame, BinaryFrame:
|
||||
handler.payloadType = frame.PayloadType()
|
||||
case CloseFrame:
|
||||
return nil, io.EOF
|
||||
case PingFrame, PongFrame:
|
||||
b := make([]byte, maxControlFramePayloadLength)
|
||||
n, err := io.ReadFull(frame, b)
|
||||
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
|
||||
return nil, err
|
||||
}
|
||||
io.Copy(ioutil.Discard, frame)
|
||||
if frame.PayloadType() == PingFrame {
|
||||
if _, err := handler.WritePong(b[:n]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
|
||||
handler.conn.wio.Lock()
|
||||
defer handler.conn.wio.Unlock()
|
||||
w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(msg, uint16(status))
|
||||
_, err = w.Write(msg)
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
|
||||
handler.conn.wio.Lock()
|
||||
defer handler.conn.wio.Unlock()
|
||||
w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = w.Write(msg)
|
||||
w.Close()
|
||||
return n, err
|
||||
}
|
||||
|
||||
// newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
|
||||
func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
||||
if buf == nil {
|
||||
br := bufio.NewReader(rwc)
|
||||
bw := bufio.NewWriter(rwc)
|
||||
buf = bufio.NewReadWriter(br, bw)
|
||||
}
|
||||
ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
|
||||
frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
|
||||
frameWriterFactory: hybiFrameWriterFactory{
|
||||
buf.Writer, request == nil},
|
||||
PayloadType: TextFrame,
|
||||
defaultCloseStatus: closeStatusNormal}
|
||||
ws.frameHandler = &hybiFrameHandler{conn: ws}
|
||||
return ws
|
||||
}
|
||||
|
||||
// generateMaskingKey generates a masking key for a frame.
|
||||
func generateMaskingKey() (maskingKey []byte, err error) {
|
||||
maskingKey = make([]byte, 4)
|
||||
if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// generateNonce generates a nonce consisting of a randomly selected 16-byte
|
||||
// value that has been base64-encoded.
|
||||
func generateNonce() (nonce []byte) {
|
||||
key := make([]byte, 16)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
nonce = make([]byte, 24)
|
||||
base64.StdEncoding.Encode(nonce, key)
|
||||
return
|
||||
}
|
||||
|
||||
// removeZone removes IPv6 zone identifier from host.
|
||||
// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
|
||||
func removeZone(host string) string {
|
||||
if !strings.HasPrefix(host, "[") {
|
||||
return host
|
||||
}
|
||||
i := strings.LastIndex(host, "]")
|
||||
if i < 0 {
|
||||
return host
|
||||
}
|
||||
j := strings.LastIndex(host[:i], "%")
|
||||
if j < 0 {
|
||||
return host
|
||||
}
|
||||
return host[:j] + host[i:]
|
||||
}
|
||||
|
||||
// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
|
||||
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
|
||||
func getNonceAccept(nonce []byte) (expected []byte, err error) {
|
||||
h := sha1.New()
|
||||
if _, err = h.Write(nonce); err != nil {
|
||||
return
|
||||
}
|
||||
if _, err = h.Write([]byte(websocketGUID)); err != nil {
|
||||
return
|
||||
}
|
||||
expected = make([]byte, 28)
|
||||
base64.StdEncoding.Encode(expected, h.Sum(nil))
|
||||
return
|
||||
}
|
||||
|
||||
// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
|
||||
func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
|
||||
bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
|
||||
|
||||
// According to RFC 6874, an HTTP client, proxy, or other
|
||||
// intermediary must remove any IPv6 zone identifier attached
|
||||
// to an outgoing URI.
|
||||
bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n")
|
||||
bw.WriteString("Upgrade: websocket\r\n")
|
||||
bw.WriteString("Connection: Upgrade\r\n")
|
||||
nonce := generateNonce()
|
||||
if config.handshakeData != nil {
|
||||
nonce = []byte(config.handshakeData["key"])
|
||||
}
|
||||
bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
|
||||
bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
|
||||
|
||||
if config.Version != ProtocolVersionHybi13 {
|
||||
return ErrBadProtocolVersion
|
||||
}
|
||||
|
||||
bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
|
||||
if len(config.Protocol) > 0 {
|
||||
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
|
||||
}
|
||||
// TODO(ukai): send Sec-WebSocket-Extensions.
|
||||
err = config.Header.WriteSubset(bw, handshakeHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bw.WriteString("\r\n")
|
||||
if err = bw.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != 101 {
|
||||
return ErrBadStatus
|
||||
}
|
||||
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
|
||||
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
|
||||
return ErrBadUpgrade
|
||||
}
|
||||
expectedAccept, err := getNonceAccept(nonce)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
|
||||
return ErrChallengeResponse
|
||||
}
|
||||
if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
|
||||
return ErrUnsupportedExtensions
|
||||
}
|
||||
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
|
||||
if offeredProtocol != "" {
|
||||
protocolMatched := false
|
||||
for i := 0; i < len(config.Protocol); i++ {
|
||||
if config.Protocol[i] == offeredProtocol {
|
||||
protocolMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !protocolMatched {
|
||||
return ErrBadWebSocketProtocol
|
||||
}
|
||||
config.Protocol = []string{offeredProtocol}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newHybiClientConn creates a client WebSocket connection after handshake.
|
||||
func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
|
||||
return newHybiConn(config, buf, rwc, nil)
|
||||
}
|
||||
|
||||
// A HybiServerHandshaker performs a server handshake using hybi draft protocol.
|
||||
type hybiServerHandshaker struct {
|
||||
*Config
|
||||
accept []byte
|
||||
}
|
||||
|
||||
func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
|
||||
c.Version = ProtocolVersionHybi13
|
||||
if req.Method != "GET" {
|
||||
return http.StatusMethodNotAllowed, ErrBadRequestMethod
|
||||
}
|
||||
// HTTP version can be safely ignored.
|
||||
|
||||
if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
|
||||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
||||
return http.StatusBadRequest, ErrNotWebSocket
|
||||
}
|
||||
|
||||
key := req.Header.Get("Sec-Websocket-Key")
|
||||
if key == "" {
|
||||
return http.StatusBadRequest, ErrChallengeResponse
|
||||
}
|
||||
version := req.Header.Get("Sec-Websocket-Version")
|
||||
switch version {
|
||||
case "13":
|
||||
c.Version = ProtocolVersionHybi13
|
||||
default:
|
||||
return http.StatusBadRequest, ErrBadWebSocketVersion
|
||||
}
|
||||
var scheme string
|
||||
if req.TLS != nil {
|
||||
scheme = "wss"
|
||||
} else {
|
||||
scheme = "ws"
|
||||
}
|
||||
c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
|
||||
if err != nil {
|
||||
return http.StatusBadRequest, err
|
||||
}
|
||||
protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
|
||||
if protocol != "" {
|
||||
protocols := strings.Split(protocol, ",")
|
||||
for i := 0; i < len(protocols); i++ {
|
||||
c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
|
||||
}
|
||||
}
|
||||
c.accept, err = getNonceAccept([]byte(key))
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, err
|
||||
}
|
||||
return http.StatusSwitchingProtocols, nil
|
||||
}
|
||||
|
||||
// Origin parses the Origin header in req.
|
||||
// If the Origin header is not set, it returns nil and nil.
|
||||
func Origin(config *Config, req *http.Request) (*url.URL, error) {
|
||||
var origin string
|
||||
switch config.Version {
|
||||
case ProtocolVersionHybi13:
|
||||
origin = req.Header.Get("Origin")
|
||||
}
|
||||
if origin == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return url.ParseRequestURI(origin)
|
||||
}
|
||||
|
||||
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
|
||||
if len(c.Protocol) > 0 {
|
||||
if len(c.Protocol) != 1 {
|
||||
// You need choose a Protocol in Handshake func in Server.
|
||||
return ErrBadWebSocketProtocol
|
||||
}
|
||||
}
|
||||
buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
|
||||
buf.WriteString("Upgrade: websocket\r\n")
|
||||
buf.WriteString("Connection: Upgrade\r\n")
|
||||
buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
|
||||
if len(c.Protocol) > 0 {
|
||||
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
|
||||
}
|
||||
// TODO(ukai): send Sec-WebSocket-Extensions.
|
||||
if c.Header != nil {
|
||||
err := c.Header.WriteSubset(buf, handshakeHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
buf.WriteString("\r\n")
|
||||
return buf.Flush()
|
||||
}
|
||||
|
||||
func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
||||
return newHybiServerConn(c.Config, buf, rwc, request)
|
||||
}
|
||||
|
||||
// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
|
||||
func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
||||
return newHybiConn(config, buf, rwc, request)
|
||||
}
|
113
vendor/golang.org/x/net/websocket/server.go
generated
vendored
Normal file
113
vendor/golang.org/x/net/websocket/server.go
generated
vendored
Normal file
|
@ -0,0 +1,113 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package websocket
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
|
||||
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
|
||||
code, err := hs.ReadHandshake(buf.Reader, req)
|
||||
if err == ErrBadWebSocketVersion {
|
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||
fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
|
||||
buf.WriteString("\r\n")
|
||||
buf.WriteString(err.Error())
|
||||
buf.Flush()
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||
buf.WriteString("\r\n")
|
||||
buf.WriteString(err.Error())
|
||||
buf.Flush()
|
||||
return
|
||||
}
|
||||
if handshake != nil {
|
||||
err = handshake(config, req)
|
||||
if err != nil {
|
||||
code = http.StatusForbidden
|
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||
buf.WriteString("\r\n")
|
||||
buf.Flush()
|
||||
return
|
||||
}
|
||||
}
|
||||
err = hs.AcceptHandshake(buf.Writer)
|
||||
if err != nil {
|
||||
code = http.StatusBadRequest
|
||||
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
|
||||
buf.WriteString("\r\n")
|
||||
buf.Flush()
|
||||
return
|
||||
}
|
||||
conn = hs.NewServerConn(buf, rwc, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Server represents a server of a WebSocket.
|
||||
type Server struct {
|
||||
// Config is a WebSocket configuration for new WebSocket connection.
|
||||
Config
|
||||
|
||||
// Handshake is an optional function in WebSocket handshake.
|
||||
// For example, you can check, or don't check Origin header.
|
||||
// Another example, you can select config.Protocol.
|
||||
Handshake func(*Config, *http.Request) error
|
||||
|
||||
// Handler handles a WebSocket connection.
|
||||
Handler
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler interface for a WebSocket
|
||||
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
s.serveWebSocket(w, req)
|
||||
}
|
||||
|
||||
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
|
||||
rwc, buf, err := w.(http.Hijacker).Hijack()
|
||||
if err != nil {
|
||||
panic("Hijack failed: " + err.Error())
|
||||
}
|
||||
// The server should abort the WebSocket connection if it finds
|
||||
// the client did not send a handshake that matches with protocol
|
||||
// specification.
|
||||
defer rwc.Close()
|
||||
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if conn == nil {
|
||||
panic("unexpected nil conn")
|
||||
}
|
||||
s.Handler(conn)
|
||||
}
|
||||
|
||||
// Handler is a simple interface to a WebSocket browser client.
|
||||
// It checks if Origin header is valid URL by default.
|
||||
// You might want to verify websocket.Conn.Config().Origin in the func.
|
||||
// If you use Server instead of Handler, you could call websocket.Origin and
|
||||
// check the origin in your Handshake func. So, if you want to accept
|
||||
// non-browser clients, which do not send an Origin header, set a
|
||||
// Server.Handshake that does not check the origin.
|
||||
type Handler func(*Conn)
|
||||
|
||||
func checkOrigin(config *Config, req *http.Request) (err error) {
|
||||
config.Origin, err = Origin(config, req)
|
||||
if err == nil && config.Origin == nil {
|
||||
return fmt.Errorf("null origin")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ServeHTTP implements the http.Handler interface for a WebSocket
|
||||
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
s := Server{Handler: h, Handshake: checkOrigin}
|
||||
s.serveWebSocket(w, req)
|
||||
}
|
449
vendor/golang.org/x/net/websocket/websocket.go
generated
vendored
Normal file
449
vendor/golang.org/x/net/websocket/websocket.go
generated
vendored
Normal file
|
@ -0,0 +1,449 @@
|
|||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package websocket implements a client and server for the WebSocket protocol
|
||||
// as specified in RFC 6455.
|
||||
//
|
||||
// This package currently lacks some features found in an alternative
|
||||
// and more actively maintained WebSocket package:
|
||||
//
|
||||
// https://pkg.go.dev/nhooyr.io/websocket
|
||||
package websocket // import "golang.org/x/net/websocket"
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolVersionHybi13 = 13
|
||||
ProtocolVersionHybi = ProtocolVersionHybi13
|
||||
SupportedProtocolVersion = "13"
|
||||
|
||||
ContinuationFrame = 0
|
||||
TextFrame = 1
|
||||
BinaryFrame = 2
|
||||
CloseFrame = 8
|
||||
PingFrame = 9
|
||||
PongFrame = 10
|
||||
UnknownFrame = 255
|
||||
|
||||
DefaultMaxPayloadBytes = 32 << 20 // 32MB
|
||||
)
|
||||
|
||||
// ProtocolError represents WebSocket protocol errors.
|
||||
type ProtocolError struct {
|
||||
ErrorString string
|
||||
}
|
||||
|
||||
func (err *ProtocolError) Error() string { return err.ErrorString }
|
||||
|
||||
var (
|
||||
ErrBadProtocolVersion = &ProtocolError{"bad protocol version"}
|
||||
ErrBadScheme = &ProtocolError{"bad scheme"}
|
||||
ErrBadStatus = &ProtocolError{"bad status"}
|
||||
ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"}
|
||||
ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"}
|
||||
ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
|
||||
ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
|
||||
ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"}
|
||||
ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"}
|
||||
ErrBadFrame = &ProtocolError{"bad frame"}
|
||||
ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"}
|
||||
ErrNotWebSocket = &ProtocolError{"not websocket protocol"}
|
||||
ErrBadRequestMethod = &ProtocolError{"bad method"}
|
||||
ErrNotSupported = &ProtocolError{"not supported"}
|
||||
)
|
||||
|
||||
// ErrFrameTooLarge is returned by Codec's Receive method if payload size
|
||||
// exceeds limit set by Conn.MaxPayloadBytes
|
||||
var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
|
||||
|
||||
// Addr is an implementation of net.Addr for WebSocket.
|
||||
type Addr struct {
|
||||
*url.URL
|
||||
}
|
||||
|
||||
// Network returns the network type for a WebSocket, "websocket".
|
||||
func (addr *Addr) Network() string { return "websocket" }
|
||||
|
||||
// Config is a WebSocket configuration
|
||||
type Config struct {
|
||||
// A WebSocket server address.
|
||||
Location *url.URL
|
||||
|
||||
// A Websocket client origin.
|
||||
Origin *url.URL
|
||||
|
||||
// WebSocket subprotocols.
|
||||
Protocol []string
|
||||
|
||||
// WebSocket protocol version.
|
||||
Version int
|
||||
|
||||
// TLS config for secure WebSocket (wss).
|
||||
TlsConfig *tls.Config
|
||||
|
||||
// Additional header fields to be sent in WebSocket opening handshake.
|
||||
Header http.Header
|
||||
|
||||
// Dialer used when opening websocket connections.
|
||||
Dialer *net.Dialer
|
||||
|
||||
handshakeData map[string]string
|
||||
}
|
||||
|
||||
// serverHandshaker is an interface to handle WebSocket server side handshake.
|
||||
type serverHandshaker interface {
|
||||
// ReadHandshake reads handshake request message from client.
|
||||
// Returns http response code and error if any.
|
||||
ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
|
||||
|
||||
// AcceptHandshake accepts the client handshake request and sends
|
||||
// handshake response back to client.
|
||||
AcceptHandshake(buf *bufio.Writer) (err error)
|
||||
|
||||
// NewServerConn creates a new WebSocket connection.
|
||||
NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
|
||||
}
|
||||
|
||||
// frameReader is an interface to read a WebSocket frame.
|
||||
type frameReader interface {
|
||||
// Reader is to read payload of the frame.
|
||||
io.Reader
|
||||
|
||||
// PayloadType returns payload type.
|
||||
PayloadType() byte
|
||||
|
||||
// HeaderReader returns a reader to read header of the frame.
|
||||
HeaderReader() io.Reader
|
||||
|
||||
// TrailerReader returns a reader to read trailer of the frame.
|
||||
// If it returns nil, there is no trailer in the frame.
|
||||
TrailerReader() io.Reader
|
||||
|
||||
// Len returns total length of the frame, including header and trailer.
|
||||
Len() int
|
||||
}
|
||||
|
||||
// frameReaderFactory is an interface to creates new frame reader.
|
||||
type frameReaderFactory interface {
|
||||
NewFrameReader() (r frameReader, err error)
|
||||
}
|
||||
|
||||
// frameWriter is an interface to write a WebSocket frame.
|
||||
type frameWriter interface {
|
||||
// Writer is to write payload of the frame.
|
||||
io.WriteCloser
|
||||
}
|
||||
|
||||
// frameWriterFactory is an interface to create new frame writer.
|
||||
type frameWriterFactory interface {
|
||||
NewFrameWriter(payloadType byte) (w frameWriter, err error)
|
||||
}
|
||||
|
||||
type frameHandler interface {
|
||||
HandleFrame(frame frameReader) (r frameReader, err error)
|
||||
WriteClose(status int) (err error)
|
||||
}
|
||||
|
||||
// Conn represents a WebSocket connection.
|
||||
//
|
||||
// Multiple goroutines may invoke methods on a Conn simultaneously.
|
||||
type Conn struct {
|
||||
config *Config
|
||||
request *http.Request
|
||||
|
||||
buf *bufio.ReadWriter
|
||||
rwc io.ReadWriteCloser
|
||||
|
||||
rio sync.Mutex
|
||||
frameReaderFactory
|
||||
frameReader
|
||||
|
||||
wio sync.Mutex
|
||||
frameWriterFactory
|
||||
|
||||
frameHandler
|
||||
PayloadType byte
|
||||
defaultCloseStatus int
|
||||
|
||||
// MaxPayloadBytes limits the size of frame payload received over Conn
|
||||
// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
|
||||
MaxPayloadBytes int
|
||||
}
|
||||
|
||||
// Read implements the io.Reader interface:
|
||||
// it reads data of a frame from the WebSocket connection.
|
||||
// if msg is not large enough for the frame data, it fills the msg and next Read
|
||||
// will read the rest of the frame data.
|
||||
// it reads Text frame or Binary frame.
|
||||
func (ws *Conn) Read(msg []byte) (n int, err error) {
|
||||
ws.rio.Lock()
|
||||
defer ws.rio.Unlock()
|
||||
again:
|
||||
if ws.frameReader == nil {
|
||||
frame, err := ws.frameReaderFactory.NewFrameReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if ws.frameReader == nil {
|
||||
goto again
|
||||
}
|
||||
}
|
||||
n, err = ws.frameReader.Read(msg)
|
||||
if err == io.EOF {
|
||||
if trailer := ws.frameReader.TrailerReader(); trailer != nil {
|
||||
io.Copy(ioutil.Discard, trailer)
|
||||
}
|
||||
ws.frameReader = nil
|
||||
goto again
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Write implements the io.Writer interface:
|
||||
// it writes data as a frame to the WebSocket connection.
|
||||
func (ws *Conn) Write(msg []byte) (n int, err error) {
|
||||
ws.wio.Lock()
|
||||
defer ws.wio.Unlock()
|
||||
w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = w.Write(msg)
|
||||
w.Close()
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close implements the io.Closer interface.
|
||||
func (ws *Conn) Close() error {
|
||||
err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
|
||||
err1 := ws.rwc.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err1
|
||||
}
|
||||
|
||||
// IsClientConn reports whether ws is a client-side connection.
|
||||
func (ws *Conn) IsClientConn() bool { return ws.request == nil }
|
||||
|
||||
// IsServerConn reports whether ws is a server-side connection.
|
||||
func (ws *Conn) IsServerConn() bool { return ws.request != nil }
|
||||
|
||||
// LocalAddr returns the WebSocket Origin for the connection for client, or
|
||||
// the WebSocket location for server.
|
||||
func (ws *Conn) LocalAddr() net.Addr {
|
||||
if ws.IsClientConn() {
|
||||
return &Addr{ws.config.Origin}
|
||||
}
|
||||
return &Addr{ws.config.Location}
|
||||
}
|
||||
|
||||
// RemoteAddr returns the WebSocket location for the connection for client, or
|
||||
// the Websocket Origin for server.
|
||||
func (ws *Conn) RemoteAddr() net.Addr {
|
||||
if ws.IsClientConn() {
|
||||
return &Addr{ws.config.Location}
|
||||
}
|
||||
return &Addr{ws.config.Origin}
|
||||
}
|
||||
|
||||
var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
|
||||
|
||||
// SetDeadline sets the connection's network read & write deadlines.
|
||||
func (ws *Conn) SetDeadline(t time.Time) error {
|
||||
if conn, ok := ws.rwc.(net.Conn); ok {
|
||||
return conn.SetDeadline(t)
|
||||
}
|
||||
return errSetDeadline
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the connection's network read deadline.
|
||||
func (ws *Conn) SetReadDeadline(t time.Time) error {
|
||||
if conn, ok := ws.rwc.(net.Conn); ok {
|
||||
return conn.SetReadDeadline(t)
|
||||
}
|
||||
return errSetDeadline
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the connection's network write deadline.
|
||||
func (ws *Conn) SetWriteDeadline(t time.Time) error {
|
||||
if conn, ok := ws.rwc.(net.Conn); ok {
|
||||
return conn.SetWriteDeadline(t)
|
||||
}
|
||||
return errSetDeadline
|
||||
}
|
||||
|
||||
// Config returns the WebSocket config.
|
||||
func (ws *Conn) Config() *Config { return ws.config }
|
||||
|
||||
// Request returns the http request upgraded to the WebSocket.
|
||||
// It is nil for client side.
|
||||
func (ws *Conn) Request() *http.Request { return ws.request }
|
||||
|
||||
// Codec represents a symmetric pair of functions that implement a codec.
|
||||
type Codec struct {
|
||||
Marshal func(v interface{}) (data []byte, payloadType byte, err error)
|
||||
Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
|
||||
}
|
||||
|
||||
// Send sends v marshaled by cd.Marshal as single frame to ws.
|
||||
func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
|
||||
data, payloadType, err := cd.Marshal(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ws.wio.Lock()
|
||||
defer ws.wio.Unlock()
|
||||
w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(data)
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
|
||||
// in v. The whole frame payload is read to an in-memory buffer; max size of
|
||||
// payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
|
||||
// limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
|
||||
// completely. The next call to Receive would read and discard leftover data of
|
||||
// previous oversized frame before processing next frame.
|
||||
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
|
||||
ws.rio.Lock()
|
||||
defer ws.rio.Unlock()
|
||||
if ws.frameReader != nil {
|
||||
_, err = io.Copy(ioutil.Discard, ws.frameReader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ws.frameReader = nil
|
||||
}
|
||||
again:
|
||||
frame, err := ws.frameReaderFactory.NewFrameReader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
frame, err = ws.frameHandler.HandleFrame(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if frame == nil {
|
||||
goto again
|
||||
}
|
||||
maxPayloadBytes := ws.MaxPayloadBytes
|
||||
if maxPayloadBytes == 0 {
|
||||
maxPayloadBytes = DefaultMaxPayloadBytes
|
||||
}
|
||||
if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
|
||||
// payload size exceeds limit, no need to call Unmarshal
|
||||
//
|
||||
// set frameReader to current oversized frame so that
|
||||
// the next call to this function can drain leftover
|
||||
// data before processing the next frame
|
||||
ws.frameReader = frame
|
||||
return ErrFrameTooLarge
|
||||
}
|
||||
payloadType := frame.PayloadType()
|
||||
data, err := ioutil.ReadAll(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cd.Unmarshal(data, payloadType, v)
|
||||
}
|
||||
|
||||
func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
|
||||
switch data := v.(type) {
|
||||
case string:
|
||||
return []byte(data), TextFrame, nil
|
||||
case []byte:
|
||||
return data, BinaryFrame, nil
|
||||
}
|
||||
return nil, UnknownFrame, ErrNotSupported
|
||||
}
|
||||
|
||||
func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
|
||||
switch data := v.(type) {
|
||||
case *string:
|
||||
*data = string(msg)
|
||||
return nil
|
||||
case *[]byte:
|
||||
*data = msg
|
||||
return nil
|
||||
}
|
||||
return ErrNotSupported
|
||||
}
|
||||
|
||||
/*
|
||||
Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
|
||||
To send/receive text frame, use string type.
|
||||
To send/receive binary frame, use []byte type.
|
||||
|
||||
Trivial usage:
|
||||
|
||||
import "websocket"
|
||||
|
||||
// receive text frame
|
||||
var message string
|
||||
websocket.Message.Receive(ws, &message)
|
||||
|
||||
// send text frame
|
||||
message = "hello"
|
||||
websocket.Message.Send(ws, message)
|
||||
|
||||
// receive binary frame
|
||||
var data []byte
|
||||
websocket.Message.Receive(ws, &data)
|
||||
|
||||
// send binary frame
|
||||
data = []byte{0, 1, 2}
|
||||
websocket.Message.Send(ws, data)
|
||||
*/
|
||||
var Message = Codec{marshal, unmarshal}
|
||||
|
||||
func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
|
||||
msg, err = json.Marshal(v)
|
||||
return msg, TextFrame, err
|
||||
}
|
||||
|
||||
func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
|
||||
return json.Unmarshal(msg, v)
|
||||
}
|
||||
|
||||
/*
|
||||
JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
|
||||
|
||||
Trivial usage:
|
||||
|
||||
import "websocket"
|
||||
|
||||
type T struct {
|
||||
Msg string
|
||||
Count int
|
||||
}
|
||||
|
||||
// receive JSON type T
|
||||
var data T
|
||||
websocket.JSON.Receive(ws, &data)
|
||||
|
||||
// send JSON type T
|
||||
websocket.JSON.Send(ws, data)
|
||||
*/
|
||||
var JSON = Codec{jsonMarshal, jsonUnmarshal}
|
1
vendor/modules.txt
vendored
1
vendor/modules.txt
vendored
|
@ -71,6 +71,7 @@ golang.org/x/mod/semver
|
|||
golang.org/x/net/html
|
||||
golang.org/x/net/html/atom
|
||||
golang.org/x/net/idna
|
||||
golang.org/x/net/websocket
|
||||
# golang.org/x/sys v0.7.0
|
||||
## explicit; go 1.17
|
||||
golang.org/x/sys/cpu
|
||||
|
|
Loading…
Reference in a new issue