fixed data races in websockets

This commit is contained in:
Austin 2015-11-02 14:19:38 -08:00
parent b6078eded1
commit abc7c6a148

View file

@ -4,9 +4,11 @@
package websocket
import (
"bufio"
"io"
"net"
"net/http"
"os"
"os/exec"
"strings"
"time"
@ -88,15 +90,18 @@ func serveWS(w http.ResponseWriter, r *http.Request, config *Config) (int, error
defer conn.Close()
cmd := exec.Command(config.Command, config.Arguments...)
stdout, err := cmd.StdoutPipe()
if err != nil {
return http.StatusBadGateway, err
}
defer stdout.Close()
stdin, err := cmd.StdinPipe()
if err != nil {
return http.StatusBadGateway, err
}
defer stdin.Close()
metavars, err := buildEnv(cmd.Path, r)
if err != nil {
@ -109,7 +114,31 @@ func serveWS(w http.ResponseWriter, r *http.Request, config *Config) (int, error
return http.StatusBadGateway, err
}
reader(conn, stdout, stdin)
done := make(chan struct{})
go pumpStdout(conn, stdout, done)
pumpStdin(conn, stdin)
stdin.Close() // close stdin to end the process
if err := cmd.Process.Signal(os.Interrupt); err != nil { // signal an interrupt to kill the process
return http.StatusInternalServerError, err
}
select {
case <-done:
case <-time.After(time.Second):
// terminate with extreme prejudice.
if err := cmd.Process.Signal(os.Kill); err != nil {
return http.StatusInternalServerError, err
}
<-done
}
// not sure what we want to do here.
// status for an "exited" process is greater
// than 0, but isn't really an error per se.
// just going to ignore it for now.
cmd.Wait()
return 0, nil
}
@ -163,63 +192,60 @@ func buildEnv(cmdPath string, r *http.Request) (metavars []string, err error) {
return
}
// reader is the guts of this package. It takes the stdin and stdout pipes
// of the cmd we created in ServeWS and pipes them between the client and server
// over websockets.
func reader(conn *websocket.Conn, stdout io.ReadCloser, stdin io.WriteCloser) {
// pumpStdin handles reading data from the websocket connection and writing
// it to stdin of the process.
func pumpStdin(conn *websocket.Conn, stdin io.WriteCloser) {
// Setup our connection's websocket ping/pong handlers from our const values.
defer conn.Close()
conn.SetReadLimit(maxMessageSize)
conn.SetReadDeadline(time.Now().Add(pongWait))
conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
tickerChan := make(chan bool)
defer close(tickerChan) // make sure to close the ticker when we are done.
go ticker(conn, tickerChan)
for {
msgType, r, err := conn.NextReader()
_, message, err := conn.ReadMessage()
if err != nil {
if msgType == -1 {
return // we got a disconnect from the client. We are good to close.
}
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{})
return
break
}
w, err := conn.NextWriter(msgType)
if err != nil {
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{})
return
message = append(message, '\n')
if _, err := stdin.Write(message); err != nil {
break
}
if _, err := io.Copy(stdin, r); err != nil {
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{})
return
}
go func() {
if _, err := io.Copy(w, stdout); err != nil {
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{})
return
}
if err := w.Close(); err != nil {
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""), time.Time{})
return
}
}()
}
}
// ticker is start by the reader. Basically it is the method that simulates the websocket
// between the server and client to keep it alive with ping messages.
func ticker(conn *websocket.Conn, c chan bool) {
// pumpStdout handles reading data from stdout of the process and writing
// it to websocket connection.
func pumpStdout(conn *websocket.Conn, stdout io.Reader, done chan struct{}) {
go pinger(conn, done)
defer func() {
conn.Close()
close(done) // make sure to close the pinger when we are done.
}()
s := bufio.NewScanner(stdout)
for s.Scan() {
conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := conn.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil {
break
}
}
if s.Err() != nil {
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, s.Err().Error()), time.Time{})
}
}
// pinger simulates the websocket to keep it alive with ping messages.
func pinger(conn *websocket.Conn, done chan struct{}) {
ticker := time.NewTicker(pingPeriod)
defer ticker.Stop()
for { // blocking loop with select to wait for stimulation.
select {
case <-ticker.C:
conn.WriteMessage(websocket.PingMessage, nil)
case <-c:
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil {
conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, err.Error()), time.Time{})
return
}
case <-done:
return // clean up this routine.
}
}