@@ -1158,6 +1158,130 @@ func TestReverseProxyWebSocket(t *testing.T) {
1158
1158
}
1159
1159
}
1160
1160
1161
+ func TestReverseProxyWebSocketCancelation (t * testing.T ) {
1162
+ n := 5
1163
+ triggerCancelCh := make (chan interface {}, n )
1164
+ progressCh := make (chan interface {})
1165
+ nthResponse := func (i int ) string {
1166
+ return fmt .Sprintf ("backend response #%d\n " , i )
1167
+ }
1168
+ terminalMsg := "final message"
1169
+
1170
+ cst := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1171
+ if g , ws := upgradeType (r .Header ), "websocket" ; g != ws {
1172
+ t .Errorf ("Unexpected upgrade type %q, want %q" , g , ws )
1173
+ http .Error (w , "Unexpected request" , 400 )
1174
+ return
1175
+ }
1176
+ conn , bufrw , err := w .(http.Hijacker ).Hijack ()
1177
+ if err != nil {
1178
+ t .Error (err )
1179
+ return
1180
+ }
1181
+ defer conn .Close ()
1182
+
1183
+ upgradeMsg := "HTTP/1.1 101 Switching Protocols\r \n Connection: upgrade\r \n Upgrade: WebSocket\r \n \r \n "
1184
+ if _ , err := io .WriteString (conn , upgradeMsg ); err != nil {
1185
+ t .Error (err )
1186
+ return
1187
+ }
1188
+ if _ , _ , err := bufrw .ReadLine (); err != nil {
1189
+ t .Errorf ("Failed to read line from client: %v" , err )
1190
+ return
1191
+ }
1192
+
1193
+ for i := 0 ; i < n ; i ++ {
1194
+ if _ , err := bufrw .WriteString (nthResponse (i )); err != nil {
1195
+ t .Errorf ("Writing response #%d failed: %v" , i , err )
1196
+ }
1197
+ bufrw .Flush ()
1198
+ progressCh <- true
1199
+ }
1200
+ if _ , err := bufrw .WriteString (terminalMsg ); err != nil {
1201
+ t .Errorf ("Failed to write terminal message: %v" , err )
1202
+ }
1203
+ bufrw .Flush ()
1204
+ }))
1205
+ defer cst .Close ()
1206
+
1207
+ backendURL , _ := url .Parse (cst .URL )
1208
+ rproxy := NewSingleHostReverseProxy (backendURL )
1209
+ rproxy .ErrorLog = log .New (ioutil .Discard , "" , 0 ) // quiet for tests
1210
+ rproxy .ModifyResponse = func (res * http.Response ) error {
1211
+ res .Header .Add ("X-Modified" , "true" )
1212
+ return nil
1213
+ }
1214
+
1215
+ handler := http .HandlerFunc (func (rw http.ResponseWriter , req * http.Request ) {
1216
+ rw .Header ().Set ("X-Header" , "X-Value" )
1217
+ ctx , cancel := context .WithCancel (req .Context ())
1218
+ go func () {
1219
+ <- triggerCancelCh
1220
+ cancel ()
1221
+ }()
1222
+ rproxy .ServeHTTP (rw , req .WithContext (ctx ))
1223
+ })
1224
+
1225
+ frontendProxy := httptest .NewServer (handler )
1226
+ defer frontendProxy .Close ()
1227
+
1228
+ req , _ := http .NewRequest ("GET" , frontendProxy .URL , nil )
1229
+ req .Header .Set ("Connection" , "Upgrade" )
1230
+ req .Header .Set ("Upgrade" , "websocket" )
1231
+
1232
+ res , err := frontendProxy .Client ().Do (req )
1233
+ if err != nil {
1234
+ t .Fatalf ("Dialing to frontend proxy: %v" , err )
1235
+ }
1236
+ defer res .Body .Close ()
1237
+ if g , w := res .StatusCode , 101 ; g != w {
1238
+ t .Fatalf ("Switching protocols failed, got: %d, want: %d" , g , w )
1239
+ }
1240
+
1241
+ if g , w := res .Header .Get ("X-Header" ), "X-Value" ; g != w {
1242
+ t .Errorf ("X-Header mismatch\n \t got: %q\n \t want: %q" , g , w )
1243
+ }
1244
+
1245
+ if g , w := upgradeType (res .Header ), "websocket" ; g != w {
1246
+ t .Fatalf ("Upgrade header mismatch\n \t got: %q\n \t want: %q" , g , w )
1247
+ }
1248
+
1249
+ rwc , ok := res .Body .(io.ReadWriteCloser )
1250
+ if ! ok {
1251
+ t .Fatalf ("Response body type mismatch, got %T, want io.ReadWriteCloser" , res .Body )
1252
+ }
1253
+
1254
+ if got , want := res .Header .Get ("X-Modified" ), "true" ; got != want {
1255
+ t .Errorf ("response X-Modified header = %q; want %q" , got , want )
1256
+ }
1257
+
1258
+ if _ , err := io .WriteString (rwc , "Hello\n " ); err != nil {
1259
+ t .Fatalf ("Failed to write first message: %v" , err )
1260
+ }
1261
+
1262
+ // Read loop.
1263
+
1264
+ br := bufio .NewReader (rwc )
1265
+ for {
1266
+ line , err := br .ReadString ('\n' )
1267
+ switch {
1268
+ case line == terminalMsg : // this case before "err == io.EOF"
1269
+ t .Fatalf ("The websocket request was not canceled, unfortunately!" )
1270
+
1271
+ case err == io .EOF :
1272
+ return
1273
+
1274
+ case err != nil :
1275
+ t .Fatalf ("Unexpected error: %v" , err )
1276
+
1277
+ case line == nthResponse (0 ): // We've gotten the first response back
1278
+ // Let's trigger a cancel.
1279
+ close (triggerCancelCh )
1280
+ }
1281
+ <- progressCh
1282
+ }
1283
+ }
1284
+
1161
1285
func TestUnannouncedTrailer (t * testing.T ) {
1162
1286
backend := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
1163
1287
w .WriteHeader (http .StatusOK )
0 commit comments