@@ -97,82 +97,106 @@ func CloseStatus(err error) StatusCode {
97
97
//
98
98
// Close will unblock all goroutines interacting with the connection once
99
99
// complete.
100
- func (c * Conn ) Close (code StatusCode , reason string ) error {
101
- defer c .wg .Wait ()
102
- return c .closeHandshake (code , reason )
100
+ func (c * Conn ) Close (code StatusCode , reason string ) (err error ) {
101
+ defer errd .Wrap (& err , "failed to close WebSocket" )
102
+
103
+ if ! c .casClosing () {
104
+ err = c .waitGoroutines ()
105
+ if err != nil {
106
+ return err
107
+ }
108
+ return net .ErrClosed
109
+ }
110
+ defer func () {
111
+ if errors .Is (err , net .ErrClosed ) {
112
+ err = nil
113
+ }
114
+ }()
115
+
116
+ err = c .closeHandshake (code , reason )
117
+
118
+ err2 := c .close ()
119
+ if err == nil && err2 != nil {
120
+ err = err2
121
+ }
122
+
123
+ err2 = c .waitGoroutines ()
124
+ if err == nil && err2 != nil {
125
+ err = err2
126
+ }
127
+
128
+ return err
103
129
}
104
130
105
131
// CloseNow closes the WebSocket connection without attempting a close handshake.
106
132
// Use when you do not want the overhead of the close handshake.
107
133
func (c * Conn ) CloseNow () (err error ) {
108
- defer c .wg .Wait ()
109
134
defer errd .Wrap (& err , "failed to close WebSocket" )
110
135
111
- if c .isClosed () {
136
+ if ! c .casClosing () {
137
+ err = c .waitGoroutines ()
138
+ if err != nil {
139
+ return err
140
+ }
112
141
return net .ErrClosed
113
142
}
143
+ defer func () {
144
+ if errors .Is (err , net .ErrClosed ) {
145
+ err = nil
146
+ }
147
+ }()
114
148
115
- c .close (nil )
116
- c .closeMu .Lock ()
117
- defer c .closeMu .Unlock ()
118
- return c .closeErr
119
- }
120
-
121
- func (c * Conn ) closeHandshake (code StatusCode , reason string ) (err error ) {
122
- defer errd .Wrap (& err , "failed to close WebSocket" )
123
-
124
- writeErr := c .writeClose (code , reason )
125
- closeHandshakeErr := c .waitCloseHandshake ()
149
+ err = c .close ()
126
150
127
- if writeErr != nil {
128
- return writeErr
151
+ err2 := c .waitGoroutines ()
152
+ if err == nil && err2 != nil {
153
+ err = err2
129
154
}
155
+ return err
156
+ }
130
157
131
- if CloseStatus (closeHandshakeErr ) == - 1 && ! errors .Is (net .ErrClosed , closeHandshakeErr ) {
132
- return closeHandshakeErr
158
+ func (c * Conn ) closeHandshake (code StatusCode , reason string ) error {
159
+ err := c .writeClose (code , reason )
160
+ if err != nil {
161
+ return err
133
162
}
134
163
164
+ err = c .waitCloseHandshake ()
165
+ if CloseStatus (err ) != code {
166
+ return err
167
+ }
135
168
return nil
136
169
}
137
170
138
171
func (c * Conn ) writeClose (code StatusCode , reason string ) error {
139
- c .closeMu .Lock ()
140
- wroteClose := c .wroteClose
141
- c .wroteClose = true
142
- c .closeMu .Unlock ()
143
- if wroteClose {
144
- return net .ErrClosed
145
- }
146
-
147
172
ce := CloseError {
148
173
Code : code ,
149
174
Reason : reason ,
150
175
}
151
176
152
177
var p []byte
153
- var marshalErr error
178
+ var err error
154
179
if ce .Code != StatusNoStatusRcvd {
155
- p , marshalErr = ce .bytes ()
156
- }
157
-
158
- writeErr := c .writeControl (context .Background (), opClose , p )
159
- if CloseStatus (writeErr ) != - 1 {
160
- // Not a real error if it's due to a close frame being received.
161
- writeErr = nil
180
+ p , err = ce .bytes ()
181
+ if err != nil {
182
+ return err
183
+ }
162
184
}
163
185
164
- // We do this after in case there was an error writing the close frame.
165
- c . setCloseErr ( fmt . Errorf ( "sent close frame: %w" , ce ) )
186
+ ctx , cancel := context . WithTimeout ( context . Background (), time . Second * 5 )
187
+ defer cancel ( )
166
188
167
- if marshalErr != nil {
168
- return marshalErr
189
+ err = c .writeControl (ctx , opClose , p )
190
+ // If the connection closed as we're writing we ignore the error as we might
191
+ // have written the close frame, the peer responded and then someone else read it
192
+ // and closed the connection.
193
+ if err != nil && ! errors .Is (err , net .ErrClosed ) {
194
+ return err
169
195
}
170
- return writeErr
196
+ return nil
171
197
}
172
198
173
199
func (c * Conn ) waitCloseHandshake () error {
174
- defer c .close (nil )
175
-
176
200
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
177
201
defer cancel ()
178
202
@@ -208,6 +232,36 @@ func (c *Conn) waitCloseHandshake() error {
208
232
}
209
233
}
210
234
235
+ func (c * Conn ) waitGoroutines () error {
236
+ t := time .NewTimer (time .Second * 15 )
237
+ defer t .Stop ()
238
+
239
+ select {
240
+ case <- c .timeoutLoopDone :
241
+ case <- t .C :
242
+ return errors .New ("failed to wait for timeoutLoop goroutine to exit" )
243
+ }
244
+
245
+ c .closeReadMu .Lock ()
246
+ ctx := c .closeReadCtx
247
+ c .closeReadMu .Unlock ()
248
+ if ctx != nil {
249
+ select {
250
+ case <- ctx .Done ():
251
+ case <- t .C :
252
+ return errors .New ("failed to wait for close read goroutine to exit" )
253
+ }
254
+ }
255
+
256
+ select {
257
+ case <- c .closed :
258
+ case <- t .C :
259
+ return errors .New ("failed to wait for connection to be closed" )
260
+ }
261
+
262
+ return nil
263
+ }
264
+
211
265
func parseClosePayload (p []byte ) (CloseError , error ) {
212
266
if len (p ) == 0 {
213
267
return CloseError {
@@ -278,16 +332,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
278
332
return buf , nil
279
333
}
280
334
281
- func (c * Conn ) setCloseErr ( err error ) {
335
+ func (c * Conn ) casClosing () bool {
282
336
c .closeMu .Lock ()
283
- c .setCloseErrLocked (err )
284
- c .closeMu .Unlock ()
285
- }
286
-
287
- func (c * Conn ) setCloseErrLocked (err error ) {
288
- if c .closeErr == nil && err != nil {
289
- c .closeErr = fmt .Errorf ("WebSocket closed: %w" , err )
337
+ defer c .closeMu .Unlock ()
338
+ if ! c .closing {
339
+ c .closing = true
340
+ return true
290
341
}
342
+ return false
291
343
}
292
344
293
345
func (c * Conn ) isClosed () bool {
0 commit comments