|
1 | 1 | package gateway
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "bufio" |
4 | 5 | "bytes"
|
| 6 | + "context" |
5 | 7 | "encoding/json"
|
6 | 8 | "fmt"
|
| 9 | + "io" |
7 | 10 | "io/ioutil"
|
8 | 11 | "net/http"
|
9 | 12 | "net/http/httptest"
|
@@ -781,3 +784,129 @@ func TestEnsureTransport(t *testing.T) {
|
781 | 784 | })
|
782 | 785 | }
|
783 | 786 | }
|
| 787 | + |
| 788 | +func TestReverseProxyWebSocketCancelation(t *testing.T) { |
| 789 | + c := config.Global() |
| 790 | + c.HttpServerOptions.EnableWebSockets = true |
| 791 | + config.SetGlobal(c) |
| 792 | + n := 5 |
| 793 | + triggerCancelCh := make(chan bool, n) |
| 794 | + nthResponse := func(i int) string { |
| 795 | + return fmt.Sprintf("backend response #%d\n", i) |
| 796 | + } |
| 797 | + terminalMsg := "final message" |
| 798 | + |
| 799 | + cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 800 | + if g, ws := upgradeType(r.Header), "websocket"; g != ws { |
| 801 | + t.Errorf("Unexpected upgrade type %q, want %q", g, ws) |
| 802 | + http.Error(w, "Unexpected request", 400) |
| 803 | + return |
| 804 | + } |
| 805 | + conn, bufrw, err := w.(http.Hijacker).Hijack() |
| 806 | + if err != nil { |
| 807 | + t.Error(err) |
| 808 | + return |
| 809 | + } |
| 810 | + defer conn.Close() |
| 811 | + |
| 812 | + upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n" |
| 813 | + if _, err := io.WriteString(conn, upgradeMsg); err != nil { |
| 814 | + t.Error(err) |
| 815 | + return |
| 816 | + } |
| 817 | + if _, _, err := bufrw.ReadLine(); err != nil { |
| 818 | + t.Errorf("Failed to read line from client: %v", err) |
| 819 | + return |
| 820 | + } |
| 821 | + |
| 822 | + for i := 0; i < n; i++ { |
| 823 | + if _, err := bufrw.WriteString(nthResponse(i)); err != nil { |
| 824 | + select { |
| 825 | + case <-triggerCancelCh: |
| 826 | + default: |
| 827 | + t.Errorf("Writing response #%d failed: %v", i, err) |
| 828 | + } |
| 829 | + return |
| 830 | + } |
| 831 | + bufrw.Flush() |
| 832 | + time.Sleep(time.Second) |
| 833 | + } |
| 834 | + if _, err := bufrw.WriteString(terminalMsg); err != nil { |
| 835 | + select { |
| 836 | + case <-triggerCancelCh: |
| 837 | + default: |
| 838 | + t.Errorf("Failed to write terminal message: %v", err) |
| 839 | + } |
| 840 | + } |
| 841 | + bufrw.Flush() |
| 842 | + })) |
| 843 | + defer cst.Close() |
| 844 | + |
| 845 | + backendURL, _ := url.Parse(cst.URL) |
| 846 | + spec := &APISpec{APIDefinition: &apidef.APIDefinition{}} |
| 847 | + rproxy := TykNewSingleHostReverseProxy(backendURL, spec, nil) |
| 848 | + |
| 849 | + handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { |
| 850 | + rw.Header().Set("X-Header", "X-Value") |
| 851 | + ctx, cancel := context.WithCancel(req.Context()) |
| 852 | + go func() { |
| 853 | + <-triggerCancelCh |
| 854 | + cancel() |
| 855 | + }() |
| 856 | + rproxy.ServeHTTP(rw, req.WithContext(ctx)) |
| 857 | + }) |
| 858 | + |
| 859 | + frontendProxy := httptest.NewServer(handler) |
| 860 | + defer frontendProxy.Close() |
| 861 | + |
| 862 | + req, _ := http.NewRequest("GET", frontendProxy.URL, nil) |
| 863 | + req.Header.Set("Connection", "Upgrade") |
| 864 | + req.Header.Set("Upgrade", "websocket") |
| 865 | + |
| 866 | + res, err := frontendProxy.Client().Do(req) |
| 867 | + if err != nil { |
| 868 | + t.Fatalf("Dialing to frontend proxy: %v", err) |
| 869 | + } |
| 870 | + defer res.Body.Close() |
| 871 | + if g, w := res.StatusCode, 101; g != w { |
| 872 | + t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w) |
| 873 | + } |
| 874 | + |
| 875 | + if g, w := res.Header.Get("X-Header"), "X-Value"; g != w { |
| 876 | + t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w) |
| 877 | + } |
| 878 | + |
| 879 | + if g, w := upgradeType(res.Header), "websocket"; g != w { |
| 880 | + t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w) |
| 881 | + } |
| 882 | + |
| 883 | + rwc, ok := res.Body.(io.ReadWriteCloser) |
| 884 | + if !ok { |
| 885 | + t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body) |
| 886 | + } |
| 887 | + |
| 888 | + if _, err := io.WriteString(rwc, "Hello\n"); err != nil { |
| 889 | + t.Fatalf("Failed to write first message: %v", err) |
| 890 | + } |
| 891 | + |
| 892 | + // Read loop. |
| 893 | + |
| 894 | + br := bufio.NewReader(rwc) |
| 895 | + for { |
| 896 | + line, err := br.ReadString('\n') |
| 897 | + switch { |
| 898 | + case line == terminalMsg: // this case before "err == io.EOF" |
| 899 | + t.Fatalf("The websocket request was not canceled, unfortunately!") |
| 900 | + |
| 901 | + case err == io.EOF: |
| 902 | + return |
| 903 | + |
| 904 | + case err != nil: |
| 905 | + t.Fatalf("Unexpected error: %v", err) |
| 906 | + |
| 907 | + case line == nthResponse(0): // We've gotten the first response back |
| 908 | + // Let's trigger a cancel. |
| 909 | + close(triggerCancelCh) |
| 910 | + } |
| 911 | + } |
| 912 | +} |
0 commit comments