Skip to content

net/http/httputil: make Switching Protocol requests (e.g. Websockets) cancelable #38021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/net/http/httputil/reverseproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,20 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
return
}
defer backConn.Close()

backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancelation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
}
backConn.Close()
}()

defer close(backConnCloseCh)

conn, brw, err := hj.Hijack()
if err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
Expand Down
122 changes: 122 additions & 0 deletions src/net/http/httputil/reverseproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,128 @@ func TestReverseProxyWebSocket(t *testing.T) {
}
}

func TestReverseProxyWebSocketCancelation(t *testing.T) {
n := 5
triggerCancelCh := make(chan bool, n)
nthResponse := func(i int) string {
return fmt.Sprintf("backend response #%d\n", i)
}
terminalMsg := "final message"

cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if g, ws := upgradeType(r.Header), "websocket"; g != ws {
t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
http.Error(w, "Unexpected request", 400)
return
}
conn, bufrw, err := w.(http.Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
defer conn.Close()

upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
if _, err := io.WriteString(conn, upgradeMsg); err != nil {
t.Error(err)
return
}
if _, _, err := bufrw.ReadLine(); err != nil {
t.Errorf("Failed to read line from client: %v", err)
return
}

for i := 0; i < n; i++ {
if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
t.Errorf("Writing response #%d failed: %v", i, err)
}
bufrw.Flush()
time.Sleep(time.Second)
}
if _, err := bufrw.WriteString(terminalMsg); err != nil {
t.Errorf("Failed to write terminal message: %v", err)
}
bufrw.Flush()
}))
defer cst.Close()

backendURL, _ := url.Parse(cst.URL)
rproxy := NewSingleHostReverseProxy(backendURL)
rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
rproxy.ModifyResponse = func(res *http.Response) error {
res.Header.Add("X-Modified", "true")
return nil
}

handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("X-Header", "X-Value")
ctx, cancel := context.WithCancel(req.Context())
go func() {
<-triggerCancelCh
cancel()
}()
rproxy.ServeHTTP(rw, req.WithContext(ctx))
})

frontendProxy := httptest.NewServer(handler)
defer frontendProxy.Close()

req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")

res, err := frontendProxy.Client().Do(req)
if err != nil {
t.Fatalf("Dialing to frontend proxy: %v", err)
}
defer res.Body.Close()
if g, w := res.StatusCode, 101; g != w {
t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
}

if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
}

if g, w := upgradeType(res.Header), "websocket"; g != w {
t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
}

rwc, ok := res.Body.(io.ReadWriteCloser)
if !ok {
t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
}

if got, want := res.Header.Get("X-Modified"), "true"; got != want {
t.Errorf("response X-Modified header = %q; want %q", got, want)
}

if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
t.Fatalf("Failed to write first message: %v", err)
}

// Read loop.

br := bufio.NewReader(rwc)
for {
line, err := br.ReadString('\n')
switch {
case line == terminalMsg: // this case before "err == io.EOF"
t.Fatalf("The websocket request was not canceled, unfortunately!")

case err == io.EOF:
return

case err != nil:
t.Fatalf("Unexpected error: %v", err)

case line == nthResponse(0): // We've gotten the first response back
// Let's trigger a cancel.
close(triggerCancelCh)
}
}
}

func TestUnannouncedTrailer(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
Expand Down