@@ -78,11 +78,10 @@ type Conn struct {
78
78
readLock chan struct {}
79
79
80
80
// messageReader state.
81
- readerMsgCtx context.Context
82
- readerMsgHeader header
83
- readerFrameEOF bool
84
- readerMaskPos int
85
- readerShouldLock bool
81
+ readerMsgCtx context.Context
82
+ readerMsgHeader header
83
+ readerFrameEOF bool
84
+ readerMaskPos int
86
85
87
86
setReadTimeout chan context.Context
88
87
setWriteTimeout chan context.Context
@@ -237,6 +236,10 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
237
236
if h .opcode .controlOp () {
238
237
err = c .handleControl (ctx , h )
239
238
if err != nil {
239
+ // Pass through CloseErrors when receiving a close frame.
240
+ if h .opcode == opClose && CloseStatus (err ) != - 1 {
241
+ return header {}, err
242
+ }
240
243
return header {}, fmt .Errorf ("failed to handle control frame %v: %w" , h .opcode , err )
241
244
}
242
245
continue
@@ -445,7 +448,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
445
448
c .readerFrameEOF = false
446
449
c .readerMaskPos = 0
447
450
c .readMsgLeft = c .msgReadLimit .Load ()
448
- c .readerShouldLock = lock
449
451
450
452
r := & messageReader {
451
453
c : c ,
@@ -465,7 +467,11 @@ func (r *messageReader) eof() bool {
465
467
466
468
// Read reads as many bytes as possible into p.
467
469
func (r * messageReader ) Read (p []byte ) (int , error ) {
468
- n , err := r .read (p )
470
+ return r .exportedRead (p , true )
471
+ }
472
+
473
+ func (r * messageReader ) exportedRead (p []byte , lock bool ) (int , error ) {
474
+ n , err := r .read (p , lock )
469
475
if err != nil {
470
476
// Have to return io.EOF directly for now, we cannot wrap as errors.Is
471
477
// isn't used widely yet.
@@ -477,17 +483,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
477
483
return n , nil
478
484
}
479
485
480
- func (r * messageReader ) read (p []byte ) (int , error ) {
481
- if r .c .readerShouldLock {
482
- err := r .c .acquireLock (r .c .readerMsgCtx , r .c .readLock )
483
- if err != nil {
484
- return 0 , err
486
+ func (r * messageReader ) readUnlocked (p []byte ) (int , error ) {
487
+ return r .exportedRead (p , false )
488
+ }
489
+
490
+ func (r * messageReader ) read (p []byte , lock bool ) (int , error ) {
491
+ if lock {
492
+ // If we cannot acquire the read lock, then
493
+ // there is either a concurrent read or the close handshake
494
+ // is proceeding.
495
+ select {
496
+ case r .c .readLock <- struct {}{}:
497
+ defer r .c .releaseLock (r .c .readLock )
498
+ default :
499
+ if r .c .closing .Load () == 1 {
500
+ <- r .c .closed
501
+ return 0 , r .c .closeErr
502
+ }
503
+ return 0 , errors .New ("concurrent read detected" )
485
504
}
486
- defer r .c .releaseLock (r .c .readLock )
487
505
}
488
506
489
507
if r .eof () {
490
- return 0 , fmt . Errorf ("cannot use EOFed reader" )
508
+ return 0 , errors . New ("cannot use EOFed reader" )
491
509
}
492
510
493
511
if r .c .readMsgLeft <= 0 {
@@ -950,8 +968,6 @@ func (c *Conn) waitClose() error {
950
968
return c .closeReceived
951
969
}
952
970
953
- c .readerShouldLock = false
954
-
955
971
b := bpool .Get ()
956
972
buf := b .Bytes ()
957
973
buf = buf [:cap (buf )]
@@ -965,7 +981,8 @@ func (c *Conn) waitClose() error {
965
981
}
966
982
}
967
983
968
- _ , err = io .CopyBuffer (ioutil .Discard , c .activeReader , buf )
984
+ r := readerFunc (c .activeReader .readUnlocked )
985
+ _ , err = io .CopyBuffer (ioutil .Discard , r , buf )
969
986
if err != nil {
970
987
return err
971
988
}
@@ -1019,6 +1036,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
1019
1036
}
1020
1037
}
1021
1038
1039
+ type readerFunc func (p []byte ) (int , error )
1040
+
1041
+ func (f readerFunc ) Read (p []byte ) (int , error ) {
1042
+ return f (p )
1043
+ }
1044
+
1022
1045
type writerFunc func (p []byte ) (int , error )
1023
1046
1024
1047
func (f writerFunc ) Write (p []byte ) (int , error ) {
0 commit comments