Skip to content

Commit 63f27e2

Browse files
committed
Reduce Reader/Writer allocations
Closes #116
1 parent 80ddbb4 commit 63f27e2

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

websocket.go

+27-16
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,16 @@ type Conn struct {
5656
// read limit for a message in bytes.
5757
msgReadLimit int64
5858

59+
// Used to ensure a previous writer is not used after being closed.
60+
activeWriter *messageWriter
5961
// messageWriter state.
6062
writeMsgOpcode opcode
6163
writeMsgCtx context.Context
6264
readMsgLeft int64
6365

6466
// Used to ensure the previous reader is read till EOF before allowing
6567
// a new one.
66-
previousReader *messageReader
68+
activeReader *messageReader
6769
// readFrameLock is acquired to read from bw.
6870
readFrameLock chan struct{}
6971
readClosed int64
@@ -358,7 +360,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
358360
}
359361

360362
func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
361-
if c.previousReader != nil && !c.readFrameEOF {
363+
if c.activeReader != nil && !c.readFrameEOF {
362364
// The only way we know for sure the previous reader is not yet complete is
363365
// if there is an active frame not yet fully read.
364366
// Otherwise, a user may have read the last byte but not the EOF if the EOF
@@ -371,7 +373,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
371373
return 0, nil, err
372374
}
373375

374-
if c.previousReader != nil && !c.previousReader.eof {
376+
if c.activeReader != nil && !c.activeReader.eof() {
375377
if h.opcode != opContinuation {
376378
err := xerrors.Errorf("received new data message without finishing the previous message")
377379
c.Close(StatusProtocolError, err.Error())
@@ -382,7 +384,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
382384
return 0, nil, xerrors.Errorf("previous message not read to completion")
383385
}
384386

385-
c.previousReader.eof = true
387+
c.activeReader = nil
386388

387389
h, err = c.readTillMsg(ctx)
388390
if err != nil {
@@ -403,7 +405,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
403405
r := &messageReader{
404406
c: c,
405407
}
406-
c.previousReader = r
408+
c.activeReader = r
407409
return MessageType(h.opcode), r, nil
408410
}
409411

@@ -430,8 +432,11 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
430432

431433
// messageReader enables reading a data frame from the WebSocket connection.
432434
type messageReader struct {
433-
c *Conn
434-
eof bool
435+
c *Conn
436+
}
437+
438+
func (r *messageReader) eof() bool {
439+
return r.c.activeReader != r
435440
}
436441

437442
// Read reads as many bytes as possible into p.
@@ -449,7 +454,7 @@ func (r *messageReader) Read(p []byte) (int, error) {
449454
}
450455

451456
func (r *messageReader) read(p []byte) (int, error) {
452-
if r.eof {
457+
if r.eof() {
453458
return 0, xerrors.Errorf("cannot use EOFed reader")
454459
}
455460

@@ -502,7 +507,7 @@ func (r *messageReader) read(p []byte) (int, error) {
502507
r.c.readFrameEOF = true
503508

504509
if h.fin {
505-
r.eof = true
510+
r.c.activeReader = nil
506511
return n, io.EOF
507512
}
508513
}
@@ -593,9 +598,11 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
593598
}
594599
c.writeMsgCtx = ctx
595600
c.writeMsgOpcode = opcode(typ)
596-
return &messageWriter{
601+
w := &messageWriter{
597602
c: c,
598-
}, nil
603+
}
604+
c.activeWriter = w
605+
return w, nil
599606
}
600607

601608
// Write is a convenience method to write a message to the connection.
@@ -622,8 +629,11 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
622629

623630
// messageWriter enables writing to a WebSocket connection.
624631
type messageWriter struct {
625-
c *Conn
626-
closed bool
632+
c *Conn
633+
}
634+
635+
func (w *messageWriter) closed() bool {
636+
return w != w.c.activeWriter
627637
}
628638

629639
// Write writes the given bytes to the WebSocket connection.
@@ -636,7 +646,7 @@ func (w *messageWriter) Write(p []byte) (int, error) {
636646
}
637647

638648
func (w *messageWriter) write(p []byte) (int, error) {
639-
if w.closed {
649+
if w.closed() {
640650
return 0, xerrors.Errorf("cannot use closed writer")
641651
}
642652
n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p)
@@ -658,16 +668,17 @@ func (w *messageWriter) Close() error {
658668
}
659669

660670
func (w *messageWriter) close() error {
661-
if w.closed {
671+
if w.closed() {
662672
return xerrors.Errorf("cannot use closed writer")
663673
}
664-
w.closed = true
674+
w.closed()
665675

666676
_, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil)
667677
if err != nil {
668678
return xerrors.Errorf("failed to write fin frame: %w", err)
669679
}
670680

681+
w.c.activeWriter = nil
671682
w.c.releaseLock(w.c.writeMsgLock)
672683
return nil
673684
}

0 commit comments

Comments
 (0)