@@ -2,7 +2,6 @@ package websocket
2
2
3
3
import (
4
4
"context"
5
- "errors"
6
5
"fmt"
7
6
"net/http"
8
7
"sync"
@@ -17,98 +16,92 @@ import (
17
16
// Grace is intended to be used in harmony with net/http.Server's Shutdown and Close methods.
18
17
// It's required as net/http's Shutdown and Close methods do not keep track of WebSocket
19
18
// connections.
19
+ //
20
+ // Make sure to Close or Shutdown the *http.Server first as you don't want to accept
21
+ // any new connections while the existing websockets are being shut down.
20
22
type Grace struct {
21
- mu sync.Mutex
22
- closed bool
23
- shuttingDown bool
24
- conns map [* Conn ]struct {}
23
+ handlersMu sync.Mutex
24
+ closing bool
25
+ handlers map [context.Context ]context.CancelFunc
25
26
}
26
27
27
28
// Handler returns a handler that wraps around h to record
28
29
// all WebSocket connections accepted.
29
30
//
30
31
// Use Close or Shutdown to gracefully close recorded connections.
32
+ // Make sure to Close or Shutdown the *http.Server first.
31
33
func (g * Grace ) Handler (h http.Handler ) http.Handler {
32
34
return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
33
- ctx := context .WithValue (r .Context (), gracefulContextKey {}, g )
35
+ ctx , cancel := context .WithCancel (r .Context ())
36
+ defer cancel ()
37
+
34
38
r = r .WithContext (ctx )
39
+
40
+ ok := g .add (w , ctx , cancel )
41
+ if ! ok {
42
+ return
43
+ }
44
+ defer g .del (ctx )
45
+
35
46
h .ServeHTTP (w , r )
36
47
})
37
48
}
38
49
39
- func (g * Grace ) isShuttingdown () bool {
40
- g .mu .Lock ()
41
- defer g .mu .Unlock ()
42
- return g .shuttingDown
43
- }
44
-
45
- func graceFromRequest (r * http.Request ) * Grace {
46
- g , _ := r .Context ().Value (gracefulContextKey {}).(* Grace )
47
- return g
48
- }
50
+ func (g * Grace ) add (w http.ResponseWriter , ctx context.Context , cancel context.CancelFunc ) bool {
51
+ g .handlersMu .Lock ()
52
+ defer g .handlersMu .Unlock ()
49
53
50
- func (g * Grace ) addConn (c * Conn ) error {
51
- g .mu .Lock ()
52
- defer g .mu .Unlock ()
53
- if g .closed {
54
- c .Close (StatusGoingAway , "server shutting down" )
55
- return errors .New ("server shutting down" )
54
+ if g .closing {
55
+ http .Error (w , "shutting down" , http .StatusServiceUnavailable )
56
+ return false
56
57
}
57
- if g .conns == nil {
58
- g .conns = make (map [* Conn ]struct {})
58
+
59
+ if g .handlers == nil {
60
+ g .handlers = make (map [context.Context ]context.CancelFunc )
59
61
}
60
- g .conns [c ] = struct {}{}
61
- c .g = g
62
- return nil
63
- }
62
+ g .handlers [ctx ] = cancel
64
63
65
- func (g * Grace ) delConn (c * Conn ) {
66
- g .mu .Lock ()
67
- defer g .mu .Unlock ()
68
- delete (g .conns , c )
64
+ return true
69
65
}
70
66
71
- type gracefulContextKey struct {}
67
+ func (g * Grace ) del (ctx context.Context ) {
68
+ g .handlersMu .Lock ()
69
+ defer g .handlersMu .Unlock ()
70
+
71
+ delete (g .handlers , ctx )
72
+ }
72
73
73
74
// Close prevents the acceptance of new connections with
74
75
// http.StatusServiceUnavailable and closes all accepted
75
76
// connections with StatusGoingAway.
77
+ //
78
+ // Make sure to Close or Shutdown the *http.Server first.
76
79
func (g * Grace ) Close () error {
77
- g .mu .Lock ()
78
- g .shuttingDown = true
79
- g .closed = true
80
- var wg sync.WaitGroup
81
- for c := range g .conns {
82
- wg .Add (1 )
83
- go func (c * Conn ) {
84
- defer wg .Done ()
85
- c .Close (StatusGoingAway , "server shutting down" )
86
- }(c )
87
-
88
- delete (g .conns , c )
80
+ g .handlersMu .Lock ()
81
+ for _ , cancel := range g .handlers {
82
+ cancel ()
89
83
}
90
- g .mu .Unlock ()
84
+ g .handlersMu .Unlock ()
91
85
92
- wg .Wait ()
86
+ // Wait for all goroutines to exit.
87
+ g .Shutdown (context .Background ())
93
88
94
89
return nil
95
90
}
96
91
97
92
// Shutdown prevents the acceptance of new connections and waits until
98
93
// all connections close. If the context is cancelled before that, it
99
94
// calls Close to close all connections immediately.
95
+ //
96
+ // Make sure to Close or Shutdown the *http.Server first.
100
97
func (g * Grace ) Shutdown (ctx context.Context ) error {
101
98
defer g .Close ()
102
99
103
- g .mu .Lock ()
104
- g .shuttingDown = true
105
- g .mu .Unlock ()
106
-
107
100
// Same poll period used by net/http.
108
101
t := time .NewTicker (500 * time .Millisecond )
109
102
defer t .Stop ()
110
103
for {
111
- if g .zeroConns () {
104
+ if g .zeroHandlers () {
112
105
return nil
113
106
}
114
107
@@ -120,8 +113,8 @@ func (g *Grace) Shutdown(ctx context.Context) error {
120
113
}
121
114
}
122
115
123
- func (g * Grace ) zeroConns () bool {
124
- g .mu .Lock ()
125
- defer g .mu .Unlock ()
126
- return len (g .conns ) == 0
116
+ func (g * Grace ) zeroHandlers () bool {
117
+ g .handlersMu .Lock ()
118
+ defer g .handlersMu .Unlock ()
119
+ return len (g .handlers ) == 0
127
120
}
0 commit comments