Skip to content

Commit b0c36b9

Browse files
committed
Add Grace to gracefully close WebSocket connections
Closes #199
1 parent fa720b9 commit b0c36b9

File tree

7 files changed

+202
-12
lines changed

7 files changed

+202
-12
lines changed

accept.go

+18-2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
6565
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
6666
defer errd.Wrap(&err, "failed to accept WebSocket connection")
6767

68+
g := graceFromRequest(r)
69+
if g != nil && g.isClosing() {
70+
err := errors.New("server closing")
71+
http.Error(w, err.Error(), http.StatusServiceUnavailable)
72+
return nil, err
73+
}
74+
6875
if opts == nil {
6976
opts = &AcceptOptions{}
7077
}
@@ -120,7 +127,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
120127
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
121128
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
122129

123-
return newConn(connConfig{
130+
c := newConn(connConfig{
124131
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
125132
rwc: netConn,
126133
client: false,
@@ -129,7 +136,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
129136

130137
br: brw.Reader,
131138
bw: brw.Writer,
132-
}), nil
139+
})
140+
141+
if g != nil {
142+
err = g.addConn(c)
143+
if err != nil {
144+
return nil, err
145+
}
146+
}
147+
148+
return c, nil
133149
}
134150

135151
func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {

conn_notjs.go

+5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ type Conn struct {
3333
flateThreshold int
3434
br *bufio.Reader
3535
bw *bufio.Writer
36+
g *Grace
3637

3738
readTimeout chan context.Context
3839
writeTimeout chan context.Context
@@ -138,6 +139,10 @@ func (c *Conn) close(err error) {
138139
// closeErr.
139140
c.rwc.Close()
140141

142+
if c.g != nil {
143+
c.g.delConn(c)
144+
}
145+
141146
go func() {
142147
c.msgWriterState.close()
143148

conn_test.go

+4-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"os"
1414
"os/exec"
1515
"strings"
16-
"sync"
1716
"testing"
1817
"time"
1918

@@ -272,11 +271,9 @@ func TestWasm(t *testing.T) {
272271
t.Skip("skipping on CI")
273272
}
274273

275-
var wg sync.WaitGroup
276-
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
277-
wg.Add(1)
278-
defer wg.Done()
279-
274+
var g websocket.Grace
275+
defer g.Close()
276+
s := httptest.NewServer(g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
280277
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
281278
Subprotocols: []string{"echo"},
282279
InsecureSkipVerify: true,
@@ -294,8 +291,7 @@ func TestWasm(t *testing.T) {
294291
t.Errorf("echo server failed: %v", err)
295292
return
296293
}
297-
}))
298-
defer wg.Wait()
294+
})))
299295
defer s.Close()
300296

301297
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)

example_echo_test.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@ func Example_echo() {
3131
}
3232
defer l.Close()
3333

34+
var g websocket.Grace
35+
defer g.Close()
3436
s := &http.Server{
35-
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
37+
Handler: g.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3638
err := echoServer(w, r)
3739
if err != nil {
3840
log.Printf("echo server: %v", err)
3941
}
40-
}),
42+
})),
4143
ReadTimeout: time.Second * 15,
4244
WriteTimeout: time.Second * 15,
4345
}

example_test.go

+46
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"log"
88
"net/http"
99
"net/url"
10+
"os"
11+
"os/signal"
1012
"time"
1113

1214
"nhooyr.io/websocket"
@@ -143,3 +145,47 @@ func Example_crossOrigin() {
143145
err := http.ListenAndServe("localhost:8080", fn)
144146
log.Fatal(err)
145147
}
148+
149+
// This example demonstrates how to create a WebSocket server
150+
// that gracefully exits when sent a signal.
151+
//
152+
// It starts a WebSocket server that keeps every connection open
153+
// for 10 seconds.
154+
// If you CTRL+C while a connection is open, it will wait at most 30s
155+
// for all connections to terminate before shutting down.
156+
func ExampleGrace() {
157+
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
158+
c, err := websocket.Accept(w, r, nil)
159+
if err != nil {
160+
log.Println(err)
161+
return
162+
}
163+
defer c.Close(websocket.StatusInternalError, "the sky is falling")
164+
165+
ctx := c.CloseRead(r.Context())
166+
select {
167+
case <-ctx.Done():
168+
case <-time.After(time.Second * 10):
169+
}
170+
171+
c.Close(websocket.StatusNormalClosure, "")
172+
})
173+
174+
var g websocket.Grace
175+
s := &http.Server{
176+
Handler: g.Handler(fn),
177+
ReadTimeout: time.Second * 15,
178+
WriteTimeout: time.Second * 15,
179+
}
180+
go s.ListenAndServe()
181+
182+
sigs := make(chan os.Signal, 1)
183+
signal.Notify(sigs, os.Interrupt)
184+
sig := <-sigs
185+
log.Printf("recieved %v, shutting down", sig)
186+
187+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
188+
defer cancel()
189+
s.Shutdown(ctx)
190+
g.Shutdown(ctx)
191+
}

