Skip to content

Commit f178ccf

Browse files
authored
Merge pull request #169 from nhooyr/race
Fix race with c.readerShouldLock
2 parents e36318f + 8b47056 commit f178ccf

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

conn.go

+40-17
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,10 @@ type Conn struct {
7878
readLock chan struct{}
7979

8080
// messageReader state.
81-
readerMsgCtx context.Context
82-
readerMsgHeader header
83-
readerFrameEOF bool
84-
readerMaskPos int
85-
readerShouldLock bool
81+
readerMsgCtx context.Context
82+
readerMsgHeader header
83+
readerFrameEOF bool
84+
readerMaskPos int
8685

8786
setReadTimeout chan context.Context
8887
setWriteTimeout chan context.Context
@@ -237,6 +236,10 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
237236
if h.opcode.controlOp() {
238237
err = c.handleControl(ctx, h)
239238
if err != nil {
239+
// Pass through CloseErrors when receiving a close frame.
240+
if h.opcode == opClose && CloseStatus(err) != -1 {
241+
return header{}, err
242+
}
240243
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
241244
}
242245
continue
@@ -445,7 +448,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
445448
c.readerFrameEOF = false
446449
c.readerMaskPos = 0
447450
c.readMsgLeft = c.msgReadLimit.Load()
448-
c.readerShouldLock = lock
449451

450452
r := &messageReader{
451453
c: c,
@@ -465,7 +467,11 @@ func (r *messageReader) eof() bool {
465467

466468
// Read reads as many bytes as possible into p.
467469
func (r *messageReader) Read(p []byte) (int, error) {
468-
n, err := r.read(p)
470+
return r.exportedRead(p, true)
471+
}
472+
473+
func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) {
474+
n, err := r.read(p, lock)
469475
if err != nil {
470476
// Have to return io.EOF directly for now, we cannot wrap as errors.Is
471477
// isn't used widely yet.
@@ -477,17 +483,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
477483
return n, nil
478484
}
479485

480-
func (r *messageReader) read(p []byte) (int, error) {
481-
if r.c.readerShouldLock {
482-
err := r.c.acquireLock(r.c.readerMsgCtx, r.c.readLock)
483-
if err != nil {
484-
return 0, err
486+
func (r *messageReader) readUnlocked(p []byte) (int, error) {
487+
return r.exportedRead(p, false)
488+
}
489+
490+
func (r *messageReader) read(p []byte, lock bool) (int, error) {
491+
if lock {
492+
// If we cannot acquire the read lock, then
493+
// there is either a concurrent read or the close handshake
494+
// is proceeding.
495+
select {
496+
case r.c.readLock <- struct{}{}:
497+
defer r.c.releaseLock(r.c.readLock)
498+
default:
499+
if r.c.closing.Load() == 1 {
500+
<-r.c.closed
501+
return 0, r.c.closeErr
502+
}
503+
return 0, errors.New("concurrent read detected")
485504
}
486-
defer r.c.releaseLock(r.c.readLock)
487505
}
488506

489507
if r.eof() {
490-
return 0, fmt.Errorf("cannot use EOFed reader")
508+
return 0, errors.New("cannot use EOFed reader")
491509
}
492510

493511
if r.c.readMsgLeft <= 0 {
@@ -950,8 +968,6 @@ func (c *Conn) waitClose() error {
950968
return c.closeReceived
951969
}
952970

953-
c.readerShouldLock = false
954-
955971
b := bpool.Get()
956972
buf := b.Bytes()
957973
buf = buf[:cap(buf)]
@@ -965,7 +981,8 @@ func (c *Conn) waitClose() error {
965981
}
966982
}
967983

968-
_, err = io.CopyBuffer(ioutil.Discard, c.activeReader, buf)
984+
r := readerFunc(c.activeReader.readUnlocked)
985+
_, err = io.CopyBuffer(ioutil.Discard, r, buf)
969986
if err != nil {
970987
return err
971988
}
@@ -1019,6 +1036,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
10191036
}
10201037
}
10211038

1039+
type readerFunc func(p []byte) (int, error)
1040+
1041+
func (f readerFunc) Read(p []byte) (int, error) {
1042+
return f(p)
1043+
}
1044+
10221045
type writerFunc func(p []byte) (int, error)
10231046

10241047
func (f writerFunc) Write(p []byte) (int, error) {

0 commit comments

Comments
 (0)