|
1 | 1 | package gateway
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "bufio" |
4 | 5 | "bytes"
|
| 6 | + "context" |
5 | 7 | "encoding/json"
|
6 | 8 | "fmt"
|
| 9 | + "io" |
7 | 10 | "io/ioutil"
|
8 | 11 | "net/http"
|
9 | 12 | "net/http/httptest"
|
@@ -1048,3 +1051,129 @@ func TestEnsureTransport(t *testing.T) {
|
1048 | 1051 | })
|
1049 | 1052 | }
|
1050 | 1053 | }
|
| 1054 | + |
| 1055 | +func TestReverseProxyWebSocketCancelation(t *testing.T) { |
| 1056 | + c := config.Global() |
| 1057 | + c.HttpServerOptions.EnableWebSockets = true |
| 1058 | + config.SetGlobal(c) |
| 1059 | + n := 5 |
| 1060 | + triggerCancelCh := make(chan bool, n) |
| 1061 | + nthResponse := func(i int) string { |
| 1062 | + return fmt.Sprintf("backend response #%d\n", i) |
| 1063 | + } |
| 1064 | + terminalMsg := "final message" |
| 1065 | + |
| 1066 | + cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 1067 | + if g, ws := upgradeType(r.Header), "websocket"; g != ws { |
| 1068 | + t.Errorf("Unexpected upgrade type %q, want %q", g, ws) |
| 1069 | + http.Error(w, "Unexpected request", 400) |
| 1070 | + return |
| 1071 | + } |
| 1072 | + conn, bufrw, err := w.(http.Hijacker).Hijack() |
| 1073 | + if err != nil { |
| 1074 | + t.Error(err) |
| 1075 | + return |
| 1076 | + } |
| 1077 | + defer conn.Close() |
| 1078 | + |
| 1079 | + upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" |
| 1080 | + if _, err := io.WriteString(conn, upgradeMsg); err != nil { |
| 1081 | + t.Error(err) |
| 1082 | + return |
| 1083 | + } |
| 1084 | + if _, _, err := bufrw.ReadLine(); err != nil { |
| 1085 | + t.Errorf("Failed to read line from client: %v", err) |
| 1086 | + return |
| 1087 | + } |
| 1088 | + |
| 1089 | + for i := 0; i < n; i++ { |
| 1090 | + if _, err := bufrw.WriteString(nthResponse(i)); err != nil { |
| 1091 | + select { |
| 1092 | + case <-triggerCancelCh: |
| 1093 | + default: |
| 1094 | + t.Errorf("Writing response #%d failed: %v", i, err) |
| 1095 | + } |
| 1096 | + return |
| 1097 | + } |
| 1098 | + bufrw.Flush() |
| 1099 | + time.Sleep(time.Second) |
| 1100 | + } |
| 1101 | + if _, err := bufrw.WriteString(terminalMsg); err != nil { |
| 1102 | + select { |
| 1103 | + case <-triggerCancelCh: |
| 1104 | + default: |
| 1105 | + t.Errorf("Failed to write terminal message: %v", err) |
| 1106 | + } |
| 1107 | + } |
| 1108 | + bufrw.Flush() |
| 1109 | + })) |
| 1110 | + defer cst.Close() |
| 1111 | + |
| 1112 | + backendURL, _ := url.Parse(cst.URL) |
| 1113 | + spec := &APISpec{APIDefinition: &apidef.APIDefinition{}} |
| 1114 | + rproxy := TykNewSingleHostReverseProxy(backendURL, spec, nil) |
| 1115 | + |
| 1116 | + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { |
| 1117 | + rw.Header().Set("X-Header", "X-Value") |
| 1118 | + ctx, cancel := context.WithCancel(req.Context()) |
| 1119 | + go func() { |
| 1120 | + <-triggerCancelCh |
| 1121 | + cancel() |
| 1122 | + }() |
| 1123 | + rproxy.ServeHTTP(rw, req.WithContext(ctx)) |
| 1124 | + }) |
| 1125 | + |
| 1126 | + frontendProxy := httptest.NewServer(handler) |
| 1127 | + defer frontendProxy.Close() |
| 1128 | + |
| 1129 | + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) |
| 1130 | + req.Header.Set("Connection", "Upgrade") |
| 1131 | + req.Header.Set("Upgrade", "websocket") |
| 1132 | + |
| 1133 | + res, err := frontendProxy.Client().Do(req) |
| 1134 | + if err != nil { |
| 1135 | + t.Fatalf("Dialing to frontend proxy: %v", err) |
| 1136 | + } |
| 1137 | + defer res.Body.Close() |
| 1138 | + if g, w := res.StatusCode, 101; g != w { |
| 1139 | + t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) |
| 1140 | + } |
| 1141 | + |
| 1142 | + if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { |
| 1143 | + t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) |
| 1144 | + } |
| 1145 | + |
| 1146 | + if g, w := upgradeType(res.Header), "websocket"; g != w { |
| 1147 | + t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) |
| 1148 | + } |
| 1149 | + |
| 1150 | + rwc, ok := res.Body.(io.ReadWriteCloser) |
| 1151 | + if !ok { |
| 1152 | + t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) |
| 1153 | + } |
| 1154 | + |
| 1155 | + if _, err := io.WriteString(rwc, "Hello\n"); err != nil { |
| 1156 | + t.Fatalf("Failed to write first message: %v", err) |
| 1157 | + } |
| 1158 | + |
| 1159 | + // Read loop. |
| 1160 | + |
| 1161 | + br := bufio.NewReader(rwc) |
| 1162 | + for { |
| 1163 | + line, err := br.ReadString('\n') |
| 1164 | + switch { |
| 1165 | + case line == terminalMsg: // this case before "err == io.EOF" |
| 1166 | + t.Fatalf("The websocket request was not canceled, unfortunately!") |
| 1167 | + |
| 1168 | + case err == io.EOF: |
| 1169 | + return |
| 1170 | + |
| 1171 | + case err != nil: |
| 1172 | + t.Fatalf("Unexpected error: %v", err) |
| 1173 | + |
| 1174 | + case line == nthResponse(0): // We've gotten the first response back |
| 1175 | + // Let's trigger a cancel. |
| 1176 | + close(triggerCancelCh) |
| 1177 | + } |
| 1178 | + } |
| 1179 | +} |
0 commit comments