grace.go

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
package websocket
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net/http"
8+
"sync"
9+
"time"
10+
)
11+
12+
// Grace enables graceful shutdown of accepted WebSocket connections.
13+
//
14+
// Use Handler to wrap WebSocket handlers to record accepted connections
15+
// and then use Close or Shutdown to gracefully close these connections.
16+
//
17+
// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods.
18+
type Grace struct {
19+
mu sync.Mutex
20+
closing bool
21+
conns map[*Conn]struct{}
22+
}
23+
24+
// Handler returns a handler that wraps around h to record
25+
// all WebSocket connections accepted.
26+
//
27+
// Use Close or Shutdown to gracefully close recorded connections.
28+
func (g *Grace) Handler(h http.Handler) http.Handler {
29+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30+
ctx := context.WithValue(r.Context(), gracefulContextKey{}, g)
31+
r = r.WithContext(ctx)
32+
h.ServeHTTP(w, r)
33+
})
34+
}
35+
36+
func (g *Grace) isClosing() bool {
37+
g.mu.Lock()
38+
defer g.mu.Unlock()
39+
return g.closing
40+
}
41+
42+
func graceFromRequest(r *http.Request) *Grace {
43+
g, _ := r.Context().Value(gracefulContextKey{}).(*Grace)
44+
return g
45+
}
46+
47+
func (g *Grace) addConn(c *Conn) error {
48+
g.mu.Lock()
49+
defer g.mu.Unlock()
50+
if g.closing {
51+
c.Close(StatusGoingAway, "server shutting down")
52+
return errors.New("server shutting down")
53+
}
54+
if g.conns == nil {
55+
g.conns = make(map[*Conn]struct{})
56+
}
57+
g.conns[c] = struct{}{}
58+
c.g = g
59+
return nil
60+
}
61+
62+
func (g *Grace) delConn(c *Conn) {
63+
g.mu.Lock()
64+
defer g.mu.Unlock()
65+
delete(g.conns, c)
66+
}
67+
68+
type gracefulContextKey struct{}
69+
70+
// Close prevents the acceptance of new connections with
71+
// http.StatusServiceUnavailable and closes all accepted
72+
// connections with StatusGoingAway.
73+
func (g *Grace) Close() error {
74+
g.mu.Lock()
75+
g.closing = true
76+
var wg sync.WaitGroup
77+
for c := range g.conns {
78+
wg.Add(1)
79+
go func(c *Conn) {
80+
defer wg.Done()
81+
c.Close(StatusGoingAway, "server shutting down")
82+
}(c)
83+
84+
delete(g.conns, c)
85+
}
86+
g.mu.Unlock()
87+
88+
wg.Wait()
89+
90+
return nil
91+
}
92+
93+
// Shutdown prevents the acceptance of new connections and waits until
94+
// all connections close. If the context is cancelled before that, it
95+
// calls Close to close all connections immediately.
96+
func (g *Grace) Shutdown(ctx context.Context) error {
97+
defer g.Close()
98+
99+
g.mu.Lock()
100+
g.closing = true
101+
g.mu.Unlock()
102+
103+
// Same poll period used by net/http.
104+
t := time.NewTicker(500 * time.Millisecond)
105+
defer t.Stop()
106+
for {
107+
if g.zeroConns() {
108+
return nil
109+
}
110+
111+
select {
112+
case <-t.C:
113+
case <-ctx.Done():
114+
return fmt.Errorf("failed to shutdown WebSockets: %w", ctx.Err())
115+
}
116+
}
117+
}
118+
119+
func (g *Grace) zeroConns() bool {
120+
g.mu.Lock()
121+
defer g.mu.Unlock()
122+
return len(g.conns) == 0
123+
}

ws_js.go

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ type Conn struct {
3838
readSignal chan struct{}
3939
readBufMu sync.Mutex
4040
readBuf []wsjs.MessageEvent
41+
42+
g *Grace
4143
}
4244

4345
func (c *Conn) close(err error, wasClean bool) {

0 commit comments

Comments
 (0)