@@ -56,14 +56,16 @@ type Conn struct {
56
56
// read limit for a message in bytes.
57
57
msgReadLimit int64
58
58
59
+ // Used to ensure a previous writer is not used after being closed.
60
+ activeWriter * messageWriter
59
61
// messageWriter state.
60
62
writeMsgOpcode opcode
61
63
writeMsgCtx context.Context
62
64
readMsgLeft int64
63
65
64
66
// Used to ensure the previous reader is read till EOF before allowing
65
67
// a new one.
66
- previousReader * messageReader
68
+ activeReader * messageReader
67
69
// readFrameLock is acquired to read from bw.
68
70
readFrameLock chan struct {}
69
71
readClosed int64
@@ -358,7 +360,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
358
360
}
359
361
360
362
func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
361
- if c .previousReader != nil && ! c .readFrameEOF {
363
+ if c .activeReader != nil && ! c .readFrameEOF {
362
364
// The only way we know for sure the previous reader is not yet complete is
363
365
// if there is an active frame not yet fully read.
364
366
// Otherwise, a user may have read the last byte but not the EOF if the EOF
@@ -371,7 +373,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
371
373
return 0 , nil , err
372
374
}
373
375
374
- if c .previousReader != nil && ! c .previousReader .eof {
376
+ if c .activeReader != nil && ! c .activeReader .eof () {
375
377
if h .opcode != opContinuation {
376
378
err := xerrors .Errorf ("received new data message without finishing the previous message" )
377
379
c .Close (StatusProtocolError , err .Error ())
@@ -382,7 +384,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
382
384
return 0 , nil , xerrors .Errorf ("previous message not read to completion" )
383
385
}
384
386
385
- c .previousReader . eof = true
387
+ c .activeReader = nil
386
388
387
389
h , err = c .readTillMsg (ctx )
388
390
if err != nil {
@@ -403,7 +405,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
403
405
r := & messageReader {
404
406
c : c ,
405
407
}
406
- c .previousReader = r
408
+ c .activeReader = r
407
409
return MessageType (h .opcode ), r , nil
408
410
}
409
411
@@ -430,8 +432,11 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
430
432
431
433
// messageReader enables reading a data frame from the WebSocket connection.
432
434
type messageReader struct {
433
- c * Conn
434
- eof bool
435
+ c * Conn
436
+ }
437
+
438
+ func (r * messageReader ) eof () bool {
439
+ return r .c .activeReader != r
435
440
}
436
441
437
442
// Read reads as many bytes as possible into p.
@@ -449,7 +454,7 @@ func (r *messageReader) Read(p []byte) (int, error) {
449
454
}
450
455
451
456
func (r * messageReader ) read (p []byte ) (int , error ) {
452
- if r .eof {
457
+ if r .eof () {
453
458
return 0 , xerrors .Errorf ("cannot use EOFed reader" )
454
459
}
455
460
@@ -502,7 +507,7 @@ func (r *messageReader) read(p []byte) (int, error) {
502
507
r .c .readFrameEOF = true
503
508
504
509
if h .fin {
505
- r .eof = true
510
+ r .c . activeReader = nil
506
511
return n , io .EOF
507
512
}
508
513
}
@@ -593,9 +598,11 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
593
598
}
594
599
c .writeMsgCtx = ctx
595
600
c .writeMsgOpcode = opcode (typ )
596
- return & messageWriter {
601
+ w := & messageWriter {
597
602
c : c ,
598
- }, nil
603
+ }
604
+ c .activeWriter = w
605
+ return w , nil
599
606
}
600
607
601
608
// Write is a convenience method to write a message to the connection.
@@ -622,8 +629,11 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
622
629
623
630
// messageWriter enables writing to a WebSocket connection.
624
631
type messageWriter struct {
625
- c * Conn
626
- closed bool
632
+ c * Conn
633
+ }
634
+
635
+ func (w * messageWriter ) closed () bool {
636
+ return w != w .c .activeWriter
627
637
}
628
638
629
639
// Write writes the given bytes to the WebSocket connection.
@@ -636,7 +646,7 @@ func (w *messageWriter) Write(p []byte) (int, error) {
636
646
}
637
647
638
648
func (w * messageWriter ) write (p []byte ) (int , error ) {
639
- if w .closed {
649
+ if w .closed () {
640
650
return 0 , xerrors .Errorf ("cannot use closed writer" )
641
651
}
642
652
n , err := w .c .writeFrame (w .c .writeMsgCtx , false , w .c .writeMsgOpcode , p )
@@ -658,16 +668,17 @@ func (w *messageWriter) Close() error {
658
668
}
659
669
660
670
func (w * messageWriter ) close () error {
661
- if w .closed {
671
+ if w .closed () {
662
672
return xerrors .Errorf ("cannot use closed writer" )
663
673
}
664
- w .closed = true
674
+ w .closed ()
665
675
666
676
_ , err := w .c .writeFrame (w .c .writeMsgCtx , true , w .c .writeMsgOpcode , nil )
667
677
if err != nil {
668
678
return xerrors .Errorf ("failed to write fin frame: %w" , err )
669
679
}
670
680
681
+ w .c .activeWriter = nil
671
682
w .c .releaseLock (w .c .writeMsgLock )
672
683
return nil
673
684
}
0 commit comments