From f628749790bc308f2c978f72b92666bd7fd5ae58 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 7 Jun 2019 15:41:10 -0400 Subject: [PATCH 01/15] Remove readLoop Closes #93 --- README.md | 14 +-- accept.go | 4 - limitedreader.go | 7 +- websocket.go | 228 +++++++++++++++++----------------------------- websocket_test.go | 54 ++--------- wsjson/wsjson.go | 1 + wspb/wspb.go | 1 + 7 files changed, 103 insertions(+), 206 deletions(-) diff --git a/README.md b/README.md index 47cc3296..38541cfd 100644 --- a/README.md +++ b/README.md @@ -123,24 +123,18 @@ it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2. Some more advantages of nhooyr/websocket are that it supports concurrent writes and makes it very easy to close the connection with a status code and reason. -nhooyr/websocket also responds to pings, pongs and close frames in a separate goroutine so that -your application doesn't always need to read from the connection unless it expects a data message. -gorilla/websocket requires you to constantly read from the connection to respond to control frames -even if you don't expect the peer to send any messages. - The ping API is also much nicer. gorilla/websocket requires registering a pong handler on the Conn which results in awkward control flow. With nhooyr/websocket you use the Ping method on the Conn that sends a ping and also waits for the pong. In terms of performance, the differences depend on your application code. nhooyr/websocket reuses buffers efficiently out of the box if you use the wsjson and wspb subpackages whereas -gorilla/websocket does not. As mentioned above, nhooyr/websocket also supports concurrent +gorilla/websocket does not at all. As mentioned above, nhooyr/websocket also supports concurrent writers out of the box. -The only performance con to nhooyr/websocket is that uses two extra goroutines. One for -reading pings, pongs and close frames async to application code and another to support -context.Context cancellation. This costs 4 KB of memory which is cheap compared -to the benefits. +The only performance con to nhooyr/websocket is that uses one extra goroutine to support +cancellation with context.Context and the net/http client side body upgrade. +This costs 2 KB of memory which is cheap compared to simplicity benefits. ### x/net/websocket diff --git a/accept.go b/accept.go index bf2ed3c8..e0054b2b 100644 --- a/accept.go +++ b/accept.go @@ -81,9 +81,6 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // Accept will reject the handshake if the Origin domain is not the same as the Host unless // the InsecureSkipVerify option is set. In other words, by default it does not allow // cross origin requests. -// -// The returned connection will be bound by r.Context(). Use conn.Context() to change -// the bounding context. func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { c, err := accept(w, r, opts) if err != nil { @@ -143,7 +140,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, closer: netConn, } c.init() - c.Context(r.Context()) return c, nil } diff --git a/limitedreader.go b/limitedreader.go index 63bf40c4..7957e794 100644 --- a/limitedreader.go +++ b/limitedreader.go @@ -1,7 +1,6 @@ package websocket import ( - "fmt" "io" "golang.org/x/xerrors" @@ -20,9 +19,9 @@ func (lr *limitedReader) Read(p []byte) (int, error) { } if lr.left <= 0 { - msg := fmt.Sprintf("read limited at %v bytes", lr.limit) - lr.c.Close(StatusPolicyViolation, msg) - return 0, xerrors.Errorf(msg) + err := xerrors.Errorf("read limited at %v bytes", lr.limit) + lr.c.Close(StatusMessageTooBig, err.Error()) + return 0, err } if int64(len(p)) > lr.left { diff --git a/websocket.go b/websocket.go index 37719932..3553707a 100644 --- a/websocket.go +++ b/websocket.go @@ -28,7 +28,7 @@ type Conn struct { br *bufio.Reader bw *bufio.Writer // writeBuf is used for masking, its the buffer in bufio.Writer. - // Only used by the client. + // Only used by the client for masking the bytes in the buffer. writeBuf []byte closer io.Closer client bool @@ -51,17 +51,9 @@ type Conn struct { previousReader *messageReader // readFrameLock is acquired to read from bw. readFrameLock chan struct{} - // readMsg is used by messageReader to receive frames from - // readLoop. - readMsg chan header - // readMsgDone is used to tell the readLoop to continue after - // messageReader has read a frame. - readMsgDone chan struct{} setReadTimeout chan context.Context setWriteTimeout chan context.Context - setConnContext chan context.Context - getConnContext chan context.Context activePingsMu sync.Mutex activePings map[string]chan<- struct{} @@ -76,13 +68,9 @@ func (c *Conn) init() { c.writeFrameLock = make(chan struct{}, 1) c.readFrameLock = make(chan struct{}, 1) - c.readMsg = make(chan header) - c.readMsgDone = make(chan struct{}) c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) - c.setConnContext = make(chan context.Context) - c.getConnContext = make(chan context.Context) c.activePings = make(map[string]chan<- struct{}) @@ -91,7 +79,6 @@ func (c *Conn) init() { }) go c.timeoutLoop() - go c.readLoop() } // Subprotocol returns the negotiated subprotocol. @@ -131,56 +118,23 @@ func (c *Conn) close(err error) { func (c *Conn) timeoutLoop() { readCtx := context.Background() writeCtx := context.Background() - parentCtx := context.Background() for { select { case <-c.closed: return + case writeCtx = <-c.setWriteTimeout: case readCtx = <-c.setReadTimeout: + case <-readCtx.Done(): c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err())) case <-writeCtx.Done(): c.close(xerrors.Errorf("data write timed out: %w", writeCtx.Err())) - case <-parentCtx.Done(): - c.close(xerrors.Errorf("parent context cancelled: %w", parentCtx.Err())) - return - case parentCtx = <-c.setConnContext: - ctx, cancelCtx := context.WithCancel(parentCtx) - defer cancelCtx() - - select { - case <-c.closed: - return - case c.getConnContext <- ctx: - } } } } -// Context returns a context derived from parent that will be cancelled -// when the connection is closed or broken. -// If the parent context is cancelled, the connection will be closed. -func (c *Conn) Context(parent context.Context) context.Context { - select { - case <-c.closed: - ctx, cancel := context.WithCancel(parent) - cancel() - return ctx - case c.setConnContext <- parent: - } - - select { - case <-c.closed: - ctx, cancel := context.WithCancel(parent) - cancel() - return ctx - case ctx := <-c.getConnContext: - return ctx - } -} - func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { select { case <-ctx.Done(): @@ -210,30 +164,9 @@ func (c *Conn) releaseLock(lock chan struct{}) { } } -func (c *Conn) readLoop() { +func (c *Conn) readTillMsg(ctx context.Context) (header, error) { for { - h, err := c.readTillMsg() - if err != nil { - return - } - - select { - case <-c.closed: - return - case c.readMsg <- h: - } - - select { - case <-c.closed: - return - case <-c.readMsgDone: - } - } -} - -func (c *Conn) readTillMsg() (header, error) { - for { - h, err := c.readFrameHeader() + h, err := c.readFrameHeader(ctx) if err != nil { return header{}, err } @@ -245,7 +178,10 @@ func (c *Conn) readTillMsg() (header, error) { } if h.opcode.controlOp() { - c.handleControl(h) + err = c.handleControl(ctx, h) + if err != nil { + return header{}, err + } continue } @@ -260,43 +196,64 @@ func (c *Conn) readTillMsg() (header, error) { } } -func (c *Conn) readFrameHeader() (header, error) { +func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { err := c.acquireLock(context.Background(), c.readFrameLock) if err != nil { return header{}, err } defer c.releaseLock(c.readFrameLock) + select { + case <-c.closed: + return header{}, c.closeErr + case c.setReadTimeout <- ctx: + } + h, err := readHeader(c.br) if err != nil { + select { + case <-c.closed: + return header{}, c.closeErr + case <-ctx.Done(): + err = ctx.Err() + default: + } err := xerrors.Errorf("failed to read header: %w", err) c.releaseLock(c.readFrameLock) c.close(err) return header{}, err } + select { + case <-c.closed: + return header{}, c.closeErr + case c.setReadTimeout <- context.Background(): + } + return h, nil } -func (c *Conn) handleControl(h header) { +func (c *Conn) handleControl(ctx context.Context, h header) error { if h.payloadLength > maxControlFramePayload { - c.Close(StatusProtocolError, "control frame too large") - return + err := xerrors.Errorf("control frame too large at %v bytes", h.payloadLength) + c.Close(StatusProtocolError, err.Error()) + return err } if !h.fin { - c.Close(StatusProtocolError, "control frame cannot be fragmented") - return + err := xerrors.Errorf("received fragmented control frame") + c.Close(StatusProtocolError, err.Error()) + return err } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() b := make([]byte, h.payloadLength) _, err := c.readFramePayload(ctx, b) if err != nil { - return + return err } if h.masked { @@ -305,7 +262,7 @@ func (c *Conn) handleControl(h header) { switch h.opcode { case opPing: - c.writePong(b) + return c.writePong(b) case opPong: c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] @@ -313,17 +270,20 @@ func (c *Conn) handleControl(h header) { if ok { close(pong) } + return nil case opClose: ce, err := parseClosePayload(b) if err != nil { - c.close(xerrors.Errorf("received invalid close payload: %w", err)) - return + err = xerrors.Errorf("received invalid close payload: %w", err) + c.close(err) + return err } if ce.Code == StatusNoStatusRcvd { c.writeClose(nil, ce) } else { c.Close(ce.Code, ce.Reason) } + return c.closeErr default: panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) } @@ -335,11 +295,10 @@ func (c *Conn) handleControl(h header) { // The passed context will also bound the reader. // Ensure you read to EOF otherwise the connection will hang. // -// Control (ping, pong, close) frames will be handled automatically -// in a separate goroutine so if you do not expect any data messages, -// you do not need to read from the connection. However, if the peer -// sends a data message, further pings, pongs and close frames will not -// be read if you do not read the message from the connection. +// You must read from the connection for close frames to be read. +// If you do not expect any data messages from the peer, just call +// Reader in a separate goroutine and close the connection with StatusPolicyViolation +// when it returns. Example at // TODO // // Only one Reader may be open at a time. // @@ -368,47 +327,39 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, xerrors.Errorf("previous message not read to completion") } - select { - case <-c.closed: - return 0, nil, c.closeErr - case <-ctx.Done(): - return 0, nil, ctx.Err() - case h := <-c.readMsg: - if c.previousReader != nil && !c.previousReader.done { - if h.opcode != opContinuation { - err := xerrors.Errorf("received new data message without finishing the previous message") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err - } - - if !h.fin || h.payloadLength > 0 { - return 0, nil, xerrors.Errorf("previous message not read to completion") - } - - c.previousReader.done = true - - select { - case <-c.closed: - return 0, nil, c.closeErr - case c.readMsgDone <- struct{}{}: - } + h, err := c.readTillMsg(ctx) + if err != nil { + return 0, nil, err + } - return c.reader(ctx) - } else if h.opcode == opContinuation { - err := xerrors.Errorf("received continuation frame not after data or text frame") + if c.previousReader != nil && !c.previousReader.done { + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data message without finishing the previous message") c.Close(StatusProtocolError, err.Error()) return 0, nil, err } - r := &messageReader{ - ctx: ctx, - c: c, - - h: &h, + if !h.fin || h.payloadLength > 0 { + return 0, nil, xerrors.Errorf("previous message not read to completion") } - c.previousReader = r - return MessageType(h.opcode), r, nil + + c.previousReader.done = true + + return c.reader(ctx) + } else if h.opcode == opContinuation { + err := xerrors.Errorf("received continuation frame not after data or text frame") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err + } + + r := &messageReader{ + ctx: ctx, + c: c, + + h: &h, } + c.previousReader = r + return MessageType(h.opcode), r, nil } // messageReader enables reading a data frame from the WebSocket connection. @@ -441,20 +392,17 @@ func (r *messageReader) read(p []byte) (int, error) { } if r.h == nil { - select { - case <-r.c.closed: - return 0, r.c.closeErr - case <-r.ctx.Done(): - r.c.close(xerrors.Errorf("failed to read: %w", r.ctx.Err())) - return 0, r.ctx.Err() - case h := <-r.c.readMsg: - if h.opcode != opContinuation { - err := xerrors.Errorf("received new data frame without finishing the previous frame") - r.c.Close(StatusProtocolError, err.Error()) - return 0, err - } - r.h = &h + h, err := r.c.readTillMsg(r.ctx) + if err != nil { + return 0, err } + + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data frame without finishing the previous frame") + r.c.Close(StatusProtocolError, err.Error()) + return 0, err + } + r.h = &h } if int64(len(p)) > r.h.payloadLength { @@ -473,12 +421,6 @@ func (r *messageReader) read(p []byte) (int, error) { } if r.h.payloadLength == 0 { - select { - case <-r.c.closed: - return n, r.c.closeErr - case r.c.readMsgDone <- struct{}{}: - } - fin := r.h.fin // Need to nil this as Reader uses it to check @@ -539,7 +481,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { // // By default, the connection has a message read limit of 32768 bytes. // -// When the limit is hit, the connection will be closed with StatusPolicyViolation. +// When the limit is hit, the connection will be closed with StatusMessageTooBig. func (c *Conn) SetReadLimit(n int64) { c.msgReadLimit = n } diff --git a/websocket_test.go b/websocket_test.go index 9d867b50..8d1e7b1d 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -383,6 +383,8 @@ func TestHandshake(t *testing.T) { } defer c.Close(websocket.StatusInternalError, "") + go c.Reader(r.Context()) + err = c.Ping(r.Context()) if err != nil { return err @@ -403,10 +405,10 @@ func TestHandshake(t *testing.T) { } defer c.Close(websocket.StatusInternalError, "") - err = c.Ping(ctx) - if err != nil { - return err - } + errc := make(chan error, 1) + go func() { + errc <- c.Ping(ctx) + }() _, _, err = c.Read(ctx) if err != nil { @@ -414,7 +416,7 @@ func TestHandshake(t *testing.T) { } c.Close(websocket.StatusNormalClosure, "") - return nil + return <-errc }, }, { @@ -439,6 +441,8 @@ func TestHandshake(t *testing.T) { } defer c.Close(websocket.StatusInternalError, "") + go c.Reader(ctx) + err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) if err != nil { return err @@ -454,46 +458,6 @@ func TestHandshake(t *testing.T) { return nil }, }, - { - name: "context", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(r.Context(), time.Second) - defer cancel() - - c.Context(ctx) - - for r.Context().Err() == nil { - err = c.Ping(ctx) - if err != nil { - return nil - } - } - - return xerrors.Errorf("all pings succeeded") - }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - cctx := c.Context(ctx) - - select { - case <-ctx.Done(): - return xerrors.Errorf("child context never cancelled") - case <-cctx.Done(): - return nil - } - }, - }, } for _, tc := range testCases { diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 19e3e6d7..b72d562f 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -44,6 +44,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { err = json.Unmarshal(b.Bytes(), v) if err != nil { + c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") return xerrors.Errorf("failed to unmarshal json: %w", err) } diff --git a/wspb/wspb.go b/wspb/wspb.go index 49c2ae54..56b14ee8 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -46,6 +46,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { err = proto.Unmarshal(b.Bytes(), v) if err != nil { + c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") return xerrors.Errorf("failed to unmarshal protobuf: %w", err) } From 5add79dcd311c286a695a022058d21e46bbc534c Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 7 Jun 2019 17:21:27 -0400 Subject: [PATCH 02/15] Simplify and improve error messages --- internal/bpool/bpool_test.go | 1 - websocket.go | 36 +++++++++++++++++++----------------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/internal/bpool/bpool_test.go b/internal/bpool/bpool_test.go index 2b302a47..5dfe56e6 100644 --- a/internal/bpool/bpool_test.go +++ b/internal/bpool/bpool_test.go @@ -32,7 +32,6 @@ func BenchmarkSyncPool(b *testing.B) { p := sync.Pool{} - b.ResetTimer() for i := 0; i < b.N; i++ { buf := p.Get() if buf == nil { diff --git a/websocket.go b/websocket.go index 3553707a..91197537 100644 --- a/websocket.go +++ b/websocket.go @@ -180,7 +180,7 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { if h.opcode.controlOp() { err = c.handleControl(ctx, h) if err != nil { - return header{}, err + return header{}, xerrors.Errorf("failed to handle control frame: %w", err) } continue } @@ -274,15 +274,10 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { case opClose: ce, err := parseClosePayload(b) if err != nil { - err = xerrors.Errorf("received invalid close payload: %w", err) - c.close(err) - return err - } - if ce.Code == StatusNoStatusRcvd { - c.writeClose(nil, ce) - } else { - c.Close(ce.Code, ce.Reason) + c.Close(StatusProtocolError, "received invalid close payload") + return xerrors.Errorf("received invalid close payload: %w", err) } + c.writeClose(b, ce, false) return c.closeErr default: panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) @@ -398,7 +393,7 @@ func (r *messageReader) read(p []byte) (int, error) { } if h.opcode != opContinuation { - err := xerrors.Errorf("received new data frame without finishing the previous frame") + err := xerrors.Errorf("received new data message without finishing the previous message") r.c.Close(StatusProtocolError, err.Error()) return 0, err } @@ -461,7 +456,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { err = ctx.Err() default: } - err = xerrors.Errorf("failed to read from connection: %w", err) + err = xerrors.Errorf("failed to read frame payload: %w", err) c.releaseLock(c.readFrameLock) c.close(err) return n, err @@ -651,7 +646,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte default: } - err = xerrors.Errorf("failed to write to connection: %w", err) + err = xerrors.Errorf("failed to write frame: %w", err) // We need to release the lock first before closing the connection to ensure // the lock can be acquired inside close to ensure no one can access c.bw. c.releaseLock(c.writeFrameLock) @@ -764,20 +759,27 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { p, _ = ce.bytes() } - return c.writeClose(p, ce) + return c.writeClose(p, ce, true) } -func (c *Conn) writeClose(p []byte, cerr CloseError) error { +func (c *Conn) writeClose(p []byte, err error, us bool) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeControl(ctx, opClose, p) + // If this fails, the connection had to have died. + err = c.writeControl(ctx, opClose, p) if err != nil { return err } - c.close(cerr) - if !xerrors.Is(c.closeErr, cerr) { + if us { + err = xerrors.Errorf("sent close frame: %w", err) + } else { + err = xerrors.Errorf("received close frame: %w", err) + } + + c.close(err) + if !xerrors.Is(c.closeErr, err) { return c.closeErr } From 5404d35122bc0e869b45a4b90330a6886ebb1d2d Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 7 Jun 2019 17:33:41 -0400 Subject: [PATCH 03/15] Make CI pass --- websocket.go | 12 ++++++------ websocket_test.go | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/websocket.go b/websocket.go index 91197537..0e263dd3 100644 --- a/websocket.go +++ b/websocket.go @@ -762,24 +762,24 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { return c.writeClose(p, ce, true) } -func (c *Conn) writeClose(p []byte, err error, us bool) error { +func (c *Conn) writeClose(p []byte, cerr error, us bool) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() // If this fails, the connection had to have died. - err = c.writeControl(ctx, opClose, p) + err := c.writeControl(ctx, opClose, p) if err != nil { return err } if us { - err = xerrors.Errorf("sent close frame: %w", err) + cerr = xerrors.Errorf("sent close frame: %w", cerr) } else { - err = xerrors.Errorf("received close frame: %w", err) + cerr = xerrors.Errorf("received close frame: %w", cerr) } - c.close(err) - if !xerrors.Is(c.closeErr, err) { + c.close(cerr) + if !xerrors.Is(c.closeErr, cerr) { return c.closeErr } diff --git a/websocket_test.go b/websocket_test.go index 8d1e7b1d..dd72afb4 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -451,7 +451,7 @@ func TestHandshake(t *testing.T) { err = c.Ping(ctx) var ce websocket.CloseError - if !xerrors.As(err, &ce) || ce.Code != websocket.StatusPolicyViolation { + if !xerrors.As(err, &ce) || ce.Code != websocket.StatusMessageTooBig { return xerrors.Errorf("unexpected error: %w", err) } From df60edfe1d248784520291315436581829b8d998 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 7 Jun 2019 17:38:28 -0400 Subject: [PATCH 04/15] Reuse header and control payload buffers --- header.go | 13 ++++++++++--- header_test.go | 4 ++-- websocket.go | 11 ++++++++--- 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/header.go b/header.go index 62b30b38..b1aa2554 100644 --- a/header.go +++ b/header.go @@ -75,12 +75,19 @@ func marshalHeader(h header) []byte { return b } +func makeHeaderBuf() []byte { + return make([]byte, maxHeaderSize-2) +} + // readHeader reads a header from the reader. // See https://tools.ietf.org/html/rfc6455#section-5.2 -func readHeader(r io.Reader) (header, error) { - // We read the first two bytes directly so that we know +func readHeader(b []byte, r io.Reader) (header, error) { + if b == nil { + b = makeHeaderBuf() + } + // We read the first two bytes first so that we know // exactly how long the header is. - b := make([]byte, 2, maxHeaderSize-2) + b = b[:2] _, err := io.ReadFull(r, b) if err != nil { return header{}, err diff --git a/header_test.go b/header_test.go index b9cf351b..78d61899 100644 --- a/header_test.go +++ b/header_test.go @@ -32,7 +32,7 @@ func TestHeader(t *testing.T) { b[2] |= 1 << 7 r := bytes.NewReader(b) - _, err := readHeader(r) + _, err := readHeader(nil, r) if err == nil { t.Fatalf("unexpected error value: %+v", err) } @@ -92,7 +92,7 @@ func TestHeader(t *testing.T) { func testHeader(t *testing.T, h header) { b := marshalHeader(h) r := bytes.NewReader(b) - h2, err := readHeader(r) + h2, err := readHeader(nil, r) if err != nil { t.Logf("header: %#v", h) t.Logf("bytes: %b", b) diff --git a/websocket.go b/websocket.go index 0e263dd3..ebe12597 100644 --- a/websocket.go +++ b/websocket.go @@ -57,6 +57,9 @@ type Conn struct { activePingsMu sync.Mutex activePings map[string]chan<- struct{} + + headerBuf []byte + controlPayloadBuf []byte } func (c *Conn) init() { @@ -74,6 +77,9 @@ func (c *Conn) init() { c.activePings = make(map[string]chan<- struct{}) + c.headerBuf = makeHeaderBuf() + c.controlPayloadBuf = make([]byte, maxControlFramePayload) + runtime.SetFinalizer(c, func(c *Conn) { c.close(xerrors.New("connection garbage collected")) }) @@ -209,7 +215,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.setReadTimeout <- ctx: } - h, err := readHeader(c.br) + h, err := readHeader(c.headerBuf, c.br) if err != nil { select { case <-c.closed: @@ -249,8 +255,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - b := make([]byte, h.payloadLength) - + b := c.controlPayloadBuf[:h.payloadLength] _, err := c.readFramePayload(ctx, b) if err != nil { return err From ee1f3c601b22b62b9f82c2a2646c9611c1ea838e Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 7 Jun 2019 18:14:56 -0400 Subject: [PATCH 05/15] Reuse write and read header buffers Next is reusing the header structures. --- header.go | 18 ++++++++++++++---- header_test.go | 4 ++-- websocket.go | 19 ++++++++++--------- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/header.go b/header.go index b1aa2554..16ab6474 100644 --- a/header.go +++ b/header.go @@ -31,10 +31,19 @@ type header struct { maskKey [4]byte } +func makeWriteHeaderBuf() []byte { + return make([]byte, maxHeaderSize) +} + // bytes returns the bytes of the header. // See https://tools.ietf.org/html/rfc6455#section-5.2 -func marshalHeader(h header) []byte { - b := make([]byte, 2, maxHeaderSize) +func writeHeader(b []byte, h header) []byte { + if b == nil { + b = makeWriteHeaderBuf() + } + + b = b[:2] + b[0] = 0 if h.fin { b[0] |= 1 << 7 @@ -75,7 +84,7 @@ func marshalHeader(h header) []byte { return b } -func makeHeaderBuf() []byte { +func makeReadHeaderBuf() []byte { return make([]byte, maxHeaderSize-2) } @@ -83,8 +92,9 @@ func makeHeaderBuf() []byte { // See https://tools.ietf.org/html/rfc6455#section-5.2 func readHeader(b []byte, r io.Reader) (header, error) { if b == nil { - b = makeHeaderBuf() + b = makeReadHeaderBuf() } + // We read the first two bytes first so that we know // exactly how long the header is. b = b[:2] diff --git a/header_test.go b/header_test.go index 78d61899..b45854ea 100644 --- a/header_test.go +++ b/header_test.go @@ -24,7 +24,7 @@ func TestHeader(t *testing.T) { t.Run("readNegativeLength", func(t *testing.T) { t.Parallel() - b := marshalHeader(header{ + b := writeHeader(nil, header{ payloadLength: 1<<16 + 1, }) @@ -90,7 +90,7 @@ func TestHeader(t *testing.T) { } func testHeader(t *testing.T, h header) { - b := marshalHeader(h) + b := writeHeader(nil, h) r := bytes.NewReader(b) h2, err := readHeader(nil, r) if err != nil { diff --git a/websocket.go b/websocket.go index ebe12597..375685e7 100644 --- a/websocket.go +++ b/websocket.go @@ -45,21 +45,21 @@ type Conn struct { // writeFrameLock is acquired to write a single frame. // Effectively meaning whoever holds it gets to write to bw. writeFrameLock chan struct{} + writeHeaderBuf []byte // Used to ensure the previous reader is read till EOF before allowing // a new one. previousReader *messageReader // readFrameLock is acquired to read from bw. - readFrameLock chan struct{} + readFrameLock chan struct{} + readHeaderBuf []byte + controlPayloadBuf []byte setReadTimeout chan context.Context setWriteTimeout chan context.Context activePingsMu sync.Mutex activePings map[string]chan<- struct{} - - headerBuf []byte - controlPayloadBuf []byte } func (c *Conn) init() { @@ -77,7 +77,8 @@ func (c *Conn) init() { c.activePings = make(map[string]chan<- struct{}) - c.headerBuf = makeHeaderBuf() + c.writeHeaderBuf = makeWriteHeaderBuf() + c.readHeaderBuf = makeReadHeaderBuf() c.controlPayloadBuf = make([]byte, maxControlFramePayload) runtime.SetFinalizer(c, func(c *Conn) { @@ -215,7 +216,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.setReadTimeout <- ctx: } - h, err := readHeader(c.headerBuf, c.br) + h, err := readHeader(c.readHeaderBuf, c.br) if err != nil { select { case <-c.closed: @@ -628,7 +629,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte } } - b2 := marshalHeader(h) + headerBytes := writeHeader(c.writeHeaderBuf, h) err := c.acquireLock(ctx, c.writeFrameLock) if err != nil { @@ -651,7 +652,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte default: } - err = xerrors.Errorf("failed to write frame: %w", err) + err = xerrors.Errorf("failed to write %v frame: %w", h.opcode, err) // We need to release the lock first before closing the connection to ensure // the lock can be acquired inside close to ensure no one can access c.bw. c.releaseLock(c.writeFrameLock) @@ -660,7 +661,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte return err } - _, err = c.bw.Write(b2) + _, err = c.bw.Write(headerBytes) if err != nil { return 0, writeErr(err) } From 4357cbf9cd4e7f80e18aac7d14d0ae9c77aced75 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 7 Jun 2019 18:25:15 -0400 Subject: [PATCH 06/15] Add WriteOnly example --- example_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ websocket.go | 5 ++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/example_test.go b/example_test.go index 57f0aa5e..bc10209e 100644 --- a/example_test.go +++ b/example_test.go @@ -59,3 +59,45 @@ func ExampleDial() { c.Close(websocket.StatusNormalClosure, "") } + +func ExampleWriteOnly() { + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + log.Println(err) + return + } + defer c.Close(websocket.StatusInternalError, "the sky is falling") + + ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10) + defer cancel() + + go func() { + defer cancel() + _, _, err := c.Reader(ctx) + if err == nil { + c.Close(websocket.StatusPolicyViolation, "server doesn't accept data messages") + } + }() + + t := time.NewTicker(time.Second * 30) + defer t.Stop() + + for { + select { + case <-ctx.Done(): + c.Close(websocket.StatusNormalClosure, "") + return + case <-t.C: + err = wsjson.Write(ctx, c, "hi") + if err != nil { + log.Println(err) + return + } + } + } + }) + + err := http.ListenAndServe("localhost:8080", fn) + log.Fatal(err) +} diff --git a/websocket.go b/websocket.go index 375685e7..bd087d51 100644 --- a/websocket.go +++ b/websocket.go @@ -21,6 +21,9 @@ import ( // All methods may be called concurrently except for Reader, Read // and SetReadLimit. // +// You must always read from the connection. Otherwise control +// frames will not be handled. See the docs on Reader. +// // Please be sure to call Close on the connection when you // are finished with it to release the associated resources. type Conn struct { @@ -299,7 +302,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { // You must read from the connection for close frames to be read. // If you do not expect any data messages from the peer, just call // Reader in a separate goroutine and close the connection with StatusPolicyViolation -// when it returns. Example at // TODO +// when it returns. See the WriteOnly example. // // Only one Reader may be open at a time. // From 0ed9c744fa64fa1b20caaa0394e592298a36bd80 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 7 Jun 2019 18:25:58 -0400 Subject: [PATCH 07/15] Fix race in writeFrame --- websocket.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/websocket.go b/websocket.go index bd087d51..70f40756 100644 --- a/websocket.go +++ b/websocket.go @@ -632,8 +632,6 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte } } - headerBytes := writeHeader(c.writeHeaderBuf, h) - err := c.acquireLock(ctx, c.writeFrameLock) if err != nil { return 0, err @@ -664,6 +662,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte return err } + headerBytes := writeHeader(c.writeHeaderBuf, h) _, err = c.bw.Write(headerBytes) if err != nil { return 0, writeErr(err) From d76d893a4a4421d406bff6712c7699658eb59f45 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 10 Jun 2019 00:27:30 -0400 Subject: [PATCH 08/15] Improve write structure --- websocket.go | 92 ++++++++++++++++++++++++++++------------------------ 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/websocket.go b/websocket.go index 70f40756..0c525b58 100644 --- a/websocket.go +++ b/websocket.go @@ -286,7 +286,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { c.Close(StatusProtocolError, "received invalid close payload") return xerrors.Errorf("received invalid close payload: %w", err) } - c.writeClose(b, ce, false) + c.writeClose(b, xerrors.Errorf("received close frame: %w", ce)) return c.closeErr default: panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) @@ -644,38 +644,54 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte case c.setWriteTimeout <- ctx: } - writeErr := func(err error) error { - select { - case <-c.closed: - return c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } - - err = xerrors.Errorf("failed to write %v frame: %w", h.opcode, err) - // We need to release the lock first before closing the connection to ensure - // the lock can be acquired inside close to ensure no one can access c.bw. - c.releaseLock(c.writeFrameLock) - c.close(err) + n, err := c.realWriteFrame(ctx, h, p) + if err != nil { + return n, err + } - return err + // We already finished writing, no need to potentially brick the connection if + // the context expires. + select { + case <-c.closed: + return n, c.closeErr + case c.setWriteTimeout <- context.Background(): } + return n, nil +} + +func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error){ + defer func() { + if err != nil { + select { + case <-c.closed: + err = c.closeErr + case <-ctx.Done(): + err = ctx.Err() + default: + } + + err = xerrors.Errorf("failed to write %v frame: %w", h.opcode, err) + // We need to release the lock first before closing the connection to ensure + // the lock can be acquired inside close to ensure no one can access c.bw. + c.releaseLock(c.writeFrameLock) + c.close(err) + } + }() + headerBytes := writeHeader(c.writeHeaderBuf, h) _, err = c.bw.Write(headerBytes) if err != nil { - return 0, writeErr(err) + return 0, err } - var n int if c.client { var keypos int for len(p) > 0 { if c.bw.Available() == 0 { err = c.bw.Flush() if err != nil { - return n, writeErr(err) + return n, err } } @@ -689,7 +705,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte n2, err := c.bw.Write(p2) if err != nil { - return n, writeErr(err) + return n, err } keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2]) @@ -700,25 +716,17 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte } else { n, err = c.bw.Write(p) if err != nil { - return n, writeErr(err) + return n, err } } - if fin { + if h.fin { err = c.bw.Flush() if err != nil { - return n, writeErr(err) + return n, err } } - // We already finished writing, no need to potentially brick the connection if - // the context expires. - select { - case <-c.closed: - return n, c.closeErr - case c.setWriteTimeout <- context.Background(): - } - return n, nil } @@ -767,10 +775,19 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { p, _ = ce.bytes() } - return c.writeClose(p, ce, true) + err = c.writeClose(p, xerrors.Errorf("sent close frame: %w", ce)) + if err != nil { + return err + } + + if !xerrors.Is(c.closeErr, ce) { + return c.closeErr + } + + return nil } -func (c *Conn) writeClose(p []byte, cerr error, us bool) error { +func (c *Conn) writeClose(p []byte, cerr error) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() @@ -780,16 +797,7 @@ func (c *Conn) writeClose(p []byte, cerr error, us bool) error { return err } - if us { - cerr = xerrors.Errorf("sent close frame: %w", cerr) - } else { - cerr = xerrors.Errorf("received close frame: %w", cerr) - } - c.close(cerr) - if !xerrors.Is(c.closeErr, cerr) { - return c.closeErr - } return nil } From 4234de22a59c15774a05f73580f34f0d1c71b86b Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 10 Jun 2019 00:27:44 -0400 Subject: [PATCH 09/15] Fix docs --- example_test.go | 4 +++- websocket.go | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/example_test.go b/example_test.go index bc10209e..eef0e98d 100644 --- a/example_test.go +++ b/example_test.go @@ -60,7 +60,9 @@ func ExampleDial() { c.Close(websocket.StatusNormalClosure, "") } -func ExampleWriteOnly() { +// This example shows how to correctly handle a WebSocket connection +// on which you will only write and do not expect to read data messages. +func Example_writeOnly() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err != nil { diff --git a/websocket.go b/websocket.go index 0c525b58..8d6088c8 100644 --- a/websocket.go +++ b/websocket.go @@ -302,7 +302,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { // You must read from the connection for close frames to be read. // If you do not expect any data messages from the peer, just call // Reader in a separate goroutine and close the connection with StatusPolicyViolation -// when it returns. See the WriteOnly example. +// when it returns. See the writeOnly example. // // Only one Reader may be open at a time. // From 029e4124defb2c9c845979a53ffe849bbf6ef926 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 10 Jun 2019 00:32:01 -0400 Subject: [PATCH 10/15] Fix CI --- websocket.go | 2 +- websocket_test.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/websocket.go b/websocket.go index 8d6088c8..129f82ff 100644 --- a/websocket.go +++ b/websocket.go @@ -660,7 +660,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte return n, nil } -func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error){ +func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error) { defer func() { if err != nil { select { diff --git a/websocket_test.go b/websocket_test.go index dd72afb4..17444642 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -415,8 +415,9 @@ func TestHandshake(t *testing.T) { return err } + err = <-errc c.Close(websocket.StatusNormalClosure, "") - return <-errc + return err }, }, { From 0b8b974d4148b600a9d1817738768e0987ea8fcf Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 10 Jun 2019 03:42:53 -0400 Subject: [PATCH 11/15] Reduce allocation overhea to absolute minimum Can't go any lower than this afaict. 16 bytes per Writer and 24 bytes per Reader. go tool pprof agrees with me on bytes per op but says the allocs per op are 3 instead of 4 and thinks echoLoop is allocating. I don't know. Lots of cleanup can be performed. Closes #95 --- limitedreader.go | 33 ----------- websocket.go | 142 ++++++++++++++++++++++++++-------------------- websocket_test.go | 2 +- xor_test.go | 6 ++ 4 files changed, 86 insertions(+), 97 deletions(-) delete mode 100644 limitedreader.go diff --git a/limitedreader.go b/limitedreader.go deleted file mode 100644 index 7957e794..00000000 --- a/limitedreader.go +++ /dev/null @@ -1,33 +0,0 @@ -package websocket - -import ( - "io" - - "golang.org/x/xerrors" -) - -type limitedReader struct { - c *Conn - r io.Reader - left int64 - limit int64 -} - -func (lr *limitedReader) Read(p []byte) (int, error) { - if lr.limit == 0 { - lr.limit = lr.left - } - - if lr.left <= 0 { - err := xerrors.Errorf("read limited at %v bytes", lr.limit) - lr.c.Close(StatusMessageTooBig, err.Error()) - return 0, err - } - - if int64(len(p)) > lr.left { - p = p[:lr.left] - } - n, err := lr.r.Read(p) - lr.left -= int64(n) - return n, err -} diff --git a/websocket.go b/websocket.go index 129f82ff..2efc485d 100644 --- a/websocket.go +++ b/websocket.go @@ -49,6 +49,11 @@ type Conn struct { // Effectively meaning whoever holds it gets to write to bw. writeFrameLock chan struct{} writeHeaderBuf []byte + writeHeader *header + + // messageWriter state. + writeMsgOpcode opcode + writeMsgCtx context.Context // Used to ensure the previous reader is read till EOF before allowing // a new one. @@ -58,6 +63,12 @@ type Conn struct { readHeaderBuf []byte controlPayloadBuf []byte + // messageReader state + readMsgCtx context.Context + readMsgHeader header + readFrameEOF bool + readMaskPos int + setReadTimeout chan context.Context setWriteTimeout chan context.Context @@ -81,6 +92,7 @@ func (c *Conn) init() { c.activePings = make(map[string]chan<- struct{}) c.writeHeaderBuf = makeWriteHeaderBuf() + c.writeHeader = &header{} c.readHeaderBuf = makeReadHeaderBuf() c.controlPayloadBuf = make([]byte, maxControlFramePayload) @@ -315,15 +327,11 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { if err != nil { return 0, nil, xerrors.Errorf("failed to get reader: %w", err) } - return typ, &limitedReader{ - c: c, - r: r, - left: c.msgReadLimit, - }, nil + return typ, r, nil } func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { - if c.previousReader != nil && c.previousReader.h != nil { + if c.previousReader != nil && !c.readFrameEOF { // The only way we know for sure the previous reader is not yet complete is // if there is an active frame not yet fully read. // Otherwise, a user may have read the last byte but not the EOF if the EOF @@ -336,7 +344,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, err } - if c.previousReader != nil && !c.previousReader.done { + if c.previousReader != nil && !c.previousReader.eof { if h.opcode != opContinuation { err := xerrors.Errorf("received new data message without finishing the previous message") c.Close(StatusProtocolError, err.Error()) @@ -347,20 +355,26 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, xerrors.Errorf("previous message not read to completion") } - c.previousReader.done = true + c.previousReader.eof = true - return c.reader(ctx) + h, err = c.readTillMsg(ctx) + if err != nil { + return 0, nil, err + } } else if h.opcode == opContinuation { err := xerrors.Errorf("received continuation frame not after data or text frame") c.Close(StatusProtocolError, err.Error()) return 0, nil, err } - r := &messageReader{ - ctx: ctx, - c: c, + c.readMsgCtx = ctx + c.readMsgHeader = h + c.readFrameEOF = false + c.readMaskPos = 0 - h: &h, + r := &messageReader{ + c: c, + left: c.msgReadLimit, } c.previousReader = r return MessageType(h.opcode), r, nil @@ -368,12 +382,9 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { - ctx context.Context - c *Conn - - h *header - maskPos int - done bool + c *Conn + left int64 + eof bool } // Read reads as many bytes as possible into p. @@ -391,12 +402,22 @@ func (r *messageReader) Read(p []byte) (int, error) { } func (r *messageReader) read(p []byte) (int, error) { - if r.done { + if r.eof { return 0, xerrors.Errorf("cannot use EOFed reader") } - if r.h == nil { - h, err := r.c.readTillMsg(r.ctx) + if r.left <= 0 { + err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit) + r.c.Close(StatusMessageTooBig, err.Error()) + return 0, err + } + + if int64(len(p)) > r.left { + p = p[:r.left] + } + + if r.c.readFrameEOF { + h, err := r.c.readTillMsg(r.c.readMsgCtx) if err != nil { return 0, err } @@ -406,38 +427,37 @@ func (r *messageReader) read(p []byte) (int, error) { r.c.Close(StatusProtocolError, err.Error()) return 0, err } - r.h = &h + + r.c.readMsgHeader = h + r.c.readFrameEOF = false + r.c.readMaskPos = 0 } - if int64(len(p)) > r.h.payloadLength { - p = p[:r.h.payloadLength] + h := r.c.readMsgHeader + if int64(len(p)) > h.payloadLength { + p = p[:h.payloadLength] } - n, err := r.c.readFramePayload(r.ctx, p) + n, err := r.c.readFramePayload(r.c.readMsgCtx, p) - r.h.payloadLength -= int64(n) - if r.h.masked { - r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p) + h.payloadLength -= int64(n) + r.left -= int64(n) + if h.masked { + r.c.readMaskPos = fastXOR(h.maskKey, r.c.readMaskPos, p) } + r.c.readMsgHeader = h if err != nil { return n, err } - if r.h.payloadLength == 0 { - fin := r.h.fin - - // Need to nil this as Reader uses it to check - // whether there is active data on the previous reader and - // now there isn't. - r.h = nil + if h.payloadLength == 0 { + r.c.readFrameEOF = true - if fin { - r.done = true + if h.fin { + r.eof = true return n, io.EOF } - - r.maskPos = 0 } return n, nil @@ -524,10 +544,10 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err if err != nil { return nil, err } + c.writeMsgCtx = ctx + c.writeMsgOpcode = opcode(typ) return &messageWriter{ - ctx: ctx, - opcode: opcode(typ), - c: c, + c: c, }, nil } @@ -556,8 +576,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error // messageWriter enables writing to a WebSocket connection. type messageWriter struct { - ctx context.Context - opcode opcode c *Conn closed bool } @@ -575,11 +593,11 @@ func (w *messageWriter) write(p []byte) (int, error) { if w.closed { return 0, xerrors.Errorf("cannot use closed writer") } - n, err := w.c.writeFrame(w.ctx, false, w.opcode, p) + n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p) if err != nil { return n, xerrors.Errorf("failed to write data frame: %w", err) } - w.opcode = opContinuation + w.c.writeMsgOpcode = opContinuation return n, nil } @@ -599,7 +617,7 @@ func (w *messageWriter) close() error { } w.closed = true - _, err := w.c.writeFrame(w.ctx, true, w.opcode, nil) + _, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil) if err != nil { return xerrors.Errorf("failed to write fin frame: %w", err) } @@ -618,20 +636,6 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error // writeFrame handles all writes to the connection. func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { - h := header{ - fin: fin, - opcode: opcode, - masked: c.client, - payloadLength: int64(len(p)), - } - - if c.client { - _, err := io.ReadFull(cryptorand.Reader, h.maskKey[:]) - if err != nil { - return 0, xerrors.Errorf("failed to generate masking key: %w", err) - } - } - err := c.acquireLock(ctx, c.writeFrameLock) if err != nil { return 0, err @@ -644,7 +648,19 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte case c.setWriteTimeout <- ctx: } - n, err := c.realWriteFrame(ctx, h, p) + c.writeHeader.fin = fin + c.writeHeader.opcode = opcode + c.writeHeader.masked = c.client + c.writeHeader.payloadLength = int64(len(p)) + + if c.client { + _, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:]) + if err != nil { + return 0, xerrors.Errorf("failed to generate masking key: %w", err) + } + } + + n, err := c.realWriteFrame(ctx, *c.writeHeader, p) if err != nil { return n, err } diff --git a/websocket_test.go b/websocket_test.go index 17444642..adcc8aeb 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -879,7 +879,7 @@ func BenchmarkConn(b *testing.B) { b.Run("echo", func(b *testing.B) { for _, size := range sizes { b.Run(strconv.Itoa(size), func(b *testing.B) { - benchConn(b, true, true, size) + benchConn(b, true, false, size) }) } }) diff --git a/xor_test.go b/xor_test.go index 634af606..be766227 100644 --- a/xor_test.go +++ b/xor_test.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "strconv" "testing" + "unsafe" "github.com/google/go-cmp/cmp" ) @@ -80,3 +81,8 @@ func BenchmarkXOR(b *testing.B) { }) } } + +func TestFoo(t *testing.T) { + t.Log(unsafe.Sizeof(messageWriter{})) + t.Log(unsafe.Sizeof(messageReader{})) +} From 5eff0e397a3a46c3e3e8cd51cb531b4619f5ea02 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 10 Jun 2019 03:47:46 -0400 Subject: [PATCH 12/15] Update performance comparison --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 38541cfd..1ba912a4 100644 --- a/README.md +++ b/README.md @@ -127,10 +127,9 @@ The ping API is also much nicer. gorilla/websocket requires registering a pong h which results in awkward control flow. With nhooyr/websocket you use the Ping method on the Conn that sends a ping and also waits for the pong. -In terms of performance, the differences depend on your application code. nhooyr/websocket -reuses buffers efficiently out of the box if you use the wsjson and wspb subpackages whereas -gorilla/websocket does not at all. As mentioned above, nhooyr/websocket also supports concurrent -writers out of the box. +In terms of performance, the differences mostly depend on your application code. nhooyr/websocket +reuses message buffers out of the box if you use the wsjson and wspb subpackages. +As mentioned above, nhooyr/websocket also supports concurrent writers. The only performance con to nhooyr/websocket is that uses one extra goroutine to support cancellation with context.Context and the net/http client side body upgrade. From 7b05f53672061cd623e64515a6b211001e964ad1 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 10 Jun 2019 10:27:57 -0400 Subject: [PATCH 13/15] Fix writeOnly example --- example_test.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/example_test.go b/example_test.go index eef0e98d..050af907 100644 --- a/example_test.go +++ b/example_test.go @@ -76,10 +76,8 @@ func Example_writeOnly() { go func() { defer cancel() - _, _, err := c.Reader(ctx) - if err == nil { - c.Close(websocket.StatusPolicyViolation, "server doesn't accept data messages") - } + c.Reader(ctx) + c.Close(websocket.StatusPolicyViolation, "server doesn't accept data messages") }() t := time.NewTicker(time.Second * 30) From 3e007c6d1af687b6460b37635067b6b25ea86f8a Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 10 Jun 2019 10:49:05 -0400 Subject: [PATCH 14/15] Remove unneeded Foo test --- websocket_test.go | 6 +++--- xor_test.go | 6 ------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/websocket_test.go b/websocket_test.go index adcc8aeb..5209e2d7 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -809,7 +809,7 @@ func benchConn(b *testing.B, echo, stream bool, size int) { defer c.Close(websocket.StatusInternalError, "") msg := []byte(strings.Repeat("2", size)) - buf := make([]byte, len(msg)) + readBuf := make([]byte, len(msg)) b.SetBytes(int64(len(msg))) b.ReportAllocs() b.ResetTimer() @@ -842,7 +842,7 @@ func benchConn(b *testing.B, echo, stream bool, size int) { b.Fatal(err) } - _, err = io.ReadFull(r, buf) + _, err = io.ReadFull(r, readBuf) if err != nil { b.Fatal(err) } @@ -879,7 +879,7 @@ func BenchmarkConn(b *testing.B) { b.Run("echo", func(b *testing.B) { for _, size := range sizes { b.Run(strconv.Itoa(size), func(b *testing.B) { - benchConn(b, true, false, size) + benchConn(b, false, false, size) }) } }) diff --git a/xor_test.go b/xor_test.go index be766227..634af606 100644 --- a/xor_test.go +++ b/xor_test.go @@ -4,7 +4,6 @@ import ( "crypto/rand" "strconv" "testing" - "unsafe" "github.com/google/go-cmp/cmp" ) @@ -81,8 +80,3 @@ func BenchmarkXOR(b *testing.B) { }) } } - -func TestFoo(t *testing.T) { - t.Log(unsafe.Sizeof(messageWriter{})) - t.Log(unsafe.Sizeof(messageReader{})) -} From 73d39e21ebe2c59ccc70a68831808b4058367251 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 10 Jun 2019 11:08:32 -0400 Subject: [PATCH 15/15] Fix a error resp in Accept --- accept.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/accept.go b/accept.go index e0054b2b..ca1eeeaf 100644 --- a/accept.go +++ b/accept.go @@ -106,7 +106,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, hj, ok := w.(http.Hijacker) if !ok { err = xerrors.New("passed ResponseWriter does not implement http.Hijacker") - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return nil, err }