mirror of
https://github.com/mjl-/mox.git
synced 2025-01-16 02:16:26 +03:00
259928ab62
if we recognize that a request for a WebForward is trying to turn the connection into a websocket, we forward it to the backend and check if the backend understands the websocket request. if so, we pass back the upgrade response and get out of the way, copying bytes between the two. we do log the total amount of bytes read from the client and written to the client. if the backend doesn't respond with a websocke response, or an invalid one, we respond with a regular non-websocket response. and we log details about the failed connection, should help with debugging and any bug reports. we don't try to parse the websocket framing, that's between the client and the backend. we could try to parse it, in part to protect the backend from bad frames, but it would be a lot of work and could be brittle in the face of extensions. this doesn't yet handle websocket connections when a http proxy is configured. we'll implement it when someone needs it. we do recognize it and fail the connection. for issue #25
583 lines
16 KiB
Go
583 lines
16 KiB
Go
// Copyright 2011 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package websocket
|
|
|
|
// This file implements a protocol of hybi draft.
|
|
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/sha1"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
)
|
|
|
|
const (
|
|
websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
|
|
|
closeStatusNormal = 1000
|
|
closeStatusGoingAway = 1001
|
|
closeStatusProtocolError = 1002
|
|
closeStatusUnsupportedData = 1003
|
|
closeStatusFrameTooLarge = 1004
|
|
closeStatusNoStatusRcvd = 1005
|
|
closeStatusAbnormalClosure = 1006
|
|
closeStatusBadMessageData = 1007
|
|
closeStatusPolicyViolation = 1008
|
|
closeStatusTooBigData = 1009
|
|
closeStatusExtensionMismatch = 1010
|
|
|
|
maxControlFramePayloadLength = 125
|
|
)
|
|
|
|
var (
|
|
ErrBadMaskingKey = &ProtocolError{"bad masking key"}
|
|
ErrBadPongMessage = &ProtocolError{"bad pong message"}
|
|
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
|
|
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
|
|
ErrNotImplemented = &ProtocolError{"not implemented"}
|
|
|
|
handshakeHeader = map[string]bool{
|
|
"Host": true,
|
|
"Upgrade": true,
|
|
"Connection": true,
|
|
"Sec-Websocket-Key": true,
|
|
"Sec-Websocket-Origin": true,
|
|
"Sec-Websocket-Version": true,
|
|
"Sec-Websocket-Protocol": true,
|
|
"Sec-Websocket-Accept": true,
|
|
}
|
|
)
|
|
|
|
// A hybiFrameHeader is a frame header as defined in hybi draft.
|
|
type hybiFrameHeader struct {
|
|
Fin bool
|
|
Rsv [3]bool
|
|
OpCode byte
|
|
Length int64
|
|
MaskingKey []byte
|
|
|
|
data *bytes.Buffer
|
|
}
|
|
|
|
// A hybiFrameReader is a reader for hybi frame.
|
|
type hybiFrameReader struct {
|
|
reader io.Reader
|
|
|
|
header hybiFrameHeader
|
|
pos int64
|
|
length int
|
|
}
|
|
|
|
func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
|
|
n, err = frame.reader.Read(msg)
|
|
if frame.header.MaskingKey != nil {
|
|
for i := 0; i < n; i++ {
|
|
msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
|
|
frame.pos++
|
|
}
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
|
|
|
|
func (frame *hybiFrameReader) HeaderReader() io.Reader {
|
|
if frame.header.data == nil {
|
|
return nil
|
|
}
|
|
if frame.header.data.Len() == 0 {
|
|
return nil
|
|
}
|
|
return frame.header.data
|
|
}
|
|
|
|
func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
|
|
|
|
func (frame *hybiFrameReader) Len() (n int) { return frame.length }
|
|
|
|
// A hybiFrameReaderFactory creates new frame reader based on its frame type.
|
|
type hybiFrameReaderFactory struct {
|
|
*bufio.Reader
|
|
}
|
|
|
|
// NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
|
|
// See Section 5.2 Base Framing protocol for detail.
|
|
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
|
|
func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
|
|
hybiFrame := new(hybiFrameReader)
|
|
frame = hybiFrame
|
|
var header []byte
|
|
var b byte
|
|
// First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
|
|
b, err = buf.ReadByte()
|
|
if err != nil {
|
|
return
|
|
}
|
|
header = append(header, b)
|
|
hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
|
|
for i := 0; i < 3; i++ {
|
|
j := uint(6 - i)
|
|
hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
|
|
}
|
|
hybiFrame.header.OpCode = header[0] & 0x0f
|
|
|
|
// Second byte. Mask/Payload len(7bits)
|
|
b, err = buf.ReadByte()
|
|
if err != nil {
|
|
return
|
|
}
|
|
header = append(header, b)
|
|
mask := (b & 0x80) != 0
|
|
b &= 0x7f
|
|
lengthFields := 0
|
|
switch {
|
|
case b <= 125: // Payload length 7bits.
|
|
hybiFrame.header.Length = int64(b)
|
|
case b == 126: // Payload length 7+16bits
|
|
lengthFields = 2
|
|
case b == 127: // Payload length 7+64bits
|
|
lengthFields = 8
|
|
}
|
|
for i := 0; i < lengthFields; i++ {
|
|
b, err = buf.ReadByte()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits
|
|
b &= 0x7f
|
|
}
|
|
header = append(header, b)
|
|
hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
|
|
}
|
|
if mask {
|
|
// Masking key. 4 bytes.
|
|
for i := 0; i < 4; i++ {
|
|
b, err = buf.ReadByte()
|
|
if err != nil {
|
|
return
|
|
}
|
|
header = append(header, b)
|
|
hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
|
|
}
|
|
}
|
|
hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
|
|
hybiFrame.header.data = bytes.NewBuffer(header)
|
|
hybiFrame.length = len(header) + int(hybiFrame.header.Length)
|
|
return
|
|
}
|
|
|
|
// A HybiFrameWriter is a writer for hybi frame.
|
|
type hybiFrameWriter struct {
|
|
writer *bufio.Writer
|
|
|
|
header *hybiFrameHeader
|
|
}
|
|
|
|
func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
|
|
var header []byte
|
|
var b byte
|
|
if frame.header.Fin {
|
|
b |= 0x80
|
|
}
|
|
for i := 0; i < 3; i++ {
|
|
if frame.header.Rsv[i] {
|
|
j := uint(6 - i)
|
|
b |= 1 << j
|
|
}
|
|
}
|
|
b |= frame.header.OpCode
|
|
header = append(header, b)
|
|
if frame.header.MaskingKey != nil {
|
|
b = 0x80
|
|
} else {
|
|
b = 0
|
|
}
|
|
lengthFields := 0
|
|
length := len(msg)
|
|
switch {
|
|
case length <= 125:
|
|
b |= byte(length)
|
|
case length < 65536:
|
|
b |= 126
|
|
lengthFields = 2
|
|
default:
|
|
b |= 127
|
|
lengthFields = 8
|
|
}
|
|
header = append(header, b)
|
|
for i := 0; i < lengthFields; i++ {
|
|
j := uint((lengthFields - i - 1) * 8)
|
|
b = byte((length >> j) & 0xff)
|
|
header = append(header, b)
|
|
}
|
|
if frame.header.MaskingKey != nil {
|
|
if len(frame.header.MaskingKey) != 4 {
|
|
return 0, ErrBadMaskingKey
|
|
}
|
|
header = append(header, frame.header.MaskingKey...)
|
|
frame.writer.Write(header)
|
|
data := make([]byte, length)
|
|
for i := range data {
|
|
data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
|
|
}
|
|
frame.writer.Write(data)
|
|
err = frame.writer.Flush()
|
|
return length, err
|
|
}
|
|
frame.writer.Write(header)
|
|
frame.writer.Write(msg)
|
|
err = frame.writer.Flush()
|
|
return length, err
|
|
}
|
|
|
|
func (frame *hybiFrameWriter) Close() error { return nil }
|
|
|
|
type hybiFrameWriterFactory struct {
|
|
*bufio.Writer
|
|
needMaskingKey bool
|
|
}
|
|
|
|
func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
|
|
frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
|
|
if buf.needMaskingKey {
|
|
frameHeader.MaskingKey, err = generateMaskingKey()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
|
|
}
|
|
|
|
type hybiFrameHandler struct {
|
|
conn *Conn
|
|
payloadType byte
|
|
}
|
|
|
|
func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
|
|
if handler.conn.IsServerConn() {
|
|
// The client MUST mask all frames sent to the server.
|
|
if frame.(*hybiFrameReader).header.MaskingKey == nil {
|
|
handler.WriteClose(closeStatusProtocolError)
|
|
return nil, io.EOF
|
|
}
|
|
} else {
|
|
// The server MUST NOT mask all frames.
|
|
if frame.(*hybiFrameReader).header.MaskingKey != nil {
|
|
handler.WriteClose(closeStatusProtocolError)
|
|
return nil, io.EOF
|
|
}
|
|
}
|
|
if header := frame.HeaderReader(); header != nil {
|
|
io.Copy(ioutil.Discard, header)
|
|
}
|
|
switch frame.PayloadType() {
|
|
case ContinuationFrame:
|
|
frame.(*hybiFrameReader).header.OpCode = handler.payloadType
|
|
case TextFrame, BinaryFrame:
|
|
handler.payloadType = frame.PayloadType()
|
|
case CloseFrame:
|
|
return nil, io.EOF
|
|
case PingFrame, PongFrame:
|
|
b := make([]byte, maxControlFramePayloadLength)
|
|
n, err := io.ReadFull(frame, b)
|
|
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
|
|
return nil, err
|
|
}
|
|
io.Copy(ioutil.Discard, frame)
|
|
if frame.PayloadType() == PingFrame {
|
|
if _, err := handler.WritePong(b[:n]); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return nil, nil
|
|
}
|
|
return frame, nil
|
|
}
|
|
|
|
func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
|
|
handler.conn.wio.Lock()
|
|
defer handler.conn.wio.Unlock()
|
|
w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
msg := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(msg, uint16(status))
|
|
_, err = w.Write(msg)
|
|
w.Close()
|
|
return err
|
|
}
|
|
|
|
func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
|
|
handler.conn.wio.Lock()
|
|
defer handler.conn.wio.Unlock()
|
|
w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
n, err = w.Write(msg)
|
|
w.Close()
|
|
return n, err
|
|
}
|
|
|
|
// newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
|
|
func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
|
if buf == nil {
|
|
br := bufio.NewReader(rwc)
|
|
bw := bufio.NewWriter(rwc)
|
|
buf = bufio.NewReadWriter(br, bw)
|
|
}
|
|
ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
|
|
frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
|
|
frameWriterFactory: hybiFrameWriterFactory{
|
|
buf.Writer, request == nil},
|
|
PayloadType: TextFrame,
|
|
defaultCloseStatus: closeStatusNormal}
|
|
ws.frameHandler = &hybiFrameHandler{conn: ws}
|
|
return ws
|
|
}
|
|
|
|
// generateMaskingKey generates a masking key for a frame.
|
|
func generateMaskingKey() (maskingKey []byte, err error) {
|
|
maskingKey = make([]byte, 4)
|
|
if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
// generateNonce generates a nonce consisting of a randomly selected 16-byte
|
|
// value that has been base64-encoded.
|
|
func generateNonce() (nonce []byte) {
|
|
key := make([]byte, 16)
|
|
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
|
panic(err)
|
|
}
|
|
nonce = make([]byte, 24)
|
|
base64.StdEncoding.Encode(nonce, key)
|
|
return
|
|
}
|
|
|
|
// removeZone removes IPv6 zone identifier from host.
|
|
// E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
|
|
func removeZone(host string) string {
|
|
if !strings.HasPrefix(host, "[") {
|
|
return host
|
|
}
|
|
i := strings.LastIndex(host, "]")
|
|
if i < 0 {
|
|
return host
|
|
}
|
|
j := strings.LastIndex(host[:i], "%")
|
|
if j < 0 {
|
|
return host
|
|
}
|
|
return host[:j] + host[i:]
|
|
}
|
|
|
|
// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
|
|
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
|
|
func getNonceAccept(nonce []byte) (expected []byte, err error) {
|
|
h := sha1.New()
|
|
if _, err = h.Write(nonce); err != nil {
|
|
return
|
|
}
|
|
if _, err = h.Write([]byte(websocketGUID)); err != nil {
|
|
return
|
|
}
|
|
expected = make([]byte, 28)
|
|
base64.StdEncoding.Encode(expected, h.Sum(nil))
|
|
return
|
|
}
|
|
|
|
// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
|
|
func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
|
|
bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
|
|
|
|
// According to RFC 6874, an HTTP client, proxy, or other
|
|
// intermediary must remove any IPv6 zone identifier attached
|
|
// to an outgoing URI.
|
|
bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n")
|
|
bw.WriteString("Upgrade: websocket\r\n")
|
|
bw.WriteString("Connection: Upgrade\r\n")
|
|
nonce := generateNonce()
|
|
if config.handshakeData != nil {
|
|
nonce = []byte(config.handshakeData["key"])
|
|
}
|
|
bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
|
|
bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
|
|
|
|
if config.Version != ProtocolVersionHybi13 {
|
|
return ErrBadProtocolVersion
|
|
}
|
|
|
|
bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
|
|
if len(config.Protocol) > 0 {
|
|
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
|
|
}
|
|
// TODO(ukai): send Sec-WebSocket-Extensions.
|
|
err = config.Header.WriteSubset(bw, handshakeHeader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
bw.WriteString("\r\n")
|
|
if err = bw.Flush(); err != nil {
|
|
return err
|
|
}
|
|
|
|
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resp.StatusCode != 101 {
|
|
return ErrBadStatus
|
|
}
|
|
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
|
|
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
|
|
return ErrBadUpgrade
|
|
}
|
|
expectedAccept, err := getNonceAccept(nonce)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
|
|
return ErrChallengeResponse
|
|
}
|
|
if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
|
|
return ErrUnsupportedExtensions
|
|
}
|
|
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
|
|
if offeredProtocol != "" {
|
|
protocolMatched := false
|
|
for i := 0; i < len(config.Protocol); i++ {
|
|
if config.Protocol[i] == offeredProtocol {
|
|
protocolMatched = true
|
|
break
|
|
}
|
|
}
|
|
if !protocolMatched {
|
|
return ErrBadWebSocketProtocol
|
|
}
|
|
config.Protocol = []string{offeredProtocol}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// newHybiClientConn creates a client WebSocket connection after handshake.
|
|
func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
|
|
return newHybiConn(config, buf, rwc, nil)
|
|
}
|
|
|
|
// A HybiServerHandshaker performs a server handshake using hybi draft protocol.
|
|
type hybiServerHandshaker struct {
|
|
*Config
|
|
accept []byte
|
|
}
|
|
|
|
func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
|
|
c.Version = ProtocolVersionHybi13
|
|
if req.Method != "GET" {
|
|
return http.StatusMethodNotAllowed, ErrBadRequestMethod
|
|
}
|
|
// HTTP version can be safely ignored.
|
|
|
|
if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
|
|
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
|
|
return http.StatusBadRequest, ErrNotWebSocket
|
|
}
|
|
|
|
key := req.Header.Get("Sec-Websocket-Key")
|
|
if key == "" {
|
|
return http.StatusBadRequest, ErrChallengeResponse
|
|
}
|
|
version := req.Header.Get("Sec-Websocket-Version")
|
|
switch version {
|
|
case "13":
|
|
c.Version = ProtocolVersionHybi13
|
|
default:
|
|
return http.StatusBadRequest, ErrBadWebSocketVersion
|
|
}
|
|
var scheme string
|
|
if req.TLS != nil {
|
|
scheme = "wss"
|
|
} else {
|
|
scheme = "ws"
|
|
}
|
|
c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
|
|
if err != nil {
|
|
return http.StatusBadRequest, err
|
|
}
|
|
protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
|
|
if protocol != "" {
|
|
protocols := strings.Split(protocol, ",")
|
|
for i := 0; i < len(protocols); i++ {
|
|
c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
|
|
}
|
|
}
|
|
c.accept, err = getNonceAccept([]byte(key))
|
|
if err != nil {
|
|
return http.StatusInternalServerError, err
|
|
}
|
|
return http.StatusSwitchingProtocols, nil
|
|
}
|
|
|
|
// Origin parses the Origin header in req.
|
|
// If the Origin header is not set, it returns nil and nil.
|
|
func Origin(config *Config, req *http.Request) (*url.URL, error) {
|
|
var origin string
|
|
switch config.Version {
|
|
case ProtocolVersionHybi13:
|
|
origin = req.Header.Get("Origin")
|
|
}
|
|
if origin == "" {
|
|
return nil, nil
|
|
}
|
|
return url.ParseRequestURI(origin)
|
|
}
|
|
|
|
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
|
|
if len(c.Protocol) > 0 {
|
|
if len(c.Protocol) != 1 {
|
|
// You need choose a Protocol in Handshake func in Server.
|
|
return ErrBadWebSocketProtocol
|
|
}
|
|
}
|
|
buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
|
|
buf.WriteString("Upgrade: websocket\r\n")
|
|
buf.WriteString("Connection: Upgrade\r\n")
|
|
buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
|
|
if len(c.Protocol) > 0 {
|
|
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
|
|
}
|
|
// TODO(ukai): send Sec-WebSocket-Extensions.
|
|
if c.Header != nil {
|
|
err := c.Header.WriteSubset(buf, handshakeHeader)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
buf.WriteString("\r\n")
|
|
return buf.Flush()
|
|
}
|
|
|
|
func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
|
return newHybiServerConn(c.Config, buf, rwc, request)
|
|
}
|
|
|
|
// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
|
|
func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
|
|
return newHybiConn(config, buf, rwc, request)
|
|
}
|