Merge pull request #306 from mholt/bug/websocket-races

fixed data races in websockets
This commit is contained in:
Matt Holt 2015-11-05 17:19:35 -07:00
commit a1481bc29e

View file

@ -4,9 +4,12 @@
package websocket package websocket
import ( import (
"bufio"
"bytes"
"io" "io"
"net" "net"
"net/http" "net/http"
"os"
"os/exec" "os/exec"
"strings" "strings"
"time" "time"
@ -88,15 +91,18 @@ func serveWS(w http.ResponseWriter, r *http.Request, config *Config) (int, error
defer conn.Close() defer conn.Close()
cmd := exec.Command(config.Command, config.Arguments...) cmd := exec.Command(config.Command, config.Arguments...)
stdout, err := cmd.StdoutPipe() stdout, err := cmd.StdoutPipe()
if err != nil { if err != nil {
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
defer stdout.Close()
stdin, err := cmd.StdinPipe() stdin, err := cmd.StdinPipe()
if err != nil { if err != nil {
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
defer stdin.Close()
metavars, err := buildEnv(cmd.Path, r) metavars, err := buildEnv(cmd.Path, r)
if err != nil { if err != nil {
@ -109,7 +115,31 @@ func serveWS(w http.ResponseWriter, r *http.Request, config *Config) (int, error
return http.StatusBadGateway, err return http.StatusBadGateway, err
} }
reader(conn, stdout, stdin) done := make(chan struct{})
go pumpStdout(conn, stdout, done)
pumpStdin(conn, stdin)
stdin.Close() // close stdin to end the process
if err := cmd.Process.Signal(os.Interrupt); err != nil { // signal an interrupt to kill the process
return http.StatusInternalServerError, err
}
select {
case <-done:
case <-time.After(time.Second):
// terminate with extreme prejudice.
if err := cmd.Process.Signal(os.Kill); err != nil {
return http.StatusInternalServerError, err
}
<-done
}
// not sure what we want to do here.
// status for an "exited" process is greater
// than 0, but isn't really an error per se.
// just going to ignore it for now.
cmd.Wait()
return 0, nil return 0, nil
} }
@ -163,63 +193,60 @@ func buildEnv(cmdPath string, r *http.Request) (metavars []string, err error) {
return return
} }
// reader is the guts of this package. It takes the stdin and stdout pipes // pumpStdin handles reading data from the websocket connection and writing
// of the cmd we created in ServeWS and pipes them between the client and server // it to stdin of the process.
// over websockets. func pumpStdin(conn *websocket.Conn, stdin io.WriteCloser) {
func reader(conn *websocket.Conn, stdout io.ReadCloser, stdin io.WriteCloser) {
// Setup our connection's websocket ping/pong handlers from our const values. // Setup our connection's websocket ping/pong handlers from our const values.
defer conn.Close()
conn.SetReadLimit(maxMessageSize) conn.SetReadLimit(maxMessageSize)
conn.SetReadDeadline(time.Now().Add(pongWait)) conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
tickerChan := make(chan bool)
defer close(tickerChan) // make sure to close the ticker when we are done.
go ticker(conn, tickerChan)
for { for {
msgType, r, err := conn.NextReader() _, message, err := conn.ReadMessage()
if err != nil { if err != nil {
if msgType == -1 { break
return // we got a disconnect from the client. We are good to close.
}
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{})
return
} }
message = append(message, '\n')
w, err := conn.NextWriter(msgType) if _, err := stdin.Write(message); err != nil {
if err != nil { break
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{})
return
} }
if _, err := io.Copy(stdin, r); err != nil {
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{})
return
}
go func() {
if _, err := io.Copy(w, stdout); err != nil {
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{})
return
}
if err := w.Close(); err != nil {
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{})
return
}
}()
} }
} }
// ticker is start by the reader. Basically it is the method that simulates the websocket // pumpStdout handles reading data from stdout of the process and writing
// between the server and client to keep it alive with ping messages. // it to websocket connection.
func ticker(conn *websocket.Conn, c chan bool) { func pumpStdout(conn *websocket.Conn, stdout io.Reader, done chan struct{}) {
go pinger(conn, done)
defer func() {
conn.Close()
close(done) // make sure to close the pinger when we are done.
}()
s := bufio.NewScanner(stdout)
for s.Scan() {
conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := conn.WriteMessage(websocket.TextMessage, bytes.TrimSpace(s.Bytes())); err != nil {
break
}
}
if s.Err() != nil {
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, s.Err().Error()), time.Time{})
}
}
// pinger simulates the websocket to keep it alive with ping messages.
func pinger(conn *websocket.Conn, done chan struct{}) {
ticker := time.NewTicker(pingPeriod) ticker := time.NewTicker(pingPeriod)
defer ticker.Stop() defer ticker.Stop()
for { // blocking loop with select to wait for stimulation. for { // blocking loop with select to wait for stimulation.
select { select {
case <-ticker.C: case <-ticker.C:
conn.WriteMessage(websocket.PingMessage, nil) if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
case <-c: conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, err.Error()), time.Time{})
return
}
case <-done:
return // clean up this routine. return // clean up this routine.
} }
} }