Skip to content

Commit db18a31

Browse files
committed
close.go: Rewrite how the library handles closing
Far simpler now. Sorry this took a while. Closes #427 Closes #429 Closes #434 Closes #436 Closes #437
1 parent 0b3912f commit db18a31

File tree

6 files changed

+150
-136
lines changed

6 files changed

+150
-136
lines changed

close.go

+103-51
Original file line numberDiff line numberDiff line change
@@ -97,82 +97,106 @@ func CloseStatus(err error) StatusCode {
9797
//
9898
// Close will unblock all goroutines interacting with the connection once
9999
// complete.
100-
func (c *Conn) Close(code StatusCode, reason string) error {
101-
defer c.wg.Wait()
102-
return c.closeHandshake(code, reason)
100+
func (c *Conn) Close(code StatusCode, reason string) (err error) {
101+
defer errd.Wrap(&err, "failed to close WebSocket")
102+
103+
if !c.casClosing() {
104+
err = c.waitGoroutines()
105+
if err != nil {
106+
return err
107+
}
108+
return net.ErrClosed
109+
}
110+
defer func() {
111+
if errors.Is(err, net.ErrClosed) {
112+
err = nil
113+
}
114+
}()
115+
116+
err = c.closeHandshake(code, reason)
117+
118+
err2 := c.close()
119+
if err == nil && err2 != nil {
120+
err = err2
121+
}
122+
123+
err2 = c.waitGoroutines()
124+
if err == nil && err2 != nil {
125+
err = err2
126+
}
127+
128+
return err
103129
}
104130

105131
// CloseNow closes the WebSocket connection without attempting a close handshake.
106132
// Use when you do not want the overhead of the close handshake.
107133
func (c *Conn) CloseNow() (err error) {
108-
defer c.wg.Wait()
109134
defer errd.Wrap(&err, "failed to close WebSocket")
110135

111-
if c.isClosed() {
136+
if !c.casClosing() {
137+
err = c.waitGoroutines()
138+
if err != nil {
139+
return err
140+
}
112141
return net.ErrClosed
113142
}
143+
defer func() {
144+
if errors.Is(err, net.ErrClosed) {
145+
err = nil
146+
}
147+
}()
114148

115-
c.close(nil)
116-
c.closeMu.Lock()
117-
defer c.closeMu.Unlock()
118-
return c.closeErr
119-
}
120-
121-
func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
122-
defer errd.Wrap(&err, "failed to close WebSocket")
123-
124-
writeErr := c.writeClose(code, reason)
125-
closeHandshakeErr := c.waitCloseHandshake()
149+
err = c.close()
126150

127-
if writeErr != nil {
128-
return writeErr
151+
err2 := c.waitGoroutines()
152+
if err == nil && err2 != nil {
153+
err = err2
129154
}
155+
return err
156+
}
130157

131-
if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) {
132-
return closeHandshakeErr
158+
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
159+
err := c.writeClose(code, reason)
160+
if err != nil {
161+
return err
133162
}
134163

164+
err = c.waitCloseHandshake()
165+
if CloseStatus(err) != code {
166+
return err
167+
}
135168
return nil
136169
}
137170

138171
func (c *Conn) writeClose(code StatusCode, reason string) error {
139-
c.closeMu.Lock()
140-
wroteClose := c.wroteClose
141-
c.wroteClose = true
142-
c.closeMu.Unlock()
143-
if wroteClose {
144-
return net.ErrClosed
145-
}
146-
147172
ce := CloseError{
148173
Code: code,
149174
Reason: reason,
150175
}
151176

152177
var p []byte
153-
var marshalErr error
178+
var err error
154179
if ce.Code != StatusNoStatusRcvd {
155-
p, marshalErr = ce.bytes()
156-
}
157-
158-
writeErr := c.writeControl(context.Background(), opClose, p)
159-
if CloseStatus(writeErr) != -1 {
160-
// Not a real error if it's due to a close frame being received.
161-
writeErr = nil
180+
p, err = ce.bytes()
181+
if err != nil {
182+
return err
183+
}
162184
}
163185

164-
// We do this after in case there was an error writing the close frame.
165-
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
186+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
187+
defer cancel()
166188

167-
if marshalErr != nil {
168-
return marshalErr
189+
err = c.writeControl(ctx, opClose, p)
190+
// If the connection closed as we're writing we ignore the error as we might
191+
// have written the close frame, the peer responded and then someone else read it
192+
// and closed the connection.
193+
if err != nil && !errors.Is(err, net.ErrClosed) {
194+
return err
169195
}
170-
return writeErr
196+
return nil
171197
}
172198

173199
func (c *Conn) waitCloseHandshake() error {
174-
defer c.close(nil)
175-
176200
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
177201
defer cancel()
178202

@@ -208,6 +232,36 @@ func (c *Conn) waitCloseHandshake() error {
208232
}
209233
}
210234

235+
func (c *Conn) waitGoroutines() error {
236+
t := time.NewTimer(time.Second * 15)
237+
defer t.Stop()
238+
239+
select {
240+
case <-c.timeoutLoopDone:
241+
case <-t.C:
242+
return errors.New("failed to wait for timeoutLoop goroutine to exit")
243+
}
244+
245+
c.closeReadMu.Lock()
246+
ctx := c.closeReadCtx
247+
c.closeReadMu.Unlock()
248+
if ctx != nil {
249+
select {
250+
case <-ctx.Done():
251+
case <-t.C:
252+
return errors.New("failed to wait for close read goroutine to exit")
253+
}
254+
}
255+
256+
select {
257+
case <-c.closed:
258+
case <-t.C:
259+
return errors.New("failed to wait for connection to be closed")
260+
}
261+
262+
return nil
263+
}
264+
211265
func parseClosePayload(p []byte) (CloseError, error) {
212266
if len(p) == 0 {
213267
return CloseError{
@@ -278,16 +332,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
278332
return buf, nil
279333
}
280334

281-
func (c *Conn) setCloseErr(err error) {
335+
func (c *Conn) casClosing() bool {
282336
c.closeMu.Lock()
283-
c.setCloseErrLocked(err)
284-
c.closeMu.Unlock()
285-
}
286-
287-
func (c *Conn) setCloseErrLocked(err error) {
288-
if c.closeErr == nil && err != nil {
289-
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
337+
defer c.closeMu.Unlock()
338+
if !c.closing {
339+
c.closing = true
340+
return true
290341
}
342+
return false
291343
}
292344

293345
func (c *Conn) isClosed() bool {

conn.go

+29-45
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package websocket
66
import (
77
"bufio"
88
"context"
9-
"errors"
109
"fmt"
1110
"io"
1211
"net"
@@ -53,8 +52,9 @@ type Conn struct {
5352
br *bufio.Reader
5453
bw *bufio.Writer
5554

56-
readTimeout chan context.Context
57-
writeTimeout chan context.Context
55+
readTimeout chan context.Context
56+
writeTimeout chan context.Context
57+
timeoutLoopDone chan struct{}
5858

5959
// Read state.
6060
readMu *mu
@@ -70,11 +70,12 @@ type Conn struct {
7070
writeHeaderBuf [8]byte
7171
writeHeader header
7272

73-
wg sync.WaitGroup
74-
closed chan struct{}
75-
closeMu sync.Mutex
76-
closeErr error
77-
wroteClose bool
73+
closeReadMu sync.Mutex
74+
closeReadCtx context.Context
75+
76+
closed chan struct{}
77+
closeMu sync.Mutex
78+
closing bool
7879

7980
pingCounter int32
8081
activePingsMu sync.Mutex
@@ -103,8 +104,9 @@ func newConn(cfg connConfig) *Conn {
103104
br: cfg.br,
104105
bw: cfg.bw,
105106

106-
readTimeout: make(chan context.Context),
107-
writeTimeout: make(chan context.Context),
107+
readTimeout: make(chan context.Context),
108+
writeTimeout: make(chan context.Context),
109+
timeoutLoopDone: make(chan struct{}),
108110

109111
closed: make(chan struct{}),
110112
activePings: make(map[string]chan<- struct{}),
@@ -128,14 +130,10 @@ func newConn(cfg connConfig) *Conn {
128130
}
129131

130132
runtime.SetFinalizer(c, func(c *Conn) {
131-
c.close(errors.New("connection garbage collected"))
133+
c.close()
132134
})
133135

134-
c.wg.Add(1)
135-
go func() {
136-
defer c.wg.Done()
137-
c.timeoutLoop()
138-
}()
136+
go c.timeoutLoop()
139137

140138
return c
141139
}
@@ -146,35 +144,29 @@ func (c *Conn) Subprotocol() string {
146144
return c.subprotocol
147145
}
148146

149-
func (c *Conn) close(err error) {
147+
func (c *Conn) close() error {
150148
c.closeMu.Lock()
151149
defer c.closeMu.Unlock()
152150

153151
if c.isClosed() {
154-
return
155-
}
156-
if err == nil {
157-
err = c.rwc.Close()
152+
return net.ErrClosed
158153
}
159-
c.setCloseErrLocked(err)
160-
161-
close(c.closed)
162154
runtime.SetFinalizer(c, nil)
155+
close(c.closed)
163156

164157
// Have to close after c.closed is closed to ensure any goroutine that wakes up
165158
// from the connection being closed also sees that c.closed is closed and returns
166159
// closeErr.
167-
c.rwc.Close()
168-
169-
c.wg.Add(1)
170-
go func() {
171-
defer c.wg.Done()
172-
c.msgWriter.close()
173-
c.msgReader.close()
174-
}()
160+
err := c.rwc.Close()
161+
// With the close of rwc, these become safe to close.
162+
c.msgWriter.close()
163+
c.msgReader.close()
164+
return err
175165
}
176166

177167
func (c *Conn) timeoutLoop() {
168+
defer close(c.timeoutLoopDone)
169+
178170
readCtx := context.Background()
179171
writeCtx := context.Background()
180172

@@ -187,14 +179,10 @@ func (c *Conn) timeoutLoop() {
187179
case readCtx = <-c.readTimeout:
188180

189181
case <-readCtx.Done():
190-
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
191-
c.wg.Add(1)
192-
go func() {
193-
defer c.wg.Done()
194-
c.writeError(StatusPolicyViolation, errors.New("read timed out"))
195-
}()
182+
c.close()
183+
return
196184
case <-writeCtx.Done():
197-
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
185+
c.close()
198186
return
199187
}
200188
}
@@ -243,9 +231,7 @@ func (c *Conn) ping(ctx context.Context, p string) error {
243231
case <-c.closed:
244232
return net.ErrClosed
245233
case <-ctx.Done():
246-
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
247-
c.close(err)
248-
return err
234+
return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
249235
case <-pong:
250236
return nil
251237
}
@@ -281,9 +267,7 @@ func (m *mu) lock(ctx context.Context) error {
281267
case <-m.c.closed:
282268
return net.ErrClosed
283269
case <-ctx.Done():
284-
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
285-
m.c.close(err)
286-
return err
270+
return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
287271
case m.ch <- struct{}{}:
288272
// To make sure the connection is certainly alive.
289273
// As it's possible the send on m.ch was selected

conn_test.go

+3
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ func TestConn(t *testing.T) {
345345

346346
func TestWasm(t *testing.T) {
347347
t.Parallel()
348+
if os.Getenv("CI") == "" {
349+
t.Skip()
350+
}
348351

349352
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
350353
err := echoServer(w, r, &websocket.AcceptOptions{

0 commit comments

Comments
 (0)