Skip to content

Commit f6dda0a

Browse files
committed
net/http/httputil: TestReverseProxyWebSocketCancelation
1 parent eb88375 commit f6dda0a

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

src/net/http/httputil/reverseproxy_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,130 @@ func TestReverseProxyWebSocket(t *testing.T) {
11581158
}
11591159
}
11601160

1161+
func TestReverseProxyWebSocketCancelation(t *testing.T) {
1162+
n := 5
1163+
triggerCancelCh := make(chan interface{}, n)
1164+
progressCh := make(chan interface{})
1165+
nthResponse := func(i int) string {
1166+
return fmt.Sprintf("backend response #%d\n", i)
1167+
}
1168+
terminalMsg := "final message"
1169+
1170+
cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1171+
if g, ws := upgradeType(r.Header), "websocket"; g != ws {
1172+
t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
1173+
http.Error(w, "Unexpected request", 400)
1174+
return
1175+
}
1176+
conn, bufrw, err := w.(http.Hijacker).Hijack()
1177+
if err != nil {
1178+
t.Error(err)
1179+
return
1180+
}
1181+
defer conn.Close()
1182+
1183+
upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
1184+
if _, err := io.WriteString(conn, upgradeMsg); err != nil {
1185+
t.Error(err)
1186+
return
1187+
}
1188+
if _, _, err := bufrw.ReadLine(); err != nil {
1189+
t.Errorf("Failed to read line from client: %v", err)
1190+
return
1191+
}
1192+
1193+
for i := 0; i < n; i++ {
1194+
if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
1195+
t.Errorf("Writing response #%d failed: %v", i, err)
1196+
}
1197+
bufrw.Flush()
1198+
progressCh <- true
1199+
}
1200+
if _, err := bufrw.WriteString(terminalMsg); err != nil {
1201+
t.Errorf("Failed to write terminal message: %v", err)
1202+
}
1203+
bufrw.Flush()
1204+
}))
1205+
defer cst.Close()
1206+
1207+
backendURL, _ := url.Parse(cst.URL)
1208+
rproxy := NewSingleHostReverseProxy(backendURL)
1209+
rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
1210+
rproxy.ModifyResponse = func(res *http.Response) error {
1211+
res.Header.Add("X-Modified", "true")
1212+
return nil
1213+
}
1214+
1215+
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
1216+
rw.Header().Set("X-Header", "X-Value")
1217+
ctx, cancel := context.WithCancel(req.Context())
1218+
go func() {
1219+
<-triggerCancelCh
1220+
cancel()
1221+
}()
1222+
rproxy.ServeHTTP(rw, req.WithContext(ctx))
1223+
})
1224+
1225+
frontendProxy := httptest.NewServer(handler)
1226+
defer frontendProxy.Close()
1227+
1228+
req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
1229+
req.Header.Set("Connection", "Upgrade")
1230+
req.Header.Set("Upgrade", "websocket")
1231+
1232+
res, err := frontendProxy.Client().Do(req)
1233+
if err != nil {
1234+
t.Fatalf("Dialing to frontend proxy: %v", err)
1235+
}
1236+
defer res.Body.Close()
1237+
if g, w := res.StatusCode, 101; g != w {
1238+
t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
1239+
}
1240+
1241+
if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
1242+
t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
1243+
}
1244+
1245+
if g, w := upgradeType(res.Header), "websocket"; g != w {
1246+
t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
1247+
}
1248+
1249+
rwc, ok := res.Body.(io.ReadWriteCloser)
1250+
if !ok {
1251+
t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
1252+
}
1253+
1254+
if got, want := res.Header.Get("X-Modified"), "true"; got != want {
1255+
t.Errorf("response X-Modified header = %q; want %q", got, want)
1256+
}
1257+
1258+
if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
1259+
t.Fatalf("Failed to write first message: %v", err)
1260+
}
1261+
1262+
// Read loop.
1263+
1264+
br := bufio.NewReader(rwc)
1265+
for {
1266+
line, err := br.ReadString('\n')
1267+
switch {
1268+
case line == terminalMsg: // this case before "err == io.EOF"
1269+
t.Fatalf("The websocket request was not canceled, unfortunately!")
1270+
1271+
case err == io.EOF:
1272+
return
1273+
1274+
case err != nil:
1275+
t.Fatalf("Unexpected error: %v", err)
1276+
1277+
case line == nthResponse(0): // We've gotten the first response back
1278+
// Let's trigger a cancel.
1279+
close(triggerCancelCh)
1280+
}
1281+
<-progressCh
1282+
}
1283+
}
1284+
11611285
func TestUnannouncedTrailer(t *testing.T) {
11621286
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
11631287
w.WriteHeader(http.StatusOK)

0 commit comments

Comments
 (0)