Skip to content

Commit c2e8dc4

Browse files
committed
Close target connection when request is canceled
ackport changes added to the net/http/httputil ReverseProxy for go1.15 . See release notes here https://golang.org/doc/go1.15#net/http/httputil
1 parent 3434dc8 commit c2e8dc4

File tree

2 files changed

+149
-1
lines changed

2 files changed

+149
-1
lines changed

gateway/reverse_proxy.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import (
4242
"github.com/opentracing/opentracing-go/ext"
4343
cache "github.com/pmylund/go-cache"
4444
"github.com/sirupsen/logrus"
45+
"golang.org/x/net/http/httpguts"
4546
"golang.org/x/net/http2"
4647
)
4748

@@ -1108,6 +1109,13 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader) (int64, error) {
11081109
}
11091110
}
11101111

1112+
func upgradeType(h http.Header) string {
1113+
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
1114+
return ""
1115+
}
1116+
return strings.ToLower(h.Get("Upgrade"))
1117+
}
1118+
11111119
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) error {
11121120
copyHeader(res.Header, rw.Header())
11131121

@@ -1119,7 +1127,18 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R
11191127
if !ok {
11201128
return fmt.Errorf("internal error: 101 switching protocols response with non-writable body")
11211129
}
1122-
defer backConn.Close()
1130+
backConnCloseCh := make(chan bool)
1131+
go func() {
1132+
// Ensure that the cancelation of a request closes the backend.
1133+
// See issue https://golang.org/issue/35559.
1134+
select {
1135+
case <-req.Context().Done():
1136+
case <-backConnCloseCh:
1137+
}
1138+
backConn.Close()
1139+
}()
1140+
1141+
defer close(backConnCloseCh)
11231142
conn, brw, err := hj.Hijack()
11241143
if err != nil {
11251144
return fmt.Errorf("Hijack failed on protocol switch: %v", err)

gateway/reverse_proxy_test.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package gateway
22

33
import (
4+
"bufio"
45
"bytes"
6+
"context"
57
"encoding/json"
68
"fmt"
9+
"io"
710
"io/ioutil"
811
"net/http"
912
"net/http/httptest"
@@ -781,3 +784,129 @@ func TestEnsureTransport(t *testing.T) {
781784
})
782785
}
783786
}
787+
788+
func TestReverseProxyWebSocketCancelation(t *testing.T) {
789+
c := config.Global()
790+
c.HttpServerOptions.EnableWebSockets = true
791+
config.SetGlobal(c)
792+
n := 5
793+
triggerCancelCh := make(chan bool, n)
794+
nthResponse := func(i int) string {
795+
return fmt.Sprintf("backend response #%d\n", i)
796+
}
797+
terminalMsg := "final message"
798+
799+
cst := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
800+
if g, ws := upgradeType(r.Header), "websocket"; g != ws {
801+
t.Errorf("Unexpected upgrade type %q, want %q", g, ws)
802+
http.Error(w, "Unexpected request", 400)
803+
return
804+
}
805+
conn, bufrw, err := w.(http.Hijacker).Hijack()
806+
if err != nil {
807+
t.Error(err)
808+
return
809+
}
810+
defer conn.Close()
811+
812+
upgradeMsg := "HTTP/1.1 101 Switching Protocols\r\nConnection: upgrade\r\nUpgrade: WebSocket\r\n\r\n"
813+
if _, err := io.WriteString(conn, upgradeMsg); err != nil {
814+
t.Error(err)
815+
return
816+
}
817+
if _, _, err := bufrw.ReadLine(); err != nil {
818+
t.Errorf("Failed to read line from client: %v", err)
819+
return
820+
}
821+
822+
for i := 0; i < n; i++ {
823+
if _, err := bufrw.WriteString(nthResponse(i)); err != nil {
824+
select {
825+
case <-triggerCancelCh:
826+
default:
827+
t.Errorf("Writing response #%d failed: %v", i, err)
828+
}
829+
return
830+
}
831+
bufrw.Flush()
832+
time.Sleep(time.Second)
833+
}
834+
if _, err := bufrw.WriteString(terminalMsg); err != nil {
835+
select {
836+
case <-triggerCancelCh:
837+
default:
838+
t.Errorf("Failed to write terminal message: %v", err)
839+
}
840+
}
841+
bufrw.Flush()
842+
}))
843+
defer cst.Close()
844+
845+
backendURL, _ := url.Parse(cst.URL)
846+
spec := &APISpec{APIDefinition: &apidef.APIDefinition{}}
847+
rproxy := TykNewSingleHostReverseProxy(backendURL, spec, nil)
848+
849+
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
850+
rw.Header().Set("X-Header", "X-Value")
851+
ctx, cancel := context.WithCancel(req.Context())
852+
go func() {
853+
<-triggerCancelCh
854+
cancel()
855+
}()
856+
rproxy.ServeHTTP(rw, req.WithContext(ctx))
857+
})
858+
859+
frontendProxy := httptest.NewServer(handler)
860+
defer frontendProxy.Close()
861+
862+
req, _ := http.NewRequest("GET", frontendProxy.URL, nil)
863+
req.Header.Set("Connection", "Upgrade")
864+
req.Header.Set("Upgrade", "websocket")
865+
866+
res, err := frontendProxy.Client().Do(req)
867+
if err != nil {
868+
t.Fatalf("Dialing to frontend proxy: %v", err)
869+
}
870+
defer res.Body.Close()
871+
if g, w := res.StatusCode, 101; g != w {
872+
t.Fatalf("Switching protocols failed, got: %d, want: %d", g, w)
873+
}
874+
875+
if g, w := res.Header.Get("X-Header"), "X-Value"; g != w {
876+
t.Errorf("X-Header mismatch\n\tgot: %q\n\twant: %q", g, w)
877+
}
878+
879+
if g, w := upgradeType(res.Header), "websocket"; g != w {
880+
t.Fatalf("Upgrade header mismatch\n\tgot: %q\n\twant: %q", g, w)
881+
}
882+
883+
rwc, ok := res.Body.(io.ReadWriteCloser)
884+
if !ok {
885+
t.Fatalf("Response body type mismatch, got %T, want io.ReadWriteCloser", res.Body)
886+
}
887+
888+
if _, err := io.WriteString(rwc, "Hello\n"); err != nil {
889+
t.Fatalf("Failed to write first message: %v", err)
890+
}
891+
892+
// Read loop.
893+
894+
br := bufio.NewReader(rwc)
895+
for {
896+
line, err := br.ReadString('\n')
897+
switch {
898+
case line == terminalMsg: // this case before "err == io.EOF"
899+
t.Fatalf("The websocket request was not canceled, unfortunately!")
900+
901+
case err == io.EOF:
902+
return
903+
904+
case err != nil:
905+
t.Fatalf("Unexpected error: %v", err)
906+
907+
case line == nthResponse(0): // We've gotten the first response back
908+
// Let's trigger a cancel.
909+
close(triggerCancelCh)
910+
}
911+
}
912+
}

0 commit comments

Comments
 (0)