From 8604dee32ef3a441729179f610eb3dce2e40c5ff Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 14 Oct 2019 21:31:46 -0400 Subject: [PATCH 01/55] Increase TestWASM timeout --- conn_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn_test.go b/conn_test.go index 83f09dbf..d03a7214 100644 --- a/conn_test.go +++ b/conn_test.go @@ -2377,7 +2377,7 @@ func TestWASM(t *testing.T) { wsURL := strings.Replace(s.URL, "http", "ws", 1) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") From e55ac18137f04b40dc74556bfb2b92b242db32b5 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 14 Oct 2019 16:49:51 -0400 Subject: [PATCH 02/55] Document compression API So it begins :) --- handshake.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/handshake.go b/handshake.go index 2c01cab6..81ebf48a 100644 --- a/handshake.go +++ b/handshake.go @@ -45,6 +45,11 @@ type AcceptOptions struct { // If you do, remember that if you store secure data in cookies, you wil need to verify the // Origin header yourself otherwise you are exposing yourself to a CSRF attack. InsecureSkipVerify bool + + // Compression sets the compression options. + // By default, compression is disabled. + // See docs on the CompressionOptions type. + Compression *CompressionOptions } func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { @@ -240,6 +245,49 @@ type DialOptions struct { // Subprotocols lists the subprotocols to negotiate with the server. Subprotocols []string + + // Compression sets the compression options. + // By default, compression is disabled. + // See docs on the CompressionOptions type. + Compression CompressionOptions +} + +// CompressionOptions describes the available compression options. +// +// See https://tools.ietf.org/html/rfc7692 +// +// Enabling compression may spike memory usage as each flate.Writer takes up 1.2 MB. +// See https://github.com/gorilla/websocket/issues/203 +// Benchmark before enabling in production. +// +// This API is experimental and subject to change. +type CompressionOptions struct { + // ContextTakeover controls whether context takeover is enabled. + // + // If ContextTakeover == false, then a flate.Writer will be grabbed + // from the pool as needed for every message written to the connection. + // + // If ContextTakeover == true, then a flate.Writer will be allocated for each connection. + // This allows more efficient compression as the sliding window from previous + // messages will be used instead of resetting in between every message. + // The downside is that for every connection there will be a fixed allocation + // for the flate.Writer. + // + // See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression. + ContextTakeover bool + + // Level controls the compression level negotiated. + // Defaults to flate.BestSpeed. + Level int + + // Threshold controls the minimum message size in bytes before compression is used. + // In the case of ContextTakeover == false, a flate.Writer will not be grabbed + // from the pool until the message exceeds this threshold. + // + // Must not be greater than 4096 as that is the write buffer's size. + // + // Defaults to 512. + Threshold int } // Dial performs a WebSocket handshake on the given url with the given options. From e142e08cbe82354cbee73f4b023623f04a55924d Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 11 Nov 2019 19:28:39 -0500 Subject: [PATCH 03/55] Improve compression docs --- README.md | 8 ++++---- handshake.go | 44 +++++++++++++++++++++++++++++--------------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index c426423a..b5adc59c 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ # websocket -[![GitHub Release](https://img.shields.io/github/v/release/nhooyr/websocket?color=6b9ded&sort=semver)](https://github.com/nhooyr/websocket/releases) -[![GoDoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket) -[![Coveralls](https://img.shields.io/coveralls/github/nhooyr/websocket?color=65d6a4)](https://coveralls.io/github/nhooyr/websocket) -[![CI Status](https://github.com/nhooyr/websocket/workflows/ci/badge.svg)](https://github.com/nhooyr/websocket/actions) +[![version](https://img.shields.io/github/v/release/nhooyr/websocket?color=6b9ded&sort=semver)](https://github.com/nhooyr/websocket/releases) +[![docs](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket) +[![coverage](https://img.shields.io/coveralls/github/nhooyr/websocket?color=65d6a4)](https://coveralls.io/github/nhooyr/websocket) +[![ci](https://github.com/nhooyr/websocket/workflows/ci/badge.svg)](https://github.com/nhooyr/websocket/actions) websocket is a minimal and idiomatic WebSocket library for Go. diff --git a/handshake.go b/handshake.go index 81ebf48a..2cde6ae2 100644 --- a/handshake.go +++ b/handshake.go @@ -249,34 +249,45 @@ type DialOptions struct { // Compression sets the compression options. // By default, compression is disabled. // See docs on the CompressionOptions type. - Compression CompressionOptions + Compression *CompressionOptions } // CompressionOptions describes the available compression options. // // See https://tools.ietf.org/html/rfc7692 // -// Enabling compression may spike memory usage as each flate.Writer takes up 1.2 MB. +// The NoContextTakeover variables control whether a flate.Writer or flate.Reader is allocated +// for every connection (context takeover) versus shared from a pool (no context takeover). +// +// The advantage to context takeover is more efficient compression as the sliding window from previous +// messages will be used instead of being reset between every message. +// +// The advantage to no context takeover is that the flate structures are allocated as needed +// and shared between connections instead of giving each connection a fixed flate.Writer and +// flate.Reader. +// +// See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression. +// +// Enabling compression will increase memory and CPU usage. +// Thus it is not ideal for every use case and disabled by default. // See https://github.com/gorilla/websocket/issues/203 -// Benchmark before enabling in production. +// Profile before enabling in production. // // This API is experimental and subject to change. type CompressionOptions struct { - // ContextTakeover controls whether context takeover is enabled. - // - // If ContextTakeover == false, then a flate.Writer will be grabbed - // from the pool as needed for every message written to the connection. + // ServerNoContextTakeover controls whether the server should use context takeover. + // See docs on CompressionOptions for discussion regarding context takeover. // - // If ContextTakeover == true, then a flate.Writer will be allocated for each connection. - // This allows more efficient compression as the sliding window from previous - // messages will be used instead of resetting in between every message. - // The downside is that for every connection there will be a fixed allocation - // for the flate.Writer. + // If set by the client, will guarantee that the server does not use context takeover. + ServerNoContextTakeover bool + + // ClientNoContextTakeover controls whether the client should use context takeover. + // See docs on CompressionOptions for discussion regarding context takeover. // - // See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression. - ContextTakeover bool + // If set by the server, will guarantee that the client does not use context takeover. + ClientNoContextTakeover bool - // Level controls the compression level negotiated. + // Level controls the compression level used. // Defaults to flate.BestSpeed. Level int @@ -355,6 +366,9 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } + if opts.Compression != nil { + req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + } resp, err := opts.HTTPClient.Do(req) if err != nil { From 53c1aea0c6ec1169acb4359dd2361e938e910455 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 11 Nov 2019 21:01:08 -0500 Subject: [PATCH 04/55] Implement compression extension negotiation --- conn.go | 1 + doc.go | 1 + handshake.go | 190 +++++++++++++++++++++++++++++++++++----------- handshake_test.go | 2 +- 4 files changed, 149 insertions(+), 45 deletions(-) diff --git a/conn.go b/conn.go index 26906c79..14d93cf6 100644 --- a/conn.go +++ b/conn.go @@ -42,6 +42,7 @@ type Conn struct { writeBuf []byte closer io.Closer client bool + copts *CompressionOptions closeOnce sync.Once closeErrOnce sync.Once diff --git a/doc.go b/doc.go index b29d2cdd..804665fb 100644 --- a/doc.go +++ b/doc.go @@ -31,6 +31,7 @@ // - Accept and AcceptOptions // - Conn.Ping // - HTTPClient and HTTPHeader fields in DialOptions +// - CompressionOptions // // The *http.Response returned by Dial will always either be nil or &http.Response{} as // we do not have access to the handshake response in the browser. diff --git a/handshake.go b/handshake.go index 2cde6ae2..787fee2c 100644 --- a/handshake.go +++ b/handshake.go @@ -59,13 +59,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { return err } - if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") { + if !headerContainsToken(r.Header, "Connection", "Upgrade") { err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) http.Error(w, err.Error(), http.StatusBadRequest) return err } - if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") { + if !headerContainsToken(r.Header, "Upgrade", "WebSocket") { err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) http.Error(w, err.Error(), http.StatusBadRequest) return err @@ -144,6 +144,18 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, w.Header().Set("Sec-WebSocket-Protocol", subproto) } + var copts *CompressionOptions + if opts.Compression != nil { + copts, err = negotiateCompression(r.Header, opts.Compression) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + if copts != nil { + copts.setHeader(w.Header()) + } + } + w.WriteHeader(http.StatusSwitchingProtocols) netConn, brw, err := hj.Hijack() @@ -162,17 +174,23 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, br: brw.Reader, bw: brw.Writer, closer: netConn, + copts: copts, } c.init() return c, nil } -func headerValuesContainsToken(h http.Header, key, token string) bool { +func headerContainsToken(h http.Header, key, token string) bool { key = textproto.CanonicalMIMEHeaderKey(key) - for _, val2 := range h[key] { - if headerValueContainsToken(val2, token) { + token = strings.ToLower(token) + match := func(t string) bool { + return t == token + } + + for _, v := range h[key] { + if searchHeaderTokens(v, match) != "" { return true } } @@ -180,22 +198,41 @@ func headerValuesContainsToken(h http.Header, key, token string) bool { return false } -func headerValueContainsToken(val2, token string) bool { - val2 = strings.TrimSpace(val2) +func headerTokenHasPrefix(h http.Header, key, prefix string) string { + key = textproto.CanonicalMIMEHeaderKey(key) - for _, val2 := range strings.Split(val2, ",") { - val2 = strings.TrimSpace(val2) - if strings.EqualFold(val2, token) { - return true + prefix = strings.ToLower(prefix) + match := func(t string) bool { + return strings.HasPrefix(t, prefix) + } + + for _, v := range h[key] { + found := searchHeaderTokens(v, match) + if found != "" { + return found } } - return false + return "" +} + +func searchHeaderTokens(v string, match func(val string) bool) string { + v = strings.TrimSpace(v) + + for _, v2 := range strings.Split(v, ",") { + v2 = strings.TrimSpace(v2) + v2 = strings.ToLower(v2) + if match(v2) { + return v2 + } + } + + return "" } func selectSubprotocol(r *http.Request, subprotocols []string) string { for _, sp := range subprotocols { - if headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { + if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { return sp } } @@ -268,36 +305,32 @@ type DialOptions struct { // // See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression. // -// Enabling compression will increase memory and CPU usage. -// Thus it is not ideal for every use case and disabled by default. +// Enabling compression will increase memory and CPU usage and should +// be profiled before enabling in production. // See https://github.com/gorilla/websocket/issues/203 -// Profile before enabling in production. // // This API is experimental and subject to change. type CompressionOptions struct { - // ServerNoContextTakeover controls whether the server should use context takeover. - // See docs on CompressionOptions for discussion regarding context takeover. - // - // If set by the client, will guarantee that the server does not use context takeover. - ServerNoContextTakeover bool - // ClientNoContextTakeover controls whether the client should use context takeover. // See docs on CompressionOptions for discussion regarding context takeover. // // If set by the server, will guarantee that the client does not use context takeover. ClientNoContextTakeover bool + // ServerNoContextTakeover controls whether the server should use context takeover. + // See docs on CompressionOptions for discussion regarding context takeover. + // + // If set by the client, will guarantee that the server does not use context takeover. + ServerNoContextTakeover bool + // Level controls the compression level used. // Defaults to flate.BestSpeed. Level int // Threshold controls the minimum message size in bytes before compression is used. - // In the case of ContextTakeover == false, a flate.Writer will not be grabbed - // from the pool until the message exceeds this threshold. - // // Must not be greater than 4096 as that is the write buffer's size. // - // Defaults to 512. + // Defaults to 256. Threshold int } @@ -319,25 +352,32 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon return c, r, nil } -func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { +func (opts *DialOptions) ensure() (*DialOptions, error) { if opts == nil { opts = &DialOptions{} + } else { + opts = &*opts } - // Shallow copy to ensure defaults do not affect user passed options. - opts2 := *opts - opts = &opts2 - if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } if opts.HTTPClient.Timeout > 0 { - return nil, nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") + return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") } if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } + return opts, nil +} + +func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { + opts, err = opts.ensure() + if err != nil { + return nil, nil, err + } + parsedURL, err := url.Parse(u) if err != nil { return nil, nil, fmt.Errorf("failed to parse url: %w", err) @@ -367,7 +407,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if opts.Compression != nil { - req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + opts.Compression.setHeader(req.Header) } resp, err := opts.HTTPClient.Do(req) @@ -384,7 +424,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re } }() - err = verifyServerResponse(req, resp) + copts, err := verifyServerResponse(req, resp, opts) if err != nil { return nil, resp, err } @@ -400,6 +440,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re bw: getBufioWriter(rwc), closer: rwc, client: true, + copts: copts, } c.extractBufioWriterBuf(rwc) c.init() @@ -407,31 +448,40 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re return c, resp, nil } -func verifyServerResponse(r *http.Request, resp *http.Response) error { +func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*CompressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { - return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } - if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") { - return fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + if !headerContainsToken(resp.Header, "Connection", "Upgrade") { + return nil, fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } - if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") { - return fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { + return nil, fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { - return fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + return nil, fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), r.Header.Get("Sec-WebSocket-Key"), ) } - if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { - return fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) + if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { + return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } - return nil + var copts *CompressionOptions + if opts.Compression != nil { + var err error + copts, err = negotiateCompression(resp.Header, opts.Compression) + if err != nil { + return nil, err + } + } + + return copts, nil } // The below pools can only be used by the client because http.Hijacker will always @@ -477,3 +527,55 @@ func makeSecWebSocketKey() (string, error) { } return base64.StdEncoding.EncodeToString(b), nil } + +func negotiateCompression(h http.Header, copts *CompressionOptions) (*CompressionOptions, error) { + deflate := headerTokenHasPrefix(h, "Sec-WebSocket-Extensions", "permessage-deflate") + if deflate == "" { + return nil, nil + } + + // Ensures our changes do not modify the real compression options. + copts = &*copts + + params := strings.Split(deflate, ";") + for i := range params { + params[i] = strings.TrimSpace(params[i]) + } + + if params[0] != "permessage-deflate" { + return nil, fmt.Errorf("unexpected header format for permessage-deflate extension: %q", deflate) + } + + for _, p := range params[1:] { + switch p { + case "client_no_context_takeover": + copts.ClientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.ServerNoContextTakeover = true + continue + case "client_max_window_bits", "server-max-window-bits": + server := h.Get("Sec-WebSocket-Key") != "" + if server { + // If we are the server, we are allowed to ignore these parameters. + // However, if we are the client, we must obey them but because of + // https://github.com/golang/go/issues/3155 we cannot. + continue + } + } + return nil, fmt.Errorf("unsupported permessage-deflate parameter %q in header: %q", p, deflate) + } + + return copts, nil +} + +func (copts *CompressionOptions) setHeader(h http.Header) { + s := "permessage-deflate" + if copts.ClientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.ServerNoContextTakeover { + s += "; server_no_context_takeover" + } + h.Set("Sec-WebSocket-Extensions", s) +} diff --git a/handshake_test.go b/handshake_test.go index cb09353f..82f958e0 100644 --- a/handshake_test.go +++ b/handshake_test.go @@ -377,7 +377,7 @@ func Test_verifyServerHandshake(t *testing.T) { resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } - err = verifyServerResponse(r, resp) + _, err = verifyServerResponse(r, resp, &DialOptions{}) if (err == nil) != tc.success { t.Fatalf("unexpected error: %+v", err) } From 2cf6c28875c3511edfee7409b5f25a994d2edbf3 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 11 Nov 2019 21:29:08 -0500 Subject: [PATCH 05/55] Implement compression writer and reader pooling --- conn.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/conn.go b/conn.go index 14d93cf6..32dfa81b 100644 --- a/conn.go +++ b/conn.go @@ -4,6 +4,7 @@ package websocket import ( "bufio" + "compress/flate" "context" "crypto/rand" "encoding/binary" @@ -1063,3 +1064,55 @@ func (c *Conn) extractBufioWriterBuf(w io.Writer) { c.bw.Reset(w) } + +var flateWriterPoolsMu sync.Mutex +var flateWriterPools = make(map[int]*sync.Pool) + +func getFlateWriterPool(level int) *sync.Pool { + flateWriterPoolsMu.Lock() + defer flateWriterPoolsMu.Unlock() + + p, ok := flateWriterPools[level] + if !ok { + p = &sync.Pool{ + New: func() interface{} { + w, err := flate.NewWriter(nil, level) + if err != nil { + panic("websocket: unexpected error from flate.NewWriter: " + err.Error()) + } + return w + }, + } + flateWriterPools[level] = p + } + + return p +} + +func getFlateWriter(w io.Writer, level int) *flate.Writer { + p := getFlateWriterPool(level) + fw := p.Get().(*flate.Writer) + fw.Reset(w) + return fw +} + +func putFlateWriter(w *flate.Writer, level int) { + p := getFlateWriterPool(level) + p.Put(w) +} + +var flateReaderPool = &sync.Pool{ + New: func() interface{} { + return flate.NewReader(nil) + }, +} + +func getFlateReader(r flate.Reader) io.ReadCloser { + fr := flateReaderPool.Get().(io.ReadCloser) + fr.(flate.Resetter).Reset(r, nil) + return fr +} + +func putFlateReader(fr io.ReadCloser) { + flateReaderPool.Put(fr) +} From a01afeace4a00b64f92eb94a6d5c40d22b6386e3 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 11 Nov 2019 22:55:47 -0500 Subject: [PATCH 06/55] Support x-webkit-deflate-frame extension for Safari --- handshake.go | 130 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 93 insertions(+), 37 deletions(-) diff --git a/handshake.go b/handshake.go index 787fee2c..03331039 100644 --- a/handshake.go +++ b/handshake.go @@ -152,7 +152,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, return nil, err } if copts != nil { - copts.setHeader(w.Header()) + copts.setHeader(w.Header(), false) } } @@ -190,7 +190,7 @@ func headerContainsToken(h http.Header, key, token string) bool { } for _, v := range h[key] { - if searchHeaderTokens(v, match) != "" { + if searchHeaderTokens(v, match) { return true } } @@ -198,36 +198,54 @@ func headerContainsToken(h http.Header, key, token string) bool { return false } -func headerTokenHasPrefix(h http.Header, key, prefix string) string { - key = textproto.CanonicalMIMEHeaderKey(key) - - prefix = strings.ToLower(prefix) +// readCompressionExtensionHeader extracts compression extension info from h. +// The standard says we should support multiple compression extension configurations +// from the client but we don't need to as there is only a single deflate extension +// and we support every configuration without error so we only need to check the first +// and thus preferred configuration. +func readCompressionExtensionHeader(h http.Header) (xWebkitDeflateFrame bool, params []string, ok bool) { match := func(t string) bool { - return strings.HasPrefix(t, prefix) + vals := strings.Split(t, ";") + for i := range vals { + vals[i] = strings.TrimSpace(vals[i]) + } + params = vals[1:] + + if vals[0] == "permessage-deflate" { + return true + } + + // See https://bugs.webkit.org/show_bug.cgi?id=115504 + if vals[0] == "x-webkit-deflate-frame" { + xWebkitDeflateFrame = true + return true + } + + return false } + key := textproto.CanonicalMIMEHeaderKey("Sec-WebSocket-Extensions") for _, v := range h[key] { - found := searchHeaderTokens(v, match) - if found != "" { - return found + if searchHeaderTokens(v, match) { + return xWebkitDeflateFrame, params, true } } - return "" + return false, nil, false } -func searchHeaderTokens(v string, match func(val string) bool) string { +func searchHeaderTokens(v string, match func(val string) bool) bool { + v = strings.ToLower(v) v = strings.TrimSpace(v) for _, v2 := range strings.Split(v, ",") { v2 = strings.TrimSpace(v2) - v2 = strings.ToLower(v2) if match(v2) { - return v2 + return true } } - return "" + return false } func selectSubprotocol(r *http.Request, subprotocols []string) string { @@ -332,6 +350,10 @@ type CompressionOptions struct { // // Defaults to 256. Threshold int + + // This is used for supporting Safari as it still uses x-webkit-deflate-frame. + // See negotiateCompression. + xWebkitDeflateFrame bool } // Dial performs a WebSocket handshake on the given url with the given options. @@ -407,7 +429,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } if opts.Compression != nil { - opts.Compression.setHeader(req.Header) + opts.Compression.setHeader(req.Header, true) } resp, err := opts.HTTPClient.Do(req) @@ -529,24 +551,30 @@ func makeSecWebSocketKey() (string, error) { } func negotiateCompression(h http.Header, copts *CompressionOptions) (*CompressionOptions, error) { - deflate := headerTokenHasPrefix(h, "Sec-WebSocket-Extensions", "permessage-deflate") - if deflate == "" { + xWebkitDeflateFrame, params, ok := readCompressionExtensionHeader(h) + if !ok { return nil, nil } // Ensures our changes do not modify the real compression options. copts = &*copts - - params := strings.Split(deflate, ";") - for i := range params { - params[i] = strings.TrimSpace(params[i]) - } - - if params[0] != "permessage-deflate" { - return nil, fmt.Errorf("unexpected header format for permessage-deflate extension: %q", deflate) + copts.xWebkitDeflateFrame = xWebkitDeflateFrame + + // We are the client if the header contains the accept header, meaning its from the server. + client := h.Get("Sec-WebSocket-Accept") == "" + + if copts.xWebkitDeflateFrame { + // The other endpoint dictates whether or not we can + // use context takeover on our side. We cannot force it. + // Likewise, we tell the other side so we can force that. + if client { + copts.ClientNoContextTakeover = false + } else { + copts.ServerNoContextTakeover = false + } } - for _, p := range params[1:] { + for _, p := range params { switch p { case "client_no_context_takeover": copts.ClientNoContextTakeover = true @@ -555,27 +583,55 @@ func negotiateCompression(h http.Header, copts *CompressionOptions) (*Compressio copts.ServerNoContextTakeover = true continue case "client_max_window_bits", "server-max-window-bits": - server := h.Get("Sec-WebSocket-Key") != "" - if server { + if !client { // If we are the server, we are allowed to ignore these parameters. // However, if we are the client, we must obey them but because of // https://github.com/golang/go/issues/3155 we cannot. continue } + case "no_context_takeover": + if copts.xWebkitDeflateFrame { + if client { + copts.ClientNoContextTakeover = true + } else { + copts.ServerNoContextTakeover = true + } + continue + } + + // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead + // of ignoring it as the draft spec is unclear. It says the server can ignore it + // but the server has no way of signalling to the client it was ignored as parameters + // are set one way. + // Thus us ignoring it would make the client think we understood it which would cause issues. + // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 + // + // Either way, we're only implementing this for webkit which never sends the max_window_bits + // parameter so we don't need to worry about it. } - return nil, fmt.Errorf("unsupported permessage-deflate parameter %q in header: %q", p, deflate) + + return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } return copts, nil } -func (copts *CompressionOptions) setHeader(h http.Header) { - s := "permessage-deflate" - if copts.ClientNoContextTakeover { - s += "; client_no_context_takeover" - } - if copts.ServerNoContextTakeover { - s += "; server_no_context_takeover" +func (copts *CompressionOptions) setHeader(h http.Header, client bool) { + var s string + if !copts.xWebkitDeflateFrame { + s := "permessage-deflate" + if copts.ClientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.ServerNoContextTakeover { + s += "; server_no_context_takeover" + } + } else { + s = "x-webkit-deflate-frame" + // We can only set no context takeover for the peer. + if client && copts.ServerNoContextTakeover || !client && copts.ClientNoContextTakeover { + s += "; no_context_takeover" + } } h.Set("Sec-WebSocket-Extensions", s) } From 531d4fab2b30955df6ca43aea0417eb7aa60d515 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Tue, 12 Nov 2019 11:15:17 -0500 Subject: [PATCH 07/55] Improve general compression API and write docs --- README.md | 46 +- accept.go | 330 +++++++++ handshake_test.go => accept_test.go | 143 ---- assert_test.go | 56 +- ci/fmt.mk | 2 +- close.go | 181 +++++ close_test.go | 196 ++++++ compress.go | 78 +++ conn.go | 297 +++++--- dial.go | 219 ++++++ dial_test.go | 149 ++++ doc.go | 2 +- frame.go | 445 ------------ frame_test.go | 457 ------------- handshake.go | 637 ------------------ internal/assert/assert.go | 18 +- internal/atomicint/atomicint.go | 32 + internal/{bpool/bpool.go => bufpool/buf.go} | 2 +- .../bpool_test.go => bufpool/buf_test.go} | 2 +- internal/bufpool/bufio.go | 40 ++ internal/wsframe/frame.go | 194 ++++++ .../wsframe/frame_stringer.go | 20 +- internal/wsframe/frame_test.go | 157 +++++ internal/wsframe/mask.go | 128 ++++ internal/wsframe/mask_test.go | 118 ++++ js_test.go | 50 ++ conn_common.go => netconn.go | 78 --- reader.go | 31 + websocket_js_test.go | 52 -- writer.go | 5 + websocket_js.go => ws_js.go | 58 +- ws_js_test.go | 22 + wsjson/wsjson.go | 7 +- wspb/wspb.go | 10 +- 34 files changed, 2243 insertions(+), 2019 deletions(-) create mode 100644 accept.go rename handshake_test.go => accept_test.go (62%) create mode 100644 close.go create mode 100644 close_test.go create mode 100644 compress.go create mode 100644 dial.go create mode 100644 dial_test.go delete mode 100644 frame.go delete mode 100644 frame_test.go delete mode 100644 handshake.go create mode 100644 internal/atomicint/atomicint.go rename internal/{bpool/bpool.go => bufpool/buf.go} (95%) rename internal/{bpool/bpool_test.go => bufpool/buf_test.go} (97%) create mode 100644 internal/bufpool/bufio.go create mode 100644 internal/wsframe/frame.go rename frame_stringer.go => internal/wsframe/frame_stringer.go (90%) create mode 100644 internal/wsframe/frame_test.go create mode 100644 internal/wsframe/mask.go create mode 100644 internal/wsframe/mask_test.go create mode 100644 js_test.go rename conn_common.go => netconn.go (60%) create mode 100644 reader.go delete mode 100644 websocket_js_test.go create mode 100644 writer.go rename websocket_js.go => ws_js.go (88%) create mode 100644 ws_js_test.go diff --git a/README.md b/README.md index b5adc59c..17c7c838 100644 --- a/README.md +++ b/README.md @@ -22,13 +22,14 @@ go get nhooyr.io/websocket - [Zero dependencies](https://godoc.org/nhooyr.io/websocket?imports) - JSON and ProtoBuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages - Highly optimized by default + - Zero alloc reads and writes - Concurrent writes out of the box - [Complete Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) support - [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) +- Full support of [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression extension ## Roadmap -- [ ] Compression Extensions [#163](https://github.com/nhooyr/websocket/pull/163) - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) ## Examples @@ -84,22 +85,12 @@ if err != nil { c.Close(websocket.StatusNormalClosure, "") ``` -## Design justifications - -- A minimal API is easier to maintain due to less docs, tests and bugs -- A minimal API is also easier to use and learn -- Context based cancellation is more ergonomic and robust than setting deadlines -- net.Conn is never exposed as WebSocket over HTTP/2 will not have a net.Conn. -- Using net/http's Client for dialing means we do not have to reinvent dialing hooks - and configurations like other WebSocket libraries - ## Comparison -Before the comparison, I want to point out that both gorilla/websocket and gobwas/ws were -extremely useful in implementing the WebSocket protocol correctly so _big thanks_ to the -authors of both. In particular, I made sure to go through the issue tracker of gorilla/websocket -to ensure I implemented details correctly and understood how people were using WebSockets in -production. +Before the comparison, I want to point out that gorilla/websocket was extremely useful in implementing the +WebSocket protocol correctly so _big thanks_ to its authors. In particular, I made sure to go through the +issue tracker of gorilla/websocket to ensure I implemented details correctly and understood how people were +using WebSockets in production. ### gorilla/websocket @@ -121,7 +112,7 @@ more code to test, more code to document and more surface area for bugs. Moreover, nhooyr.io/websocket supports newer Go idioms such as context.Context. It also uses net/http's Client and ResponseWriter directly for WebSocket handshakes. gorilla/websocket writes its handshakes to the underlying net.Conn. -Thus it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2. +Thus it has to reinvent hooks for TLS and proxies and prevents easy support of HTTP/2. Some more advantages of nhooyr.io/websocket are that it supports concurrent writes and makes it very easy to close the connection with a status code and reason. In fact, @@ -138,10 +129,14 @@ In terms of performance, the differences mostly depend on your application code. reuses message buffers out of the box if you use the wsjson and wspb subpackages. As mentioned above, nhooyr.io/websocket also supports concurrent writers. -The WebSocket masking algorithm used by this package is also [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) -faster than gorilla/websocket or gobwas/ws while using only pure safe Go. +The WebSocket masking algorithm used by this package is [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) +faster than gorilla/websocket while using only pure safe Go. -The only performance con to nhooyr.io/websocket is that it uses one extra goroutine to support +The [permessage-deflate compression extension](https://tools.ietf.org/html/rfc7692) is fully supported by this library +whereas gorilla only supports no context takeover mode. See our godoc for the differences. This will make a big +difference on bandwidth used in most use cases. + +The only performance con to nhooyr.io/websocket is that it uses a goroutine to support cancellation with context.Context. This costs 2 KB of memory which is cheap compared to the benefits. @@ -160,14 +155,15 @@ https://github.com/gobwas/ws This library has an extremely flexible API but that comes at the cost of usability and clarity. -This library is fantastic in terms of performance. The author put in significant -effort to ensure its speed and I have applied as many of its optimizations as -I could into nhooyr.io/websocket. Definitely check out his fantastic [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb) -about performant WebSocket servers. +Due to its flexibility, it can be used in a event driven style for performance. +Definitely check out his fantastic [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb) about performant WebSocket servers. If you want a library that gives you absolute control over everything, this is the library. -But for 99.9% of use cases, nhooyr.io/websocket will fit better. It's nearly as performant -but much easier to use. +But for 99.9% of use cases, nhooyr.io/websocket will fit better as it is both easier and +faster for normal idiomatic Go. The masking implementation is [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) +faster, the compression extensions are fully supported and as much as possible is reused by default. + +See the gorilla/websocket comparison for more performance details. ## Contributing diff --git a/accept.go b/accept.go new file mode 100644 index 00000000..5ff2ea41 --- /dev/null +++ b/accept.go @@ -0,0 +1,330 @@ +package websocket + +import ( + "bytes" + "crypto/sha1" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "net/textproto" + "net/url" + "strings" +) + +// AcceptOptions represents the options available to pass to Accept. +type AcceptOptions struct { + // Subprotocols lists the websocket subprotocols that Accept will negotiate with a client. + // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to + // reject it, close the connection if c.Subprotocol() == "". + Subprotocols []string + + // InsecureSkipVerify disables Accept's origin verification + // behaviour. By default Accept only allows the handshake to + // succeed if the javascript that is initiating the handshake + // is on the same domain as the server. This is to prevent CSRF + // attacks when secure data is stored in a cookie as there is no same + // origin policy for WebSockets. In other words, javascript from + // any domain can perform a WebSocket dial on an arbitrary server. + // This dial will include cookies which means the arbitrary javascript + // can perform actions as the authenticated user. + // + // See https://stackoverflow.com/a/37837709/4283659 + // + // The only time you need this is if your javascript is running on a different domain + // than your WebSocket server. + // Think carefully about whether you really need this option before you use it. + // If you do, remember that if you store secure data in cookies, you wil need to verify the + // Origin header yourself otherwise you are exposing yourself to a CSRF attack. + InsecureSkipVerify bool + + // CompressionMode sets the compression mode. + // See docs on the CompressionMode type and defined constants. + CompressionMode CompressionMode +} + +// Accept accepts a WebSocket HTTP handshake from a client and upgrades the +// the connection to a WebSocket. +// +// Accept will reject the handshake if the Origin domain is not the same as the Host unless +// the InsecureSkipVerify option is set. In other words, by default it does not allow +// cross origin requests. +// +// If an error occurs, Accept will write a response with a safe error message to w. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + c, err := accept(w, r, opts) + if err != nil { + return nil, fmt.Errorf("failed to accept websocket connection: %w", err) + } + return c, nil +} + +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + if opts == nil { + opts = &AcceptOptions{} + } + + err := verifyClientRequest(w, r) + if err != nil { + return nil, err + } + + if !opts.InsecureSkipVerify { + err = authenticateOrigin(r) + if err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return nil, err + } + } + + hj, ok := w.(http.Hijacker) + if !ok { + err = errors.New("passed ResponseWriter does not implement http.Hijacker") + http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) + return nil, err + } + + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Connection", "Upgrade") + + handleSecWebSocketKey(w, r) + + subproto := selectSubprotocol(r, opts.Subprotocols) + if subproto != "" { + w.Header().Set("Sec-WebSocket-Protocol", subproto) + } + + copts, err := acceptCompression(r, w, opts.CompressionMode) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + + w.WriteHeader(http.StatusSwitchingProtocols) + + netConn, brw, err := hj.Hijack() + if err != nil { + err = fmt.Errorf("failed to hijack connection: %w", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return nil, err + } + + // https://github.com/golang/go/issues/32314 + b, _ := brw.Reader.Peek(brw.Reader.Buffered()) + brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) + + c := &Conn{ + subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), + br: brw.Reader, + bw: brw.Writer, + closer: netConn, + copts: copts, + } + c.init() + + return c, nil +} + +func authenticateOrigin(r *http.Request) error { + origin := r.Header.Get("Origin") + if origin == "" { + return nil + } + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + if !strings.EqualFold(u.Host, r.Host) { + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + } + return nil +} + +func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { + if !r.ProtoAtLeast(1, 1) { + err := fmt.Errorf("websocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if !headerContainsToken(r.Header, "Connection", "Upgrade") { + err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if !headerContainsToken(r.Header, "Upgrade", "WebSocket") { + err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if r.Method != "GET" { + err := fmt.Errorf("websocket protocol violation: handshake request method is not GET but %q", r.Method) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if r.Header.Get("Sec-WebSocket-Version") != "13" { + err := fmt.Errorf("unsupported websocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if r.Header.Get("Sec-WebSocket-Key") == "" { + err := errors.New("websocket protocol violation: missing Sec-WebSocket-Key") + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + return nil +} + +func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { + key := r.Header.Get("Sec-WebSocket-Key") + w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) +} + +func selectSubprotocol(r *http.Request, subprotocols []string) string { + for _, sp := range subprotocols { + if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { + return sp + } + } + return "" +} + +func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { + if mode == CompressionDisabled { + return nil, nil + } + + for _, ext := range websocketExtensions(r.Header) { + switch ext.name { + case "permessage-deflate": + return acceptDeflate(w, ext, mode) + case "x-webkit-deflate-frame": + return acceptWebkitDeflate(w, ext, mode) + } + } + return nil, nil +} + +func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { + copts := mode.opts() + + for _, p := range ext.params { + switch p { + case "client_no_context_takeover": + copts.clientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.serverNoContextTakeover = true + continue + case "client_max_window_bits", "server-max-window-bits": + continue + } + + return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + } + + copts.setHeader(w.Header()) + + return copts, nil +} + +func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { + copts := mode.opts() + // The peer must explicitly request it. + copts.serverNoContextTakeover = false + + for _, p := range ext.params { + if p == "no_context_takeover" { + copts.serverNoContextTakeover = true + continue + } + + // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead + // of ignoring it as the draft spec is unclear. It says the server can ignore it + // but the server has no way of signalling to the client it was ignored as the parameters + // are set one way. + // Thus us ignoring it would make the client think we understood it which would cause issues. + // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 + // + // Either way, we're only implementing this for webkit which never sends the max_window_bits + // parameter so we don't need to worry about it. + return nil, fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) + } + + s := "x-webkit-deflate-frame" + if copts.clientNoContextTakeover { + s += "; no_context_takeover" + } + w.Header().Set("Sec-WebSocket-Extensions", s) + + return copts, nil +} + + +func headerContainsToken(h http.Header, key, token string) bool { + token = strings.ToLower(token) + + for _, t := range headerTokens(h, key) { + if t == token { + return true + } + } + return false +} + +type websocketExtension struct { + name string + params []string +} + +func websocketExtensions(h http.Header) []websocketExtension { + var exts []websocketExtension + extStrs := headerTokens(h, "Sec-WebSocket-Extensions") + for _, extStr := range extStrs { + if extStr == "" { + continue + } + + vals := strings.Split(extStr, ";") + for i := range vals { + vals[i] = strings.TrimSpace(vals[i]) + } + + e := websocketExtension{ + name: vals[0], + params: vals[1:], + } + + exts = append(exts, e) + } + return exts +} + +func headerTokens(h http.Header, key string) []string { + key = textproto.CanonicalMIMEHeaderKey(key) + var tokens []string + for _, v := range h[key] { + v = strings.TrimSpace(v) + for _, t := range strings.Split(v, ",") { + t = strings.ToLower(t) + tokens = append(tokens, t) + } + } + return tokens +} + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func secWebSocketAccept(secWebSocketKey string) string { + h := sha1.New() + h.Write([]byte(secWebSocketKey)) + h.Write(keyGUID) + + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/handshake_test.go b/accept_test.go similarity index 62% rename from handshake_test.go rename to accept_test.go index 82f958e0..9598cd58 100644 --- a/handshake_test.go +++ b/accept_test.go @@ -1,14 +1,9 @@ -// +build !js - package websocket import ( - "context" - "net/http" "net/http/httptest" "strings" "testing" - "time" ) func TestAccept(t *testing.T) { @@ -246,141 +241,3 @@ func Test_authenticateOrigin(t *testing.T) { }) } } - -func TestBadDials(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - url string - opts *DialOptions - }{ - { - name: "badURL", - url: "://noscheme", - }, - { - name: "badURLScheme", - url: "ftp://nhooyr.io", - }, - { - name: "badHTTPClient", - url: "ws://nhooyr.io", - opts: &DialOptions{ - HTTPClient: &http.Client{ - Timeout: time.Minute, - }, - }, - }, - { - name: "badTLS", - url: "wss://totallyfake.nhooyr.io", - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - _, _, err := Dial(ctx, tc.url, tc.opts) - if err == nil { - t.Fatalf("expected non nil error: %+v", err) - } - }) - } -} - -func Test_verifyServerHandshake(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - response func(w http.ResponseWriter) - success bool - }{ - { - name: "badStatus", - response: func(w http.ResponseWriter) { - w.WriteHeader(http.StatusOK) - }, - success: false, - }, - { - name: "badConnection", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "???") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "badUpgrade", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "???") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "badSecWebSocketAccept", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Sec-WebSocket-Accept", "xd") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "badSecWebSocketProtocol", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Sec-WebSocket-Protocol", "xd") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "success", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: true, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - w := httptest.NewRecorder() - tc.response(w) - resp := w.Result() - - r := httptest.NewRequest("GET", "/", nil) - key, err := makeSecWebSocketKey() - if err != nil { - t.Fatal(err) - } - r.Header.Set("Sec-WebSocket-Key", key) - - if resp.Header.Get("Sec-WebSocket-Accept") == "" { - resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) - } - - _, err = verifyServerResponse(r, resp, &DialOptions{}) - if (err == nil) != tc.success { - t.Fatalf("unexpected error: %+v", err) - } - }) - } -} diff --git a/assert_test.go b/assert_test.go index 26fd1d48..af300999 100644 --- a/assert_test.go +++ b/assert_test.go @@ -4,6 +4,7 @@ import ( "context" "math/rand" "strings" + "testing" "time" "nhooyr.io/websocket" @@ -15,36 +16,30 @@ func init() { rand.Seed(time.Now().UnixNano()) } -func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error { +func randBytes(n int) []byte { + b := make([]byte, n) + rand.Read(b) + return b +} + +func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) { exp := randString(n) err := wsjson.Write(ctx, c, exp) - if err != nil { - return err - } + assert.Success(t, err) var act interface{} err = wsjson.Read(ctx, c, &act) - if err != nil { - return err - } + assert.Success(t, err) - return assert.Equalf(exp, act, "unexpected JSON") + assert.Equalf(t, exp, act, "unexpected JSON") } -func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { +func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { var act interface{} err := wsjson.Read(ctx, c, &act) - if err != nil { - return err - } - - return assert.Equalf(exp, act, "unexpected JSON") -} + assert.Success(t, err) -func randBytes(n int) []byte { - b := make([]byte, n) - rand.Read(b) - return b + assert.Equalf(t, exp, act, "unexpected JSON") } func randString(n int) string { @@ -60,23 +55,18 @@ func randString(n int) string { return s } -func assertEcho(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) error { +func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) { p := randBytes(n) err := c.Write(ctx, typ, p) - if err != nil { - return err - } + assert.Success(t, err) + typ2, p2, err := c.Read(ctx) - if err != nil { - return err - } - err = assert.Equalf(typ, typ2, "unexpected data type") - if err != nil { - return err - } - return assert.Equalf(p, p2, "unexpected payload") + assert.Success(t, err) + + assert.Equalf(t, typ, typ2, "unexpected data type") + assert.Equalf(t, p, p2, "unexpected payload") } -func assertSubprotocol(c *websocket.Conn, exp string) error { - return assert.Equalf(exp, c.Subprotocol(), "unexpected subprotocol") +func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { + assert.Equalf(t, exp, c.Subprotocol(), "unexpected subprotocol") } diff --git a/ci/fmt.mk b/ci/fmt.mk index 8e61bc24..3637c1ac 100644 --- a/ci/fmt.mk +++ b/ci/fmt.mk @@ -22,4 +22,4 @@ prettier: prettier --write --print-width=120 --no-semi --trailing-comma=all --loglevel=warn $$(git ls-files "*.yml" "*.md") gen: - go generate ./... + stringer -type=Opcode,MessageType,StatusCode -output=websocket_stringer.go diff --git a/close.go b/close.go new file mode 100644 index 00000000..4f48f1b3 --- /dev/null +++ b/close.go @@ -0,0 +1,181 @@ +package websocket + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "nhooyr.io/websocket/internal/wsframe" +) + +// StatusCode represents a WebSocket status code. +// https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode int + +// These codes were retrieved from: +// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// +// The defined constants only represent the status codes registered with IANA. +// The 4000-4999 range of status codes is reserved for arbitrary use by applications. +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupportedData StatusCode = 1003 + + // 1004 is reserved and so not exported. + statusReserved StatusCode = 1004 + + // StatusNoStatusRcvd cannot be sent in a close message. + // It is reserved for when a close message is received without + // an explicit status. + StatusNoStatusRcvd StatusCode = 1005 + + // StatusAbnormalClosure is only exported for use with Wasm. + // In non Wasm Go, the returned error will indicate whether the connection was closed or not or what happened. + StatusAbnormalClosure StatusCode = 1006 + + StatusInvalidFramePayloadData StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusMessageTooBig StatusCode = 1009 + StatusMandatoryExtension StatusCode = 1010 + StatusInternalError StatusCode = 1011 + StatusServiceRestart StatusCode = 1012 + StatusTryAgainLater StatusCode = 1013 + StatusBadGateway StatusCode = 1014 + + // StatusTLSHandshake is only exported for use with Wasm. + // In non Wasm Go, the returned error will indicate whether there was a TLS handshake failure. + StatusTLSHandshake StatusCode = 1015 +) + +// CloseError represents a WebSocket close frame. +// It is returned by Conn's methods when a WebSocket close frame is received from +// the peer. +// You will need to use the https://golang.org/pkg/errors/#As function, new in Go 1.13, +// to check for this error. See the CloseError example. +type CloseError struct { + Code StatusCode + Reason string +} + +func (ce CloseError) Error() string { + return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) +} + +// CloseStatus is a convenience wrapper around errors.As to grab +// the status code from a *CloseError. If the passed error is nil +// or not a *CloseError, the returned StatusCode will be -1. +func CloseStatus(err error) StatusCode { + var ce CloseError + if errors.As(err, &ce) { + return ce.Code + } + return -1 +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + code, reason, err := wsframe.ParseClosePayload(p) + if err != nil { + return CloseError{}, err + } + + ce := CloseError{ + Code: StatusCode(code), + Reason: reason, + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +func (ce CloseError) bytes() ([]byte, error) { + // TODO move check into frame write + if len(ce.Reason) > wsframe.MaxControlFramePayload-2 { + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", wsframe.MaxControlFramePayload-2, ce.Reason, len(ce.Reason)) + } + if !validWireCloseCode(ce.Code) { + return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +// CloseRead will start a goroutine to read from the connection until it is closed or a data message +// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. +// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. +// After calling this method, you cannot read any data messages from the connection. +// The returned context will be cancelled when the connection is closed. +// +// Use this when you do not want to read data messages from the connection anymore but will +// want to write messages to it. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.isReadClosed.Store(1) + + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + // We use the unexported reader method so that we don't get the read closed error. + c.reader(ctx, true) + // Either the connection is already closed since there was a read error + // or the context was cancelled or a message was read and we should close + // the connection. + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusMessageTooBig. +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit.Store(n) +} + +func (c *Conn) setCloseErr(err error) { + c.closeErrOnce.Do(func() { + c.closeErr = fmt.Errorf("websocket closed: %w", err) + }) +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/close_test.go b/close_test.go new file mode 100644 index 00000000..78096d7e --- /dev/null +++ b/close_test.go @@ -0,0 +1,196 @@ +package websocket + +import ( + "github.com/google/go-cmp/cmp" + "io" + "math" + "nhooyr.io/websocket/internal/assert" + "nhooyr.io/websocket/internal/wsframe" + "strings" + "testing" +) + +func TestCloseError(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + ce CloseError + success bool + }{ + { + name: "normal", + ce: CloseError{ + Code: StatusNormalClosure, + Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-2), + }, + success: true, + }, + { + name: "bigReason", + ce: CloseError{ + Code: StatusNormalClosure, + Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-1), + }, + success: false, + }, + { + name: "bigCode", + ce: CloseError{ + Code: math.MaxUint16, + Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-2), + }, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, err := tc.ce.bytes() + if (err == nil) != tc.success { + t.Fatalf("unexpected error value: %+v", err) + } + }) + } +} + +func Test_parseClosePayload(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + p []byte + success bool + ce CloseError + }{ + { + name: "normal", + p: append([]byte{0x3, 0xE8}, []byte("hello")...), + success: true, + ce: CloseError{ + Code: StatusNormalClosure, + Reason: "hello", + }, + }, + { + name: "nothing", + success: true, + ce: CloseError{ + Code: StatusNoStatusRcvd, + }, + }, + { + name: "oneByte", + p: []byte{0}, + success: false, + }, + { + name: "badStatusCode", + p: []byte{0x17, 0x70}, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ce, err := parseClosePayload(tc.p) + if (err == nil) != tc.success { + t.Fatalf("unexpected expected error value: %+v", err) + } + + if tc.success && tc.ce != ce { + t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) + } + }) + } +} + +func Test_validWireCloseCode(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + code StatusCode + valid bool + }{ + { + name: "normal", + code: StatusNormalClosure, + valid: true, + }, + { + name: "noStatus", + code: StatusNoStatusRcvd, + valid: false, + }, + { + name: "3000", + code: 3000, + valid: true, + }, + { + name: "4999", + code: 4999, + valid: true, + }, + { + name: "unknown", + code: 5000, + valid: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if valid := validWireCloseCode(tc.code); tc.valid != valid { + t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) + } + }) + } +} + +func TestCloseStatus(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + in error + exp StatusCode + }{ + { + name: "nil", + in: nil, + exp: -1, + }, + { + name: "io.EOF", + in: io.EOF, + exp: -1, + }, + { + name: "StatusInternalError", + in: CloseError{ + Code: StatusInternalError, + }, + exp: StatusInternalError, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + assert.Equalf(t, tc.exp, CloseStatus(tc.in), "unexpected close status") + }) + } +} diff --git a/compress.go b/compress.go new file mode 100644 index 00000000..5b5fdce5 --- /dev/null +++ b/compress.go @@ -0,0 +1,78 @@ +// +build !js + +package websocket + +import ( + "net/http" +) + +// CompressionMode controls the modes available RFC 7692's deflate extension. +// See https://tools.ietf.org/html/rfc7692 +// +// A compatibility layer is implemented for the older deflate-frame extension used +// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 +// It will work the same in every way except that we cannot signal to the peer we +// want to use no context takeover on our side, we can only signal that they should. +type CompressionMode int + +const ( + // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. + // This enables reusing the sliding window from previous messages. + // As most WebSocket protocols are repetitive, this is the default. + // + // The message will only be compressed if greater than or equal to 128 bytes. + // + // If the peer negotiates NoContextTakeover on the client or server side, it will be + // used instead as this is required by the RFC. + CompressionContextTakeover CompressionMode = iota + + // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed + // for every message. This applies to both server and client side. + // + // This means less efficient compression as the sliding window from previous messages + // will not be used but the memory overhead will be much lower if the connections + // are long lived and seldom used. + // + // The message will only be compressed if greater than or equal to 512 bytes. + CompressionNoContextTakeover + + // CompressionDisabled disables the deflate extension. + // + // Use this if you are using a predominantly binary protocol with very + // little duplication in between messages or CPU and memory are more + // important than bandwidth. + CompressionDisabled +) + +func (m CompressionMode) opts() *compressionOptions { + if m == CompressionDisabled { + return nil + } + return &compressionOptions{ + clientNoContextTakeover: m == CompressionNoContextTakeover, + serverNoContextTakeover: m == CompressionNoContextTakeover, + } +} + +type compressionOptions struct { + clientNoContextTakeover bool + serverNoContextTakeover bool +} + +func (copts *compressionOptions) setHeader(h http.Header) { + s := "permessage-deflate" + if copts.clientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.serverNoContextTakeover { + s += "; server_no_context_takeover" + } + h.Set("Sec-WebSocket-Extensions", s) +} + +// These bytes are required to get flate.Reader to return. +// They are removed when sending to avoid the overhead as +// WebSocket framing tell's when the message has ended but then +// we need to add them back otherwise flate.Reader keeps +// trying to return more bytes. +const deflateMessageTail = "\x00\x00\xff\xff" diff --git a/conn.go b/conn.go index 32dfa81b..791d9b4c 100644 --- a/conn.go +++ b/conn.go @@ -13,13 +13,28 @@ import ( "io" "io/ioutil" "log" + "nhooyr.io/websocket/internal/atomicint" + "nhooyr.io/websocket/internal/wsframe" "runtime" "strconv" + "strings" "sync" "sync/atomic" "time" - "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/bufpool" +) + +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int + +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like Protobufs. + MessageBinary ) // Conn represents a WebSocket connection. @@ -36,20 +51,20 @@ import ( // This applies to the Read methods in the wsjson/wspb subpackages as well. type Conn struct { subprotocol string - br *bufio.Reader + fw *flate.Writer bw *bufio.Writer // writeBuf is used for masking, its the buffer in bufio.Writer. // Only used by the client for masking the bytes in the buffer. writeBuf []byte closer io.Closer client bool - copts *CompressionOptions + copts *compressionOptions closeOnce sync.Once closeErrOnce sync.Once closeErr error closed chan struct{} - closing *atomicInt64 + closing *atomicint.Int64 closeReceived error // messageWriter state. @@ -61,35 +76,18 @@ type Conn struct { writeHeaderBuf []byte writeHeader *header // read limit for a message in bytes. - msgReadLimit *atomicInt64 + msgReadLimit *atomicint.Int64 // Used to ensure a previous writer is not used after being closed. activeWriter atomic.Value // messageWriter state. writeMsgOpcode opcode writeMsgCtx context.Context - readMsgLeft int64 - - // Used to ensure the previous reader is read till EOF before allowing - // a new one. - activeReader *messageReader - // readFrameLock is acquired to read from bw. - readFrameLock chan struct{} - isReadClosed *atomicInt64 - readHeaderBuf []byte - controlPayloadBuf []byte - readLock chan struct{} - - // messageReader state. - readerMsgCtx context.Context - readerMsgHeader header - readerFrameEOF bool - readerMaskKey uint32 setReadTimeout chan context.Context setWriteTimeout chan context.Context - pingCounter *atomicInt64 + pingCounter *atomicint.Int64 activePingsMu sync.Mutex activePings map[string]chan<- struct{} @@ -98,9 +96,9 @@ type Conn struct { func (c *Conn) init() { c.closed = make(chan struct{}) - c.closing = &atomicInt64{} + c.closing = &atomicint.Int64{} - c.msgReadLimit = &atomicInt64{} + c.msgReadLimit = &atomicint.Int64{} c.msgReadLimit.Store(32768) c.writeMsgLock = make(chan struct{}, 1) @@ -108,17 +106,18 @@ func (c *Conn) init() { c.readFrameLock = make(chan struct{}, 1) c.readLock = make(chan struct{}, 1) + c.payloadReader = framePayloadReader{c} c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) - c.pingCounter = &atomicInt64{} + c.pingCounter = &atomicint.Int64{} c.activePings = make(map[string]chan<- struct{}) c.writeHeaderBuf = makeWriteHeaderBuf() c.writeHeader = &header{} c.readHeaderBuf = makeReadHeaderBuf() - c.isReadClosed = &atomicInt64{} + c.isReadClosed = &atomicint.Int64{} c.controlPayloadBuf = make([]byte, maxControlFramePayload) runtime.SetFinalizer(c, func(c *Conn) { @@ -127,6 +126,15 @@ func (c *Conn) init() { c.logf = log.Printf + if c.copts != nil { + if !c.readNoContextTakeOver() { + c.fr = getFlateReader(c.payloadReader) + } + if !c.writeNoContextTakeOver() { + c.fw = getFlateWriter(c.bw) + } + } + go c.timeoutLoop() } @@ -148,19 +156,26 @@ func (c *Conn) close(err error) { // closeErr. c.closer.Close() - // See comment on bufioReaderPool in handshake.go + // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer + // and we can safely return them. + // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent + // a deadlock. + // As of now, this is in writeFrame, readFramePayload and readHeader. + c.readFrameLock <- struct{}{} if c.client { - // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer - // and we can safely return them. - // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent - // a deadlock. - // As of now, this is in writeFrame, readFramePayload and readHeader. - c.readFrameLock <- struct{}{} returnBufioReader(c.br) + } + if c.fr != nil { + putFlateReader(c.fr) + } - c.writeFrameLock <- struct{}{} + c.writeFrameLock <- struct{}{} + if c.client { returnBufioWriter(c.bw) } + if c.fw != nil { + putFlateWriter(c.fw) + } }) } @@ -230,7 +245,7 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { return header{}, err } - if h.rsv1 || h.rsv2 || h.rsv3 { + if (h.rsv1 && (c.copts == nil || h.opcode.controlOp() || h.opcode == opContinuation)) || h.rsv2 || h.rsv3 { err := fmt.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) c.exportedClose(StatusProtocolError, err.Error(), false) return header{}, err @@ -448,6 +463,13 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e c.readerMsgCtx = ctx c.readerMsgHeader = h + + c.readerPayloadCompressed = h.rsv1 + + if c.readerPayloadCompressed { + c.readerCompressTail.Reset(deflateMessageTail) + } + c.readerFrameEOF = false c.readerMaskKey = h.maskKey c.readMsgLeft = c.msgReadLimit.Load() @@ -456,9 +478,67 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e c: c, } c.activeReader = r + if c.readerPayloadCompressed && c.readNoContextTakeOver() { + c.fr = getFlateReader(c.payloadReader) + } return MessageType(h.opcode), r, nil } +type framePayloadReader struct { + c *Conn +} + +func (r framePayloadReader) Read(p []byte) (int, error) { + if r.c.readerFrameEOF { + if r.c.readerPayloadCompressed && r.c.readerMsgHeader.fin { + n, _ := r.c.readerCompressTail.Read(p) + return n, nil + } + + h, err := r.c.readTillMsg(r.c.readerMsgCtx) + if err != nil { + return 0, err + } + + if h.opcode != opContinuation { + err := errors.New("received new data message without finishing the previous message") + r.c.exportedClose(StatusProtocolError, err.Error(), false) + return 0, err + } + + r.c.readerMsgHeader = h + r.c.readerFrameEOF = false + r.c.readerMaskKey = h.maskKey + } + + h := r.c.readerMsgHeader + if int64(len(p)) > h.payloadLength { + p = p[:h.payloadLength] + } + + n, err := r.c.readFramePayload(r.c.readerMsgCtx, p) + + h.payloadLength -= int64(n) + if h.masked { + r.c.readerMaskKey = mask(r.c.readerMaskKey, p) + } + r.c.readerMsgHeader = h + + if err != nil { + return n, err + } + + if h.payloadLength == 0 { + r.c.readerFrameEOF = true + + if h.fin && !r.c.readerPayloadCompressed { + return n, io.EOF + } + } + + return n, nil +} + // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { c *Conn @@ -521,51 +601,27 @@ func (r *messageReader) read(p []byte, lock bool) (int, error) { p = p[:r.c.readMsgLeft] } - if r.c.readerFrameEOF { - h, err := r.c.readTillMsg(r.c.readerMsgCtx) - if err != nil { - return 0, err - } - - if h.opcode != opContinuation { - err := errors.New("received new data message without finishing the previous message") - r.c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, err - } - - r.c.readerMsgHeader = h - r.c.readerFrameEOF = false - r.c.readerMaskKey = h.maskKey + pr := io.Reader(r.c.payloadReader) + if r.c.readerPayloadCompressed { + pr = r.c.fr } - h := r.c.readerMsgHeader - if int64(len(p)) > h.payloadLength { - p = p[:h.payloadLength] - } - - n, err := r.c.readFramePayload(r.c.readerMsgCtx, p) + n, err := pr.Read(p) - h.payloadLength -= int64(n) r.c.readMsgLeft -= int64(n) - if h.masked { - r.c.readerMaskKey = mask(r.c.readerMaskKey, p) - } - r.c.readerMsgHeader = h - if err != nil { - return n, err - } - - if h.payloadLength == 0 { - r.c.readerFrameEOF = true - - if h.fin { - r.c.activeReader = nil - return n, io.EOF + if r.c.readerFrameEOF && r.c.readerMsgHeader.fin { + if r.c.readerPayloadCompressed && r.c.readNoContextTakeOver() { + putFlateReader(r.c.fr) + r.c.fr = nil + } + r.c.activeReader = nil + if err == nil { + err = io.EOF } } - return n, nil + return n, err } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { @@ -971,10 +1027,10 @@ func (c *Conn) waitClose() error { return c.closeReceived } - b := bpool.Get() + b := bufpool.Get() buf := b.Bytes() buf = buf[:cap(buf)] - defer bpool.Put(b) + defer bufpool.Put(b) for { if c.activeReader == nil || c.readerFrameEOF { @@ -1065,40 +1121,21 @@ func (c *Conn) extractBufioWriterBuf(w io.Writer) { c.bw.Reset(w) } -var flateWriterPoolsMu sync.Mutex -var flateWriterPools = make(map[int]*sync.Pool) - -func getFlateWriterPool(level int) *sync.Pool { - flateWriterPoolsMu.Lock() - defer flateWriterPoolsMu.Unlock() - - p, ok := flateWriterPools[level] - if !ok { - p = &sync.Pool{ - New: func() interface{} { - w, err := flate.NewWriter(nil, level) - if err != nil { - panic("websocket: unexpected error from flate.NewWriter: " + err.Error()) - } - return w - }, - } - flateWriterPools[level] = p - } - - return p +var flateWriterPool = &sync.Pool{ + New: func() interface{} { + w, _ := flate.NewWriter(nil, flate.BestSpeed) + return w + }, } -func getFlateWriter(w io.Writer, level int) *flate.Writer { - p := getFlateWriterPool(level) - fw := p.Get().(*flate.Writer) +func getFlateWriter(w io.Writer) *flate.Writer { + fw := flateWriterPool.Get().(*flate.Writer) fw.Reset(w) return fw } -func putFlateWriter(w *flate.Writer, level int) { - p := getFlateWriterPool(level) - p.Put(w) +func putFlateWriter(w *flate.Writer) { + flateWriterPool.Put(w) } var flateReaderPool = &sync.Pool{ @@ -1107,12 +1144,60 @@ var flateReaderPool = &sync.Pool{ }, } -func getFlateReader(r flate.Reader) io.ReadCloser { - fr := flateReaderPool.Get().(io.ReadCloser) +func getFlateReader(r io.Reader) io.Reader { + fr := flateReaderPool.Get().(io.Reader) fr.(flate.Resetter).Reset(r, nil) return fr } -func putFlateReader(fr io.ReadCloser) { +func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } + +func (c *Conn) writeNoContextTakeOver() bool { + return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover +} + +func (c *Conn) readNoContextTakeOver() bool { + return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover +} + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (w *trimLastFourBytesWriter) Write(p []byte) (int, error) { + extra := len(w.tail) + len(p) - 4 + + if extra <= 0 { + w.tail = append(w.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(w.tail) { + extra = len(w.tail) + } + if extra > 0 { + _, err := w.Write(w.tail[:extra]) + if err != nil { + return 0, err + } + w.tail = w.tail[extra:] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + w.tail = append(w.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + w.tail = append(w.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := w.w.Write(p) + return n + 4, err +} diff --git a/dial.go b/dial.go new file mode 100644 index 00000000..10088681 --- /dev/null +++ b/dial.go @@ -0,0 +1,219 @@ +package websocket + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "nhooyr.io/websocket/internal/bufpool" + "strings" +) + +// DialOptions represents the options available to pass to Dial. +type DialOptions struct { + // HTTPClient is the http client used for the handshake. + // Its Transport must return writable bodies + // for WebSocket handshakes. + // http.Transport does this correctly beginning with Go 1.12. + HTTPClient *http.Client + + // HTTPHeader specifies the HTTP headers included in the handshake request. + HTTPHeader http.Header + + // Subprotocols lists the subprotocols to negotiate with the server. + Subprotocols []string + + // See docs on CompressionMode. + CompressionMode CompressionMode +} + +// Dial performs a WebSocket handshake on the given url with the given options. +// The response is the WebSocket handshake response from the server. +// If an error occurs, the returned response may be non nil. However, you can only +// read the first 1024 bytes of its body. +// +// You never need to close the resp.Body yourself. +// +// This function requires at least Go 1.12 to succeed as it uses a new feature +// in net/http to perform WebSocket handshakes and get a writable body +// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 +func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { + c, r, err := dial(ctx, u, opts) + if err != nil { + return nil, r, fmt.Errorf("failed to websocket dial: %w", err) + } + return c, r, nil +} + +func (opts *DialOptions) fill() (*DialOptions, error) { + if opts == nil { + opts = &DialOptions{} + } else { + opts = &*opts + } + + if opts.HTTPClient == nil { + opts.HTTPClient = http.DefaultClient + } + if opts.HTTPClient.Timeout > 0 { + return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") + } + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} + } + + return opts, nil +} + +func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { + opts, err = opts.fill() + if err != nil { + return nil, nil, err + } + + parsedURL, err := url.Parse(u) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse url: %w", err) + } + + switch parsedURL.Scheme { + case "ws": + parsedURL.Scheme = "http" + case "wss": + parsedURL.Scheme = "https" + default: + return nil, nil, fmt.Errorf("unexpected url scheme: %q", parsedURL.Scheme) + } + + req, _ := http.NewRequest("GET", parsedURL.String(), nil) + req = req.WithContext(ctx) + req.Header = opts.HTTPHeader + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + secWebSocketKey, err := secWebSocketKey() + if err != nil { + return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) + } + req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) + if len(opts.Subprotocols) > 0 { + req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) + } + copts := opts.CompressionMode.opts() + copts.setHeader(req.Header) + + resp, err := opts.HTTPClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("failed to send handshake request: %w", err) + } + defer func() { + if err != nil { + // We read a bit of the body for easier debugging. + r := io.LimitReader(resp.Body, 1024) + b, _ := ioutil.ReadAll(r) + resp.Body.Close() + resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + } + }() + + copts, err = verifyServerResponse(req, resp, opts) + if err != nil { + return nil, resp, err + } + + rwc, ok := resp.Body.(io.ReadWriteCloser) + if !ok { + return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", rwc) + } + + c := &Conn{ + subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), + br: bufpool.GetReader(rwc), + bw: bufpool.GetWriter(rwc), + closer: rwc, + client: true, + copts: copts, + } + c.extractBufioWriterBuf(rwc) + c.init() + + return c, resp, nil +} + +func secWebSocketKey() (string, error) { + b := make([]byte, 16) + _, err := io.ReadFull(rand.Reader, b) + if err != nil { + return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil +} + +func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*compressionOptions, error) { + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + + if !headerContainsToken(resp.Header, "Connection", "Upgrade") { + return nil, fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + } + + if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { + return nil, fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + } + + if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { + return nil, fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + resp.Header.Get("Sec-WebSocket-Accept"), + r.Header.Get("Sec-WebSocket-Key"), + ) + } + + if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { + return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) + } + + copts, err := verifyServerExtensions(resp.Header, opts.CompressionMode) + if err != nil { + return nil, err + } + + return copts, nil +} + +func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOptions, error) { + exts := websocketExtensions(h) + if len(exts) == 0 { + return nil, nil + } + + ext := exts[0] + if ext.name != "permessage-deflate" { + return nil, fmt.Errorf("unexpected extension from server: %q", ext) + } + + if len(exts) > 1 { + return nil, fmt.Errorf("unexpected extra extensions from server: %+v", exts[1:]) + } + + copts := mode.opts() + for _, p := range ext.params { + switch p { + case "client_no_context_takeover": + copts.clientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.serverNoContextTakeover = true + continue + } + + return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + } + + return copts, nil +} diff --git a/dial_test.go b/dial_test.go new file mode 100644 index 00000000..391aa1ce --- /dev/null +++ b/dial_test.go @@ -0,0 +1,149 @@ +// +build !js + +package websocket + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestBadDials(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + url string + opts *DialOptions + }{ + { + name: "badURL", + url: "://noscheme", + }, + { + name: "badURLScheme", + url: "ftp://nhooyr.io", + }, + { + name: "badHTTPClient", + url: "ws://nhooyr.io", + opts: &DialOptions{ + HTTPClient: &http.Client{ + Timeout: time.Minute, + }, + }, + }, + { + name: "badTLS", + url: "wss://totallyfake.nhooyr.io", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := Dial(ctx, tc.url, tc.opts) + if err == nil { + t.Fatalf("expected non nil error: %+v", err) + } + }) + } +} + +func Test_verifyServerHandshake(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + response func(w http.ResponseWriter) + success bool + }{ + { + name: "badStatus", + response: func(w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) + }, + success: false, + }, + { + name: "badConnection", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "???") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badUpgrade", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "???") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badSecWebSocketAccept", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Accept", "xd") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badSecWebSocketProtocol", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Protocol", "xd") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "success", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + tc.response(w) + resp := w.Result() + + r := httptest.NewRequest("GET", "/", nil) + key, err := secWebSocketKey() + if err != nil { + t.Fatal(err) + } + r.Header.Set("Sec-WebSocket-Key", key) + + if resp.Header.Get("Sec-WebSocket-Accept") == "" { + resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + } + + _, err = verifyServerResponse(r, resp, &DialOptions{}) + if (err == nil) != tc.success { + t.Fatalf("unexpected error: %+v", err) + } + }) + } +} diff --git a/doc.go b/doc.go index 804665fb..5285a780 100644 --- a/doc.go +++ b/doc.go @@ -1,6 +1,6 @@ // +build !js -// Package websocket is a minimal and idiomatic implementation of the WebSocket protocol. +// Package websocket implements the RFC 6455 WebSocket protocol. // // https://tools.ietf.org/html/rfc6455 // diff --git a/frame.go b/frame.go deleted file mode 100644 index e4bf931a..00000000 --- a/frame.go +++ /dev/null @@ -1,445 +0,0 @@ -package websocket - -import ( - "encoding/binary" - "errors" - "fmt" - "io" - "math" - "math/bits" -) - -//go:generate stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go - -// opcode represents a WebSocket Opcode. -type opcode int - -// opcode constants. -const ( - opContinuation opcode = iota - opText - opBinary - // 3 - 7 are reserved for further non-control frames. - _ - _ - _ - _ - _ - opClose - opPing - opPong - // 11-16 are reserved for further control frames. -) - -func (o opcode) controlOp() bool { - switch o { - case opClose, opPing, opPong: - return true - } - return false -} - -// MessageType represents the type of a WebSocket message. -// See https://tools.ietf.org/html/rfc6455#section-5.6 -type MessageType int - -// MessageType constants. -const ( - // MessageText is for UTF-8 encoded text messages like JSON. - MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like Protobufs. - MessageBinary -) - -// First byte contains fin, rsv1, rsv2, rsv3. -// Second byte contains mask flag and payload len. -// Next 8 bytes are the maximum extended payload length. -// Last 4 bytes are the mask key. -// https://tools.ietf.org/html/rfc6455#section-5.2 -const maxHeaderSize = 1 + 1 + 8 + 4 - -// header represents a WebSocket frame header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -type header struct { - fin bool - rsv1 bool - rsv2 bool - rsv3 bool - opcode opcode - - payloadLength int64 - - masked bool - maskKey uint32 -} - -func makeWriteHeaderBuf() []byte { - return make([]byte, maxHeaderSize) -} - -// bytes returns the bytes of the header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func writeHeader(b []byte, h header) []byte { - if b == nil { - b = makeWriteHeaderBuf() - } - - b = b[:2] - b[0] = 0 - - if h.fin { - b[0] |= 1 << 7 - } - if h.rsv1 { - b[0] |= 1 << 6 - } - if h.rsv2 { - b[0] |= 1 << 5 - } - if h.rsv3 { - b[0] |= 1 << 4 - } - - b[0] |= byte(h.opcode) - - switch { - case h.payloadLength < 0: - panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength)) - case h.payloadLength <= 125: - b[1] = byte(h.payloadLength) - case h.payloadLength <= math.MaxUint16: - b[1] = 126 - b = b[:len(b)+2] - binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength)) - default: - b[1] = 127 - b = b[:len(b)+8] - binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.payloadLength)) - } - - if h.masked { - b[1] |= 1 << 7 - b = b[:len(b)+4] - binary.LittleEndian.PutUint32(b[len(b)-4:], h.maskKey) - } - - return b -} - -func makeReadHeaderBuf() []byte { - return make([]byte, maxHeaderSize-2) -} - -// readHeader reads a header from the reader. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func readHeader(b []byte, r io.Reader) (header, error) { - if b == nil { - b = makeReadHeaderBuf() - } - - // We read the first two bytes first so that we know - // exactly how long the header is. - b = b[:2] - _, err := io.ReadFull(r, b) - if err != nil { - return header{}, err - } - - var h header - h.fin = b[0]&(1<<7) != 0 - h.rsv1 = b[0]&(1<<6) != 0 - h.rsv2 = b[0]&(1<<5) != 0 - h.rsv3 = b[0]&(1<<4) != 0 - - h.opcode = opcode(b[0] & 0xf) - - var extra int - - h.masked = b[1]&(1<<7) != 0 - if h.masked { - extra += 4 - } - - payloadLength := b[1] &^ (1 << 7) - switch { - case payloadLength < 126: - h.payloadLength = int64(payloadLength) - case payloadLength == 126: - extra += 2 - case payloadLength == 127: - extra += 8 - } - - if extra == 0 { - return h, nil - } - - b = b[:extra] - _, err = io.ReadFull(r, b) - if err != nil { - return header{}, err - } - - switch { - case payloadLength == 126: - h.payloadLength = int64(binary.BigEndian.Uint16(b)) - b = b[2:] - case payloadLength == 127: - h.payloadLength = int64(binary.BigEndian.Uint64(b)) - if h.payloadLength < 0 { - return header{}, fmt.Errorf("header with negative payload length: %v", h.payloadLength) - } - b = b[8:] - } - - if h.masked { - h.maskKey = binary.LittleEndian.Uint32(b) - } - - return h, nil -} - -// StatusCode represents a WebSocket status code. -// https://tools.ietf.org/html/rfc6455#section-7.4 -type StatusCode int - -// These codes were retrieved from: -// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// -// The defined constants only represent the status codes registered with IANA. -// The 4000-4999 range of status codes is reserved for arbitrary use by applications. -const ( - StatusNormalClosure StatusCode = 1000 - StatusGoingAway StatusCode = 1001 - StatusProtocolError StatusCode = 1002 - StatusUnsupportedData StatusCode = 1003 - - // 1004 is reserved and so not exported. - statusReserved StatusCode = 1004 - - // StatusNoStatusRcvd cannot be sent in a close message. - // It is reserved for when a close message is received without - // an explicit status. - StatusNoStatusRcvd StatusCode = 1005 - - // StatusAbnormalClosure is only exported for use with Wasm. - // In non Wasm Go, the returned error will indicate whether the connection was closed or not or what happened. - StatusAbnormalClosure StatusCode = 1006 - - StatusInvalidFramePayloadData StatusCode = 1007 - StatusPolicyViolation StatusCode = 1008 - StatusMessageTooBig StatusCode = 1009 - StatusMandatoryExtension StatusCode = 1010 - StatusInternalError StatusCode = 1011 - StatusServiceRestart StatusCode = 1012 - StatusTryAgainLater StatusCode = 1013 - StatusBadGateway StatusCode = 1014 - - // StatusTLSHandshake is only exported for use with Wasm. - // In non Wasm Go, the returned error will indicate whether there was a TLS handshake failure. - StatusTLSHandshake StatusCode = 1015 -) - -// CloseError represents a WebSocket close frame. -// It is returned by Conn's methods when a WebSocket close frame is received from -// the peer. -// You will need to use the https://golang.org/pkg/errors/#As function, new in Go 1.13, -// to check for this error. See the CloseError example. -type CloseError struct { - Code StatusCode - Reason string -} - -func (ce CloseError) Error() string { - return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) -} - -// CloseStatus is a convenience wrapper around errors.As to grab -// the status code from a *CloseError. If the passed error is nil -// or not a *CloseError, the returned StatusCode will be -1. -func CloseStatus(err error) StatusCode { - var ce CloseError - if errors.As(err, &ce) { - return ce.Code - } - return -1 -} - -func parseClosePayload(p []byte) (CloseError, error) { - if len(p) == 0 { - return CloseError{ - Code: StatusNoStatusRcvd, - }, nil - } - - if len(p) < 2 { - return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) - } - - ce := CloseError{ - Code: StatusCode(binary.BigEndian.Uint16(p)), - Reason: string(p[2:]), - } - - if !validWireCloseCode(ce.Code) { - return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) - } - - return ce, nil -} - -// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// and https://tools.ietf.org/html/rfc6455#section-7.4.1 -func validWireCloseCode(code StatusCode) bool { - switch code { - case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: - return false - } - - if code >= StatusNormalClosure && code <= StatusBadGateway { - return true - } - if code >= 3000 && code <= 4999 { - return true - } - - return false -} - -const maxControlFramePayload = 125 - -func (ce CloseError) bytes() ([]byte, error) { - if len(ce.Reason) > maxControlFramePayload-2 { - return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, ce.Reason, len(ce.Reason)) - } - if !validWireCloseCode(ce.Code) { - return nil, fmt.Errorf("status code %v cannot be set", ce.Code) - } - - buf := make([]byte, 2+len(ce.Reason)) - binary.BigEndian.PutUint16(buf, uint16(ce.Code)) - copy(buf[2:], ce.Reason) - return buf, nil -} - -// fastMask applies the WebSocket masking algorithm to p -// with the given key. -// See https://tools.ietf.org/html/rfc6455#section-5.3 -// -// The returned value is the correctly rotated key to -// to continue to mask/unmask the message. -// -// It is optimized for LittleEndian and expects the key -// to be in little endian. -// -// See https://github.com/golang/go/issues/31586 -func mask(key uint32, b []byte) uint32 { - if len(b) >= 8 { - key64 := uint64(key)<<32 | uint64(key) - - // At some point in the future we can clean these unrolled loops up. - // See https://github.com/golang/go/issues/31586#issuecomment-487436401 - - // Then we xor until b is less than 128 bytes. - for len(b) >= 128 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - v = binary.LittleEndian.Uint64(b[64:72]) - binary.LittleEndian.PutUint64(b[64:72], v^key64) - v = binary.LittleEndian.Uint64(b[72:80]) - binary.LittleEndian.PutUint64(b[72:80], v^key64) - v = binary.LittleEndian.Uint64(b[80:88]) - binary.LittleEndian.PutUint64(b[80:88], v^key64) - v = binary.LittleEndian.Uint64(b[88:96]) - binary.LittleEndian.PutUint64(b[88:96], v^key64) - v = binary.LittleEndian.Uint64(b[96:104]) - binary.LittleEndian.PutUint64(b[96:104], v^key64) - v = binary.LittleEndian.Uint64(b[104:112]) - binary.LittleEndian.PutUint64(b[104:112], v^key64) - v = binary.LittleEndian.Uint64(b[112:120]) - binary.LittleEndian.PutUint64(b[112:120], v^key64) - v = binary.LittleEndian.Uint64(b[120:128]) - binary.LittleEndian.PutUint64(b[120:128], v^key64) - b = b[128:] - } - - // Then we xor until b is less than 64 bytes. - for len(b) >= 64 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - b = b[64:] - } - - // Then we xor until b is less than 32 bytes. - for len(b) >= 32 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - b = b[32:] - } - - // Then we xor until b is less than 16 bytes. - for len(b) >= 16 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - b = b[16:] - } - - // Then we xor until b is less than 8 bytes. - for len(b) >= 8 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - b = b[8:] - } - } - - // Then we xor until b is less than 4 bytes. - for len(b) >= 4 { - v := binary.LittleEndian.Uint32(b) - binary.LittleEndian.PutUint32(b, v^key) - b = b[4:] - } - - // xor remaining bytes. - for i := range b { - b[i] ^= byte(key) - key = bits.RotateLeft32(key, -8) - } - - return key -} diff --git a/frame_test.go b/frame_test.go deleted file mode 100644 index 571e68fc..00000000 --- a/frame_test.go +++ /dev/null @@ -1,457 +0,0 @@ -// +build !js - -package websocket - -import ( - "bytes" - "encoding/binary" - "io" - "math" - "math/bits" - "math/rand" - "strconv" - "strings" - "testing" - "time" - _ "unsafe" - - "github.com/gobwas/ws" - "github.com/google/go-cmp/cmp" - _ "github.com/gorilla/websocket" - - "nhooyr.io/websocket/internal/assert" -) - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func randBool() bool { - return rand.Intn(1) == 0 -} - -func TestHeader(t *testing.T) { - t.Parallel() - - t.Run("eof", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - bytes []byte - }{ - { - "start", - []byte{0xff}, - }, - { - "middle", - []byte{0xff, 0xff, 0xff}, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - b := bytes.NewBuffer(tc.bytes) - _, err := readHeader(nil, b) - if io.ErrUnexpectedEOF != err { - t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) - } - }) - } - }) - - t.Run("writeNegativeLength", func(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r == nil { - t.Fatal("failed to induce panic in writeHeader with negative payload length") - } - }() - - writeHeader(nil, header{ - payloadLength: -1, - }) - }) - - t.Run("readNegativeLength", func(t *testing.T) { - t.Parallel() - - b := writeHeader(nil, header{ - payloadLength: 1<<16 + 1, - }) - - // Make length negative - b[2] |= 1 << 7 - - r := bytes.NewReader(b) - _, err := readHeader(nil, r) - if err == nil { - t.Fatalf("unexpected error value: %+v", err) - } - }) - - t.Run("lengths", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 124, - 125, - 126, - 4096, - 16384, - 65535, - 65536, - 65537, - 131072, - } - - for _, n := range lengths { - n := n - t.Run(strconv.Itoa(n), func(t *testing.T) { - t.Parallel() - - testHeader(t, header{ - payloadLength: int64(n), - }) - }) - } - }) - - t.Run("fuzz", func(t *testing.T) { - t.Parallel() - - for i := 0; i < 10000; i++ { - h := header{ - fin: randBool(), - rsv1: randBool(), - rsv2: randBool(), - rsv3: randBool(), - opcode: opcode(rand.Intn(1 << 4)), - - masked: randBool(), - payloadLength: rand.Int63(), - } - - if h.masked { - h.maskKey = rand.Uint32() - } - - testHeader(t, h) - } - }) -} - -func testHeader(t *testing.T, h header) { - b := writeHeader(nil, h) - r := bytes.NewReader(b) - h2, err := readHeader(nil, r) - if err != nil { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("failed to read header: %v", err) - } - - if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{}))) - } -} - -func TestCloseError(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - ce CloseError - success bool - }{ - { - name: "normal", - ce: CloseError{ - Code: StatusNormalClosure, - Reason: strings.Repeat("x", maxControlFramePayload-2), - }, - success: true, - }, - { - name: "bigReason", - ce: CloseError{ - Code: StatusNormalClosure, - Reason: strings.Repeat("x", maxControlFramePayload-1), - }, - success: false, - }, - { - name: "bigCode", - ce: CloseError{ - Code: math.MaxUint16, - Reason: strings.Repeat("x", maxControlFramePayload-2), - }, - success: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - _, err := tc.ce.bytes() - if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %+v", err) - } - }) - } -} - -func Test_parseClosePayload(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - p []byte - success bool - ce CloseError - }{ - { - name: "normal", - p: append([]byte{0x3, 0xE8}, []byte("hello")...), - success: true, - ce: CloseError{ - Code: StatusNormalClosure, - Reason: "hello", - }, - }, - { - name: "nothing", - success: true, - ce: CloseError{ - Code: StatusNoStatusRcvd, - }, - }, - { - name: "oneByte", - p: []byte{0}, - success: false, - }, - { - name: "badStatusCode", - p: []byte{0x17, 0x70}, - success: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ce, err := parseClosePayload(tc.p) - if (err == nil) != tc.success { - t.Fatalf("unexpected expected error value: %+v", err) - } - - if tc.success && tc.ce != ce { - t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) - } - }) - } -} - -func Test_validWireCloseCode(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - code StatusCode - valid bool - }{ - { - name: "normal", - code: StatusNormalClosure, - valid: true, - }, - { - name: "noStatus", - code: StatusNoStatusRcvd, - valid: false, - }, - { - name: "3000", - code: 3000, - valid: true, - }, - { - name: "4999", - code: 4999, - valid: true, - }, - { - name: "unknown", - code: 5000, - valid: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - if valid := validWireCloseCode(tc.code); tc.valid != valid { - t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) - } - }) - } -} - -func Test_mask(t *testing.T) { - t.Parallel() - - key := []byte{0xa, 0xb, 0xc, 0xff} - key32 := binary.LittleEndian.Uint32(key) - p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} - gotKey32 := mask(key32, p) - - if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { - t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) - } - - if exp := bits.RotateLeft32(key32, -8); !cmp.Equal(exp, gotKey32) { - t.Fatalf("unexpected mask key: %v", cmp.Diff(exp, gotKey32)) - } -} - -func basicMask(maskKey [4]byte, pos int, b []byte) int { - for i := range b { - b[i] ^= maskKey[pos&3] - pos++ - } - return pos & 3 -} - -//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes -func gorillaMaskBytes(key [4]byte, pos int, b []byte) int - -func Benchmark_mask(b *testing.B) { - sizes := []int{ - 2, - 3, - 4, - 8, - 16, - 32, - 128, - 512, - 4096, - 16384, - } - - fns := []struct { - name string - fn func(b *testing.B, key [4]byte, p []byte) - }{ - { - name: "basic", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - basicMask(key, 0, p) - } - }, - }, - - { - name: "nhooyr", - fn: func(b *testing.B, key [4]byte, p []byte) { - key32 := binary.LittleEndian.Uint32(key[:]) - b.ResetTimer() - - for i := 0; i < b.N; i++ { - mask(key32, p) - } - }, - }, - { - name: "gorilla", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - gorillaMaskBytes(key, 0, p) - } - }, - }, - { - name: "gobwas", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - ws.Cipher(p, key, 0) - } - }, - }, - } - - var key [4]byte - _, err := rand.Read(key[:]) - if err != nil { - b.Fatalf("failed to populate mask key: %v", err) - } - - for _, size := range sizes { - p := make([]byte, size) - - b.Run(strconv.Itoa(size), func(b *testing.B) { - for _, fn := range fns { - b.Run(fn.name, func(b *testing.B) { - b.SetBytes(int64(size)) - - fn.fn(b, key, p) - }) - } - }) - } -} - -func TestCloseStatus(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - in error - exp StatusCode - }{ - { - name: "nil", - in: nil, - exp: -1, - }, - { - name: "io.EOF", - in: io.EOF, - exp: -1, - }, - { - name: "StatusInternalError", - in: CloseError{ - Code: StatusInternalError, - }, - exp: StatusInternalError, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - err := assert.Equalf(tc.exp, CloseStatus(tc.in), "unexpected close status") - if err != nil { - t.Fatal(err) - } - }) - } -} diff --git a/handshake.go b/handshake.go deleted file mode 100644 index 03331039..00000000 --- a/handshake.go +++ /dev/null @@ -1,637 +0,0 @@ -// +build !js - -package websocket - -import ( - "bufio" - "bytes" - "context" - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "errors" - "fmt" - "io" - "io/ioutil" - "net/http" - "net/textproto" - "net/url" - "strings" - "sync" -) - -// AcceptOptions represents the options available to pass to Accept. -type AcceptOptions struct { - // Subprotocols lists the websocket subprotocols that Accept will negotiate with a client. - // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to - // reject it, close the connection if c.Subprotocol() == "". - Subprotocols []string - - // InsecureSkipVerify disables Accept's origin verification - // behaviour. By default Accept only allows the handshake to - // succeed if the javascript that is initiating the handshake - // is on the same domain as the server. This is to prevent CSRF - // attacks when secure data is stored in a cookie as there is no same - // origin policy for WebSockets. In other words, javascript from - // any domain can perform a WebSocket dial on an arbitrary server. - // This dial will include cookies which means the arbitrary javascript - // can perform actions as the authenticated user. - // - // See https://stackoverflow.com/a/37837709/4283659 - // - // The only time you need this is if your javascript is running on a different domain - // than your WebSocket server. - // Think carefully about whether you really need this option before you use it. - // If you do, remember that if you store secure data in cookies, you wil need to verify the - // Origin header yourself otherwise you are exposing yourself to a CSRF attack. - InsecureSkipVerify bool - - // Compression sets the compression options. - // By default, compression is disabled. - // See docs on the CompressionOptions type. - Compression *CompressionOptions -} - -func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { - if !r.ProtoAtLeast(1, 1) { - err := fmt.Errorf("websocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if !headerContainsToken(r.Header, "Connection", "Upgrade") { - err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if !headerContainsToken(r.Header, "Upgrade", "WebSocket") { - err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if r.Method != "GET" { - err := fmt.Errorf("websocket protocol violation: handshake request method is not GET but %q", r.Method) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if r.Header.Get("Sec-WebSocket-Version") != "13" { - err := fmt.Errorf("unsupported websocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if r.Header.Get("Sec-WebSocket-Key") == "" { - err := errors.New("websocket protocol violation: missing Sec-WebSocket-Key") - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - return nil -} - -// Accept accepts a WebSocket handshake from a client and upgrades the -// the connection to a WebSocket. -// -// Accept will reject the handshake if the Origin domain is not the same as the Host unless -// the InsecureSkipVerify option is set. In other words, by default it does not allow -// cross origin requests. -// -// If an error occurs, Accept will always write an appropriate response so you do not -// have to. -func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - c, err := accept(w, r, opts) - if err != nil { - return nil, fmt.Errorf("failed to accept websocket connection: %w", err) - } - return c, nil -} - -func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - if opts == nil { - opts = &AcceptOptions{} - } - - err := verifyClientRequest(w, r) - if err != nil { - return nil, err - } - - if !opts.InsecureSkipVerify { - err = authenticateOrigin(r) - if err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return nil, err - } - } - - hj, ok := w.(http.Hijacker) - if !ok { - err = errors.New("passed ResponseWriter does not implement http.Hijacker") - http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) - return nil, err - } - - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Connection", "Upgrade") - - handleSecWebSocketKey(w, r) - - subproto := selectSubprotocol(r, opts.Subprotocols) - if subproto != "" { - w.Header().Set("Sec-WebSocket-Protocol", subproto) - } - - var copts *CompressionOptions - if opts.Compression != nil { - copts, err = negotiateCompression(r.Header, opts.Compression) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err - } - if copts != nil { - copts.setHeader(w.Header(), false) - } - } - - w.WriteHeader(http.StatusSwitchingProtocols) - - netConn, brw, err := hj.Hijack() - if err != nil { - err = fmt.Errorf("failed to hijack connection: %w", err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return nil, err - } - - // https://github.com/golang/go/issues/32314 - b, _ := brw.Reader.Peek(brw.Reader.Buffered()) - brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - - c := &Conn{ - subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), - br: brw.Reader, - bw: brw.Writer, - closer: netConn, - copts: copts, - } - c.init() - - return c, nil -} - -func headerContainsToken(h http.Header, key, token string) bool { - key = textproto.CanonicalMIMEHeaderKey(key) - - token = strings.ToLower(token) - match := func(t string) bool { - return t == token - } - - for _, v := range h[key] { - if searchHeaderTokens(v, match) { - return true - } - } - - return false -} - -// readCompressionExtensionHeader extracts compression extension info from h. -// The standard says we should support multiple compression extension configurations -// from the client but we don't need to as there is only a single deflate extension -// and we support every configuration without error so we only need to check the first -// and thus preferred configuration. -func readCompressionExtensionHeader(h http.Header) (xWebkitDeflateFrame bool, params []string, ok bool) { - match := func(t string) bool { - vals := strings.Split(t, ";") - for i := range vals { - vals[i] = strings.TrimSpace(vals[i]) - } - params = vals[1:] - - if vals[0] == "permessage-deflate" { - return true - } - - // See https://bugs.webkit.org/show_bug.cgi?id=115504 - if vals[0] == "x-webkit-deflate-frame" { - xWebkitDeflateFrame = true - return true - } - - return false - } - - key := textproto.CanonicalMIMEHeaderKey("Sec-WebSocket-Extensions") - for _, v := range h[key] { - if searchHeaderTokens(v, match) { - return xWebkitDeflateFrame, params, true - } - } - - return false, nil, false -} - -func searchHeaderTokens(v string, match func(val string) bool) bool { - v = strings.ToLower(v) - v = strings.TrimSpace(v) - - for _, v2 := range strings.Split(v, ",") { - v2 = strings.TrimSpace(v2) - if match(v2) { - return true - } - } - - return false -} - -func selectSubprotocol(r *http.Request, subprotocols []string) string { - for _, sp := range subprotocols { - if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { - return sp - } - } - return "" -} - -var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - -func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { - key := r.Header.Get("Sec-WebSocket-Key") - w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) -} - -func secWebSocketAccept(secWebSocketKey string) string { - h := sha1.New() - h.Write([]byte(secWebSocketKey)) - h.Write(keyGUID) - - return base64.StdEncoding.EncodeToString(h.Sum(nil)) -} - -func authenticateOrigin(r *http.Request) error { - origin := r.Header.Get("Origin") - if origin == "" { - return nil - } - u, err := url.Parse(origin) - if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) - } - if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) - } - return nil -} - -// DialOptions represents the options available to pass to Dial. -type DialOptions struct { - // HTTPClient is the http client used for the handshake. - // Its Transport must return writable bodies - // for WebSocket handshakes. - // http.Transport does this correctly beginning with Go 1.12. - HTTPClient *http.Client - - // HTTPHeader specifies the HTTP headers included in the handshake request. - HTTPHeader http.Header - - // Subprotocols lists the subprotocols to negotiate with the server. - Subprotocols []string - - // Compression sets the compression options. - // By default, compression is disabled. - // See docs on the CompressionOptions type. - Compression *CompressionOptions -} - -// CompressionOptions describes the available compression options. -// -// See https://tools.ietf.org/html/rfc7692 -// -// The NoContextTakeover variables control whether a flate.Writer or flate.Reader is allocated -// for every connection (context takeover) versus shared from a pool (no context takeover). -// -// The advantage to context takeover is more efficient compression as the sliding window from previous -// messages will be used instead of being reset between every message. -// -// The advantage to no context takeover is that the flate structures are allocated as needed -// and shared between connections instead of giving each connection a fixed flate.Writer and -// flate.Reader. -// -// See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression. -// -// Enabling compression will increase memory and CPU usage and should -// be profiled before enabling in production. -// See https://github.com/gorilla/websocket/issues/203 -// -// This API is experimental and subject to change. -type CompressionOptions struct { - // ClientNoContextTakeover controls whether the client should use context takeover. - // See docs on CompressionOptions for discussion regarding context takeover. - // - // If set by the server, will guarantee that the client does not use context takeover. - ClientNoContextTakeover bool - - // ServerNoContextTakeover controls whether the server should use context takeover. - // See docs on CompressionOptions for discussion regarding context takeover. - // - // If set by the client, will guarantee that the server does not use context takeover. - ServerNoContextTakeover bool - - // Level controls the compression level used. - // Defaults to flate.BestSpeed. - Level int - - // Threshold controls the minimum message size in bytes before compression is used. - // Must not be greater than 4096 as that is the write buffer's size. - // - // Defaults to 256. - Threshold int - - // This is used for supporting Safari as it still uses x-webkit-deflate-frame. - // See negotiateCompression. - xWebkitDeflateFrame bool -} - -// Dial performs a WebSocket handshake on the given url with the given options. -// The response is the WebSocket handshake response from the server. -// If an error occurs, the returned response may be non nil. However, you can only -// read the first 1024 bytes of its body. -// -// You never need to close the resp.Body yourself. -// -// This function requires at least Go 1.12 to succeed as it uses a new feature -// in net/http to perform WebSocket handshakes and get a writable body -// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 -func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { - c, r, err := dial(ctx, u, opts) - if err != nil { - return nil, r, fmt.Errorf("failed to websocket dial: %w", err) - } - return c, r, nil -} - -func (opts *DialOptions) ensure() (*DialOptions, error) { - if opts == nil { - opts = &DialOptions{} - } else { - opts = &*opts - } - - if opts.HTTPClient == nil { - opts.HTTPClient = http.DefaultClient - } - if opts.HTTPClient.Timeout > 0 { - return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } - if opts.HTTPHeader == nil { - opts.HTTPHeader = http.Header{} - } - - return opts, nil -} - -func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { - opts, err = opts.ensure() - if err != nil { - return nil, nil, err - } - - parsedURL, err := url.Parse(u) - if err != nil { - return nil, nil, fmt.Errorf("failed to parse url: %w", err) - } - - switch parsedURL.Scheme { - case "ws": - parsedURL.Scheme = "http" - case "wss": - parsedURL.Scheme = "https" - default: - return nil, nil, fmt.Errorf("unexpected url scheme: %q", parsedURL.Scheme) - } - - req, _ := http.NewRequest("GET", parsedURL.String(), nil) - req = req.WithContext(ctx) - req.Header = opts.HTTPHeader - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Sec-WebSocket-Version", "13") - secWebSocketKey, err := makeSecWebSocketKey() - if err != nil { - return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) - } - req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) - if len(opts.Subprotocols) > 0 { - req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) - } - if opts.Compression != nil { - opts.Compression.setHeader(req.Header, true) - } - - resp, err := opts.HTTPClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("failed to send handshake request: %w", err) - } - defer func() { - if err != nil { - // We read a bit of the body for easier debugging. - r := io.LimitReader(resp.Body, 1024) - b, _ := ioutil.ReadAll(r) - resp.Body.Close() - resp.Body = ioutil.NopCloser(bytes.NewReader(b)) - } - }() - - copts, err := verifyServerResponse(req, resp, opts) - if err != nil { - return nil, resp, err - } - - rwc, ok := resp.Body.(io.ReadWriteCloser) - if !ok { - return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", resp.Body) - } - - c := &Conn{ - subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), - br: getBufioReader(rwc), - bw: getBufioWriter(rwc), - closer: rwc, - client: true, - copts: copts, - } - c.extractBufioWriterBuf(rwc) - c.init() - - return c, resp, nil -} - -func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*CompressionOptions, error) { - if resp.StatusCode != http.StatusSwitchingProtocols { - return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) - } - - if !headerContainsToken(resp.Header, "Connection", "Upgrade") { - return nil, fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) - } - - if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { - return nil, fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) - } - - if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { - return nil, fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", - resp.Header.Get("Sec-WebSocket-Accept"), - r.Header.Get("Sec-WebSocket-Key"), - ) - } - - if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { - return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) - } - - var copts *CompressionOptions - if opts.Compression != nil { - var err error - copts, err = negotiateCompression(resp.Header, opts.Compression) - if err != nil { - return nil, err - } - } - - return copts, nil -} - -// The below pools can only be used by the client because http.Hijacker will always -// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top. - -var bufioReaderPool = sync.Pool{ - New: func() interface{} { - return bufio.NewReader(nil) - }, -} - -func getBufioReader(r io.Reader) *bufio.Reader { - br := bufioReaderPool.Get().(*bufio.Reader) - br.Reset(r) - return br -} - -func returnBufioReader(br *bufio.Reader) { - bufioReaderPool.Put(br) -} - -var bufioWriterPool = sync.Pool{ - New: func() interface{} { - return bufio.NewWriter(nil) - }, -} - -func getBufioWriter(w io.Writer) *bufio.Writer { - bw := bufioWriterPool.Get().(*bufio.Writer) - bw.Reset(w) - return bw -} - -func returnBufioWriter(bw *bufio.Writer) { - bufioWriterPool.Put(bw) -} - -func makeSecWebSocketKey() (string, error) { - b := make([]byte, 16) - _, err := io.ReadFull(rand.Reader, b) - if err != nil { - return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) - } - return base64.StdEncoding.EncodeToString(b), nil -} - -func negotiateCompression(h http.Header, copts *CompressionOptions) (*CompressionOptions, error) { - xWebkitDeflateFrame, params, ok := readCompressionExtensionHeader(h) - if !ok { - return nil, nil - } - - // Ensures our changes do not modify the real compression options. - copts = &*copts - copts.xWebkitDeflateFrame = xWebkitDeflateFrame - - // We are the client if the header contains the accept header, meaning its from the server. - client := h.Get("Sec-WebSocket-Accept") == "" - - if copts.xWebkitDeflateFrame { - // The other endpoint dictates whether or not we can - // use context takeover on our side. We cannot force it. - // Likewise, we tell the other side so we can force that. - if client { - copts.ClientNoContextTakeover = false - } else { - copts.ServerNoContextTakeover = false - } - } - - for _, p := range params { - switch p { - case "client_no_context_takeover": - copts.ClientNoContextTakeover = true - continue - case "server_no_context_takeover": - copts.ServerNoContextTakeover = true - continue - case "client_max_window_bits", "server-max-window-bits": - if !client { - // If we are the server, we are allowed to ignore these parameters. - // However, if we are the client, we must obey them but because of - // https://github.com/golang/go/issues/3155 we cannot. - continue - } - case "no_context_takeover": - if copts.xWebkitDeflateFrame { - if client { - copts.ClientNoContextTakeover = true - } else { - copts.ServerNoContextTakeover = true - } - continue - } - - // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead - // of ignoring it as the draft spec is unclear. It says the server can ignore it - // but the server has no way of signalling to the client it was ignored as parameters - // are set one way. - // Thus us ignoring it would make the client think we understood it which would cause issues. - // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 - // - // Either way, we're only implementing this for webkit which never sends the max_window_bits - // parameter so we don't need to worry about it. - } - - return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) - } - - return copts, nil -} - -func (copts *CompressionOptions) setHeader(h http.Header, client bool) { - var s string - if !copts.xWebkitDeflateFrame { - s := "permessage-deflate" - if copts.ClientNoContextTakeover { - s += "; client_no_context_takeover" - } - if copts.ServerNoContextTakeover { - s += "; server_no_context_takeover" - } - } else { - s = "x-webkit-deflate-frame" - // We can only set no context takeover for the peer. - if client && copts.ServerNoContextTakeover || !client && copts.ClientNoContextTakeover { - s += "; no_context_takeover" - } - } - h.Set("Sec-WebSocket-Extensions", s) -} diff --git a/internal/assert/assert.go b/internal/assert/assert.go index e57abfd9..372d5465 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -1,8 +1,8 @@ package assert import ( - "fmt" "reflect" + "testing" "github.com/google/go-cmp/cmp" ) @@ -53,11 +53,15 @@ func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { } } -// Equalf compares exp to act and if they are not equal, returns -// an error describing an error. -func Equalf(exp, act interface{}, f string, v ...interface{}) error { - if diff := cmpDiff(exp, act); diff != "" { - return fmt.Errorf(f+": %v", append(v, diff)...) +func Equalf(t *testing.T, exp, act interface{}, f string, v ...interface{}) { + t.Helper() + diff := cmpDiff(exp, act) + if diff != "" { + t.Fatalf(f+": %v", append(v, diff)...) } - return nil +} + +func Success(t *testing.T, err error) { + t.Helper() + Equalf(t, error(nil), err, "unexpected failure") } diff --git a/internal/atomicint/atomicint.go b/internal/atomicint/atomicint.go new file mode 100644 index 00000000..668b3b4b --- /dev/null +++ b/internal/atomicint/atomicint.go @@ -0,0 +1,32 @@ +package atomicint + +import ( + "fmt" + "sync/atomic" +) + +// See https://github.com/nhooyr/websocket/issues/153 +type Int64 struct { + v int64 +} + +func (v *Int64) Load() int64 { + return atomic.LoadInt64(&v.v) +} + +func (v *Int64) Store(i int64) { + atomic.StoreInt64(&v.v, i) +} + +func (v *Int64) String() string { + return fmt.Sprint(v.Load()) +} + +// Increment increments the value and returns the new value. +func (v *Int64) Increment(delta int64) int64 { + return atomic.AddInt64(&v.v, delta) +} + +func (v *Int64) CAS(old, new int64) (swapped bool) { + return atomic.CompareAndSwapInt64(&v.v, old, new) +} diff --git a/internal/bpool/bpool.go b/internal/bufpool/buf.go similarity index 95% rename from internal/bpool/bpool.go rename to internal/bufpool/buf.go index 4266c236..324a17e1 100644 --- a/internal/bpool/bpool.go +++ b/internal/bufpool/buf.go @@ -1,4 +1,4 @@ -package bpool +package bufpool import ( "bytes" diff --git a/internal/bpool/bpool_test.go b/internal/bufpool/buf_test.go similarity index 97% rename from internal/bpool/bpool_test.go rename to internal/bufpool/buf_test.go index 5dfe56e6..42a2fea7 100644 --- a/internal/bpool/bpool_test.go +++ b/internal/bufpool/buf_test.go @@ -1,4 +1,4 @@ -package bpool +package bufpool import ( "strconv" diff --git a/internal/bufpool/bufio.go b/internal/bufpool/bufio.go new file mode 100644 index 00000000..875bbf4b --- /dev/null +++ b/internal/bufpool/bufio.go @@ -0,0 +1,40 @@ +package bufpool + +import ( + "bufio" + "io" + "sync" +) + +var readerPool = sync.Pool{ + New: func() interface{} { + return bufio.NewReader(nil) + }, +} + +func GetReader(r io.Reader) *bufio.Reader { + br := readerPool.Get().(*bufio.Reader) + br.Reset(r) + return br +} + +func PutReader(br *bufio.Reader) { + readerPool.Put(br) +} + +var writerPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriter(nil) + }, +} + +func GetWriter(w io.Writer) *bufio.Writer { + bw := writerPool.Get().(*bufio.Writer) + bw.Reset(w) + return bw +} + +func PutWriter(bw *bufio.Writer) { + writerPool.Put(bw) +} + diff --git a/internal/wsframe/frame.go b/internal/wsframe/frame.go new file mode 100644 index 00000000..50ff8c11 --- /dev/null +++ b/internal/wsframe/frame.go @@ -0,0 +1,194 @@ +package wsframe + +import ( + "encoding/binary" + "fmt" + "io" + "math" +) + +// Opcode represents a WebSocket Opcode. +type Opcode int + +// Opcode constants. +const ( + OpContinuation Opcode = iota + OpText + OpBinary + // 3 - 7 are reserved for further non-control frames. + _ + _ + _ + _ + _ + OpClose + OpPing + OpPong + // 11-16 are reserved for further control frames. +) + +func (o Opcode) Control() bool { + switch o { + case OpClose, OpPing, OpPong: + return true + } + return false +} + +func (o Opcode) Data() bool { + switch o { + case OpText, OpBinary: + return true + } + return false +} + +// First byte contains fin, rsv1, rsv2, rsv3. +// Second byte contains mask flag and payload len. +// Next 8 bytes are the maximum extended payload length. +// Last 4 bytes are the mask key. +// https://tools.ietf.org/html/rfc6455#section-5.2 +const maxHeaderSize = 1 + 1 + 8 + 4 + +// Header represents a WebSocket frame Header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type Header struct { + Fin bool + RSV1 bool + RSV2 bool + RSV3 bool + Opcode Opcode + + PayloadLength int64 + + Masked bool + MaskKey uint32 +} + +// bytes returns the bytes of the Header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func (h Header) Bytes(b []byte) []byte { + if b == nil { + b = make([]byte, maxHeaderSize) + } + + b = b[:2] + b[0] = 0 + + if h.Fin { + b[0] |= 1 << 7 + } + if h.RSV1 { + b[0] |= 1 << 6 + } + if h.RSV2 { + b[0] |= 1 << 5 + } + if h.RSV3 { + b[0] |= 1 << 4 + } + + b[0] |= byte(h.Opcode) + + switch { + case h.PayloadLength < 0: + panic(fmt.Sprintf("websocket: invalid Header: negative length: %v", h.PayloadLength)) + case h.PayloadLength <= 125: + b[1] = byte(h.PayloadLength) + case h.PayloadLength <= math.MaxUint16: + b[1] = 126 + b = b[:len(b)+2] + binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.PayloadLength)) + default: + b[1] = 127 + b = b[:len(b)+8] + binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.PayloadLength)) + } + + if h.Masked { + b[1] |= 1 << 7 + b = b[:len(b)+4] + binary.LittleEndian.PutUint32(b[len(b)-4:], h.MaskKey) + } + + return b +} + +func MakeReadHeaderBuf() []byte { + return make([]byte, maxHeaderSize-2) +} + +// ReadHeader reads a Header from the reader. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func ReadHeader(r io.Reader, b []byte) (Header, error) { + // We read the first two bytes first so that we know + // exactly how long the Header is. + b = b[:2] + _, err := io.ReadFull(r, b) + if err != nil { + return Header{}, err + } + + var h Header + h.Fin = b[0]&(1<<7) != 0 + h.RSV1 = b[0]&(1<<6) != 0 + h.RSV2 = b[0]&(1<<5) != 0 + h.RSV3 = b[0]&(1<<4) != 0 + + h.Opcode = Opcode(b[0] & 0xf) + + var extra int + + h.Masked = b[1]&(1<<7) != 0 + if h.Masked { + extra += 4 + } + + payloadLength := b[1] &^ (1 << 7) + switch { + case payloadLength < 126: + h.PayloadLength = int64(payloadLength) + case payloadLength == 126: + extra += 2 + case payloadLength == 127: + extra += 8 + } + + if extra == 0 { + return h, nil + } + + b = b[:extra] + _, err = io.ReadFull(r, b) + if err != nil { + return Header{}, err + } + + switch { + case payloadLength == 126: + h.PayloadLength = int64(binary.BigEndian.Uint16(b)) + b = b[2:] + case payloadLength == 127: + h.PayloadLength = int64(binary.BigEndian.Uint64(b)) + if h.PayloadLength < 0 { + return Header{}, fmt.Errorf("Header with negative payload length: %v", h.PayloadLength) + } + b = b[8:] + } + + if h.Masked { + h.MaskKey = binary.LittleEndian.Uint32(b) + } + + return h, nil +} + +const MaxControlFramePayload = 125 + +func ParseClosePayload(p []byte) (uint16, string, error) { + if len(p) < 2 { + return 0, "", fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + return binary.BigEndian.Uint16(p), string(p[2:]), nil +} diff --git a/frame_stringer.go b/internal/wsframe/frame_stringer.go similarity index 90% rename from frame_stringer.go rename to internal/wsframe/frame_stringer.go index 72b865fc..b2e7f423 100644 --- a/frame_stringer.go +++ b/internal/wsframe/frame_stringer.go @@ -1,6 +1,6 @@ -// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go"; DO NOT EDIT. +// Code generated by "stringer -type=Opcode,MessageType,StatusCode -output=frame_stringer.go"; DO NOT EDIT. -package websocket +package wsframe import "strconv" @@ -8,12 +8,12 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[opContinuation-0] - _ = x[opText-1] - _ = x[opBinary-2] - _ = x[opClose-8] - _ = x[opPing-9] - _ = x[opPong-10] + _ = x[OpContinuation-0] + _ = x[OpText-1] + _ = x[OpBinary-2] + _ = x[OpClose-8] + _ = x[OpPing-9] + _ = x[OpPong-10] } const ( @@ -26,7 +26,7 @@ var ( _opcode_index_1 = [...]uint8{0, 7, 13, 19} ) -func (i opcode) String() string { +func (i Opcode) String() string { switch { case 0 <= i && i <= 2: return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] @@ -34,7 +34,7 @@ func (i opcode) String() string { i -= 8 return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] default: - return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" + return "Opcode(" + strconv.FormatInt(int64(i), 10) + ")" } } func _() { diff --git a/internal/wsframe/frame_test.go b/internal/wsframe/frame_test.go new file mode 100644 index 00000000..d6b66e7e --- /dev/null +++ b/internal/wsframe/frame_test.go @@ -0,0 +1,157 @@ +// +build !js + +package wsframe + +import ( + "bytes" + "io" + "math/rand" + "strconv" + "testing" + "time" + _ "unsafe" + + "github.com/google/go-cmp/cmp" + _ "github.com/gorilla/websocket" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func randBool() bool { + return rand.Intn(1) == 0 +} + +func TestHeader(t *testing.T) { + t.Parallel() + + t.Run("eof", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + bytes []byte + }{ + { + "start", + []byte{0xff}, + }, + { + "middle", + []byte{0xff, 0xff, 0xff}, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := bytes.NewBuffer(tc.bytes) + _, err := ReadHeader(nil, b) + if io.ErrUnexpectedEOF != err { + t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) + } + }) + } + }) + + t.Run("writeNegativeLength", func(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r == nil { + t.Fatal("failed to induce panic in writeHeader with negative payload length") + } + }() + + Header{ + PayloadLength: -1, + }.Bytes(nil) + }) + + t.Run("readNegativeLength", func(t *testing.T) { + t.Parallel() + + b := Header{ + PayloadLength: 1<<16 + 1, + }.Bytes(nil) + + // Make length negative + b[2] |= 1 << 7 + + r := bytes.NewReader(b) + _, err := ReadHeader(nil, r) + if err == nil { + t.Fatalf("unexpected error value: %+v", err) + } + }) + + t.Run("lengths", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 124, + 125, + 126, + 4096, + 16384, + 65535, + 65536, + 65537, + 131072, + } + + for _, n := range lengths { + n := n + t.Run(strconv.Itoa(n), func(t *testing.T) { + t.Parallel() + + testHeader(t, Header{ + PayloadLength: int64(n), + }) + }) + } + }) + + t.Run("fuzz", func(t *testing.T) { + t.Parallel() + + for i := 0; i < 10000; i++ { + h := Header{ + Fin: randBool(), + RSV1: randBool(), + RSV2: randBool(), + RSV3: randBool(), + Opcode: Opcode(rand.Intn(1 << 4)), + + Masked: randBool(), + PayloadLength: rand.Int63(), + } + + if h.Masked { + h.MaskKey = rand.Uint32() + } + + testHeader(t, h) + } + }) +} + +func testHeader(t *testing.T, h Header) { + b := h.Bytes(nil) + r := bytes.NewReader(b) + h2, err := ReadHeader(r, nil) + if err != nil { + t.Logf("Header: %#v", h) + t.Logf("bytes: %b", b) + t.Fatalf("failed to read Header: %v", err) + } + + if !cmp.Equal(h, h2, cmp.AllowUnexported(Header{})) { + t.Logf("Header: %#v", h) + t.Logf("bytes: %b", b) + t.Fatalf("parsed and read Header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(Header{}))) + } +} diff --git a/internal/wsframe/mask.go b/internal/wsframe/mask.go new file mode 100644 index 00000000..2da4c11d --- /dev/null +++ b/internal/wsframe/mask.go @@ -0,0 +1,128 @@ +package wsframe + +import ( + "encoding/binary" + "math/bits" +) + +// Mask applies the WebSocket masking algorithm to p +// with the given key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +// +// The returned value is the correctly rotated key to +// to continue to mask/unmask the message. +// +// It is optimized for LittleEndian and expects the key +// to be in little endian. +// +// See https://github.com/golang/go/issues/31586 +func Mask(key uint32, b []byte) uint32 { + if len(b) >= 8 { + key64 := uint64(key)<<32 | uint64(key) + + // At some point in the future we can clean these unrolled loops up. + // See https://github.com/golang/go/issues/31586#issuecomment-487436401 + + // Then we xor until b is less than 128 bytes. + for len(b) >= 128 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + v = binary.LittleEndian.Uint64(b[64:72]) + binary.LittleEndian.PutUint64(b[64:72], v^key64) + v = binary.LittleEndian.Uint64(b[72:80]) + binary.LittleEndian.PutUint64(b[72:80], v^key64) + v = binary.LittleEndian.Uint64(b[80:88]) + binary.LittleEndian.PutUint64(b[80:88], v^key64) + v = binary.LittleEndian.Uint64(b[88:96]) + binary.LittleEndian.PutUint64(b[88:96], v^key64) + v = binary.LittleEndian.Uint64(b[96:104]) + binary.LittleEndian.PutUint64(b[96:104], v^key64) + v = binary.LittleEndian.Uint64(b[104:112]) + binary.LittleEndian.PutUint64(b[104:112], v^key64) + v = binary.LittleEndian.Uint64(b[112:120]) + binary.LittleEndian.PutUint64(b[112:120], v^key64) + v = binary.LittleEndian.Uint64(b[120:128]) + binary.LittleEndian.PutUint64(b[120:128], v^key64) + b = b[128:] + } + + // Then we xor until b is less than 64 bytes. + for len(b) >= 64 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + b = b[64:] + } + + // Then we xor until b is less than 32 bytes. + for len(b) >= 32 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + b = b[32:] + } + + // Then we xor until b is less than 16 bytes. + for len(b) >= 16 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + b = b[16:] + } + + // Then we xor until b is less than 8 bytes. + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + b = b[8:] + } + } + + // Then we xor until b is less than 4 bytes. + for len(b) >= 4 { + v := binary.LittleEndian.Uint32(b) + binary.LittleEndian.PutUint32(b, v^key) + b = b[4:] + } + + // xor remaining bytes. + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + + return key +} diff --git a/internal/wsframe/mask_test.go b/internal/wsframe/mask_test.go new file mode 100644 index 00000000..fbd29892 --- /dev/null +++ b/internal/wsframe/mask_test.go @@ -0,0 +1,118 @@ +package wsframe_test + +import ( + "crypto/rand" + "encoding/binary" + "github.com/gobwas/ws" + "github.com/google/go-cmp/cmp" + "math/bits" + "nhooyr.io/websocket/internal/wsframe" + "strconv" + "testing" + _ "unsafe" +) + +func Test_mask(t *testing.T) { + t.Parallel() + + key := []byte{0xa, 0xb, 0xc, 0xff} + key32 := binary.LittleEndian.Uint32(key) + p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} + gotKey32 := wsframe.Mask(key32, p) + + if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { + t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) + } + + if exp := bits.RotateLeft32(key32, -8); !cmp.Equal(exp, gotKey32) { + t.Fatalf("unexpected mask key: %v", cmp.Diff(exp, gotKey32)) + } +} + +func basicMask(maskKey [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= maskKey[pos&3] + pos++ + } + return pos & 3 +} + +//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes +func gorillaMaskBytes(key [4]byte, pos int, b []byte) int + +func Benchmark_mask(b *testing.B) { + sizes := []int{ + 2, + 3, + 4, + 8, + 16, + 32, + 128, + 512, + 4096, + 16384, + } + + fns := []struct { + name string + fn func(b *testing.B, key [4]byte, p []byte) + }{ + { + name: "basic", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + basicMask(key, 0, p) + } + }, + }, + + { + name: "nhooyr", + fn: func(b *testing.B, key [4]byte, p []byte) { + key32 := binary.LittleEndian.Uint32(key[:]) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + wsframe.Mask(key32, p) + } + }, + }, + { + name: "gorilla", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + gorillaMaskBytes(key, 0, p) + } + }, + }, + { + name: "gobwas", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + ws.Cipher(p, key, 0) + } + }, + }, + } + + var key [4]byte + _, err := rand.Read(key[:]) + if err != nil { + b.Fatalf("failed to populate mask key: %v", err) + } + + for _, size := range sizes { + p := make([]byte, size) + + b.Run(strconv.Itoa(size), func(b *testing.B) { + for _, fn := range fns { + b.Run(fn.name, func(b *testing.B) { + b.SetBytes(int64(size)) + + fn.fn(b, key, p) + }) + } + }) + } +} diff --git a/js_test.go b/js_test.go new file mode 100644 index 00000000..80af7896 --- /dev/null +++ b/js_test.go @@ -0,0 +1,50 @@ +package websocket_test + +import ( + "context" + "fmt" + "net/http" + "nhooyr.io/websocket/internal/wsecho" + "os" + "os/exec" + "strings" + "testing" + "time" + + "nhooyr.io/websocket" +) + +func TestJS(t *testing.T) { + t.Parallel() + + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + InsecureSkipVerify: true, + }) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + err = wsecho.Loop(r.Context(), c) + if websocket.CloseStatus(err) != websocket.StatusNormalClosure { + return err + } + return nil + }, false) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") + cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wsURL)) + + b, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("wasm test binary failed: %v:\n%s", err, b) + } +} diff --git a/conn_common.go b/netconn.go similarity index 60% rename from conn_common.go rename to netconn.go index 1247df6e..74a2c7c1 100644 --- a/conn_common.go +++ b/netconn.go @@ -1,6 +1,3 @@ -// This file contains *Conn symbols relevant to both -// Wasm and non Wasm builds. - package websocket import ( @@ -10,7 +7,6 @@ import ( "math" "net" "sync" - "sync/atomic" "time" ) @@ -169,77 +165,3 @@ func (c *netConn) SetReadDeadline(t time.Time) error { return nil } -// CloseRead will start a goroutine to read from the connection until it is closed or a data message -// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. -// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. -// After calling this method, you cannot read any data messages from the connection. -// The returned context will be cancelled when the connection is closed. -// -// Use this when you do not want to read data messages from the connection anymore but will -// want to write messages to it. -func (c *Conn) CloseRead(ctx context.Context) context.Context { - c.isReadClosed.Store(1) - - ctx, cancel := context.WithCancel(ctx) - go func() { - defer cancel() - // We use the unexported reader method so that we don't get the read closed error. - c.reader(ctx, true) - // Either the connection is already closed since there was a read error - // or the context was cancelled or a message was read and we should close - // the connection. - c.Close(StatusPolicyViolation, "unexpected data message") - }() - return ctx -} - -// SetReadLimit sets the max number of bytes to read for a single message. -// It applies to the Reader and Read methods. -// -// By default, the connection has a message read limit of 32768 bytes. -// -// When the limit is hit, the connection will be closed with StatusMessageTooBig. -func (c *Conn) SetReadLimit(n int64) { - c.msgReadLimit.Store(n) -} - -func (c *Conn) setCloseErr(err error) { - c.closeErrOnce.Do(func() { - c.closeErr = fmt.Errorf("websocket closed: %w", err) - }) -} - -// See https://github.com/nhooyr/websocket/issues/153 -type atomicInt64 struct { - v int64 -} - -func (v *atomicInt64) Load() int64 { - return atomic.LoadInt64(&v.v) -} - -func (v *atomicInt64) Store(i int64) { - atomic.StoreInt64(&v.v, i) -} - -func (v *atomicInt64) String() string { - return fmt.Sprint(v.Load()) -} - -// Increment increments the value and returns the new value. -func (v *atomicInt64) Increment(delta int64) int64 { - return atomic.AddInt64(&v.v, delta) -} - -func (v *atomicInt64) CAS(old, new int64) (swapped bool) { - return atomic.CompareAndSwapInt64(&v.v, old, new) -} - -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} diff --git a/reader.go b/reader.go new file mode 100644 index 00000000..fe716569 --- /dev/null +++ b/reader.go @@ -0,0 +1,31 @@ +package websocket + +import ( + "bufio" + "context" + "io" + "nhooyr.io/websocket/internal/atomicint" + "nhooyr.io/websocket/internal/wsframe" + "strings" +) + +type reader struct { + // Acquired before performing any sort of read operation. + readLock chan struct{} + + c *Conn + + deflateReader io.Reader + br *bufio.Reader + + readClosed *atomicint.Int64 + readHeaderBuf []byte + controlPayloadBuf []byte + + msgCtx context.Context + msgCompressed bool + frameHeader wsframe.Header + frameMaskKey uint32 + frameEOF bool + deflateTail strings.Reader +} diff --git a/websocket_js_test.go b/websocket_js_test.go deleted file mode 100644 index 9b7bb813..00000000 --- a/websocket_js_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package websocket_test - -import ( - "context" - "net/http" - "os" - "testing" - "time" - - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/assert" -) - -func TestConn(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - }) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - err = assertSubprotocol(c, "echo") - if err != nil { - t.Fatal(err) - } - - err = assert.Equalf(&http.Response{}, resp, "unexpected http response") - if err != nil { - t.Fatal(err) - } - - err = assertJSONEcho(ctx, c, 1024) - if err != nil { - t.Fatal(err) - } - - err = assertEcho(ctx, c, websocket.MessageBinary, 1024) - if err != nil { - t.Fatal(err) - } - - err = c.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatal(err) - } -} diff --git a/writer.go b/writer.go new file mode 100644 index 00000000..b31d57ad --- /dev/null +++ b/writer.go @@ -0,0 +1,5 @@ +package websocket + +type writer struct { + +} diff --git a/websocket_js.go b/ws_js.go similarity index 88% rename from websocket_js.go rename to ws_js.go index d27809cf..4c067430 100644 --- a/websocket_js.go +++ b/ws_js.go @@ -1,3 +1,5 @@ +// +build js + package websocket // import "nhooyr.io/websocket" import ( @@ -7,12 +9,13 @@ import ( "fmt" "io" "net/http" + "nhooyr.io/websocket/internal/atomicint" "reflect" "runtime" "sync" "syscall/js" - "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/bufpool" "nhooyr.io/websocket/internal/wsjs" ) @@ -21,10 +24,10 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit *atomicInt64 + msgReadLimit *atomicint.Int64 closingMu sync.Mutex - isReadClosed *atomicInt64 + isReadClosed *atomicint.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -56,17 +59,20 @@ func (c *Conn) init() { c.closed = make(chan struct{}) c.readSignal = make(chan struct{}, 1) - c.msgReadLimit = &atomicInt64{} + c.msgReadLimit = &atomicint.Int64{} c.msgReadLimit.Store(32768) - c.isReadClosed = &atomicInt64{} + c.isReadClosed = &atomicint.Int64{} c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { err := CloseError{ Code: StatusCode(e.Code), Reason: e.Reason, } - c.close(fmt.Errorf("received close: %w", err), e.WasClean) + // We do not know if we sent or received this close as + // its possible the browser triggered it without us + // explicitly sending it. + c.close(err, e.WasClean) c.releaseOnClose() c.releaseOnMessage() @@ -288,11 +294,6 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return typ, bytes.NewReader(p), nil } -// Only implemented for use by *Conn.CloseRead in conn_common.go -func (c *Conn) reader(ctx context.Context, _ bool) { - c.read(ctx) -} - // Writer returns a writer to write a WebSocket data message to the connection. // It buffers the entire message in memory and then sends it when the writer // is closed. @@ -301,7 +302,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err c: c, ctx: ctx, typ: typ, - b: bpool.Get(), + b: bufpool.Get(), }, nil } @@ -331,7 +332,7 @@ func (w writer) Close() error { return errors.New("cannot close closed writer") } w.closed = true - defer bpool.Put(w.b) + defer bufpool.Put(w.b) err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) if err != nil { @@ -339,3 +340,34 @@ func (w writer) Close() error { } return nil } + +func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.isReadClosed.Store(1) + + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.read(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit.Store(n) +} + +func (c *Conn) setCloseErr(err error) { + c.closeErrOnce.Do(func() { + c.closeErr = fmt.Errorf("websocket closed: %w", err) + }) +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/ws_js_test.go b/ws_js_test.go new file mode 100644 index 00000000..abd950c7 --- /dev/null +++ b/ws_js_test.go @@ -0,0 +1,22 @@ +package websocket + +func TestEcho(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ + Subprotocols: []string{"echo"}, + }) + assert.Success(t, err) + defer c.Close(websocket.StatusInternalError, "") + + assertSubprotocol(t, c, "echo") + assert.Equalf(t, &http.Response{}, resp, "unexpected http response") + assertJSONEcho(t, ctx, c, 1024) + assertEcho(t, ctx, c, websocket.MessageBinary, 1024) + + err = c.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) +} diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index fe935fa1..9fa8b54c 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -5,9 +5,8 @@ import ( "context" "encoding/json" "fmt" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/bufpool" ) // Read reads a json message from c into v. @@ -31,8 +30,8 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { return fmt.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ) } - b := bpool.Get() - defer bpool.Put(b) + b := bufpool.Get() + defer bufpool.Put(b) _, err = b.ReadFrom(r) if err != nil { diff --git a/wspb/wspb.go b/wspb/wspb.go index 3c9e0f76..52ddcd57 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -9,7 +9,7 @@ import ( "github.com/golang/protobuf/proto" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/bufpool" ) // Read reads a protobuf message from c into v. @@ -33,8 +33,8 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { return fmt.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ) } - b := bpool.Get() - defer bpool.Put(b) + b := bufpool.Get() + defer bufpool.Put(b) _, err = b.ReadFrom(r) if err != nil { @@ -61,10 +61,10 @@ func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { } func write(ctx context.Context, c *websocket.Conn, v proto.Message) error { - b := bpool.Get() + b := bufpool.Get() pb := proto.NewBuffer(b.Bytes()) defer func() { - bpool.Put(bytes.NewBuffer(pb.Bytes())) + bufpool.Put(bytes.NewBuffer(pb.Bytes())) }() err := pb.Marshal(v) From d0a80496108cf7cdd4e20c24e4689cd5934b5b89 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 18 Nov 2019 22:52:18 -0500 Subject: [PATCH 08/55] Rewrite core Too many improvements and changes to list. Will include a detailed changelog for release. --- accept.go | 63 +- assert_test.go | 14 + autobahn_test.go | 252 ++ close.go | 158 +- close_test.go | 9 +- compress.go | 86 + conn.go | 1133 +------- conn_export_test.go | 129 - conn_test.go | 2382 +---------------- dial.go | 78 +- dial_test.go | 2 +- example_echo_test.go | 3 +- internal/wsframe/mask.go => frame.go | 162 +- .../wsframe/mask_test.go => frame_test.go | 108 +- internal/assert/assert.go | 40 +- internal/atomicint/atomicint.go | 32 - internal/bufpool/buf.go | 6 +- internal/bufpool/bufio.go | 40 - internal/errd/errd.go | 11 + internal/wsecho/wsecho.go | 55 - internal/wsframe/frame.go | 194 -- internal/wsframe/frame_stringer.go | 91 - internal/wsframe/frame_test.go | 157 -- internal/wsgrace/wsgrace.go | 50 - js_test.go | 50 - read.go | 479 ++++ reader.go | 31 - write.go | 348 +++ writer.go | 5 - ws_js.go | 12 +- wsjson/wsjson.go | 2 + 31 files changed, 1844 insertions(+), 4338 deletions(-) create mode 100644 autobahn_test.go delete mode 100644 conn_export_test.go rename internal/wsframe/mask.go => frame.go (57%) rename internal/wsframe/mask_test.go => frame_test.go (51%) delete mode 100644 internal/atomicint/atomicint.go delete mode 100644 internal/bufpool/bufio.go create mode 100644 internal/errd/errd.go delete mode 100644 internal/wsecho/wsecho.go delete mode 100644 internal/wsframe/frame.go delete mode 100644 internal/wsframe/frame_stringer.go delete mode 100644 internal/wsframe/frame_test.go delete mode 100644 internal/wsgrace/wsgrace.go delete mode 100644 js_test.go create mode 100644 read.go delete mode 100644 reader.go create mode 100644 write.go delete mode 100644 writer.go diff --git a/accept.go b/accept.go index 5ff2ea41..2028d4b2 100644 --- a/accept.go +++ b/accept.go @@ -60,10 +60,15 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, return c, nil } -func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { +func (opts *AcceptOptions) ensure() *AcceptOptions { if opts == nil { - opts = &AcceptOptions{} + return &AcceptOptions{} } + return opts +} + +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + opts = opts.ensure() err := verifyClientRequest(w, r) if err != nil { @@ -114,31 +119,14 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - c := &Conn{ + return newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), + rwc: netConn, + client: false, + copts: copts, br: brw.Reader, bw: brw.Writer, - closer: netConn, - copts: copts, - } - c.init() - - return c, nil -} - -func authenticateOrigin(r *http.Request) error { - origin := r.Header.Get("Origin") - if origin == "" { - return nil - } - u, err := url.Parse(origin) - if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) - } - if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) - } - return nil + }), nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { @@ -181,15 +169,37 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { return nil } +func authenticateOrigin(r *http.Request) error { + origin := r.Header.Get("Origin") + if origin == "" { + return nil + } + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + if !strings.EqualFold(u.Host, r.Host) { + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + } + return nil +} + func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { key := r.Header.Get("Sec-WebSocket-Key") w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { + cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") + if len(cps) == 0 { + return "" + } + for _, sp := range subprotocols { - if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { - return sp + for _, cp := range cps { + if strings.EqualFold(sp, cp) { + return cp + } } } return "" @@ -266,7 +276,6 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com return copts, nil } - func headerContainsToken(h http.Header, key, token string) bool { token = strings.ToLower(token) diff --git a/assert_test.go b/assert_test.go index af300999..0cc9dfe3 100644 --- a/assert_test.go +++ b/assert_test.go @@ -23,6 +23,8 @@ func randBytes(n int) []byte { } func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) { + t.Helper() + exp := randString(n) err := wsjson.Write(ctx, c, exp) assert.Success(t, err) @@ -35,6 +37,8 @@ func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) } func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { + t.Helper() + var act interface{} err := wsjson.Read(ctx, c, &act) assert.Success(t, err) @@ -56,6 +60,8 @@ func randString(n int) string { } func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) { + t.Helper() + p := randBytes(n) err := c.Write(ctx, typ, p) assert.Success(t, err) @@ -68,5 +74,13 @@ func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websoc } func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { + t.Helper() + assert.Equalf(t, exp, c.Subprotocol(), "unexpected subprotocol") } + +func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) { + t.Helper() + + assert.Equalf(t, exp, websocket.CloseStatus(err), "unexpected status code") +} diff --git a/autobahn_test.go b/autobahn_test.go new file mode 100644 index 00000000..27f8a1b4 --- /dev/null +++ b/autobahn_test.go @@ -0,0 +1,252 @@ +package websocket_test + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "nhooyr.io/websocket" + "os" + "os/exec" + "strconv" + "strings" + "testing" + "time" +) + +func TestAutobahn(t *testing.T) { + // This test contains the old autobahn test suite tests that use the + // python binary. The approach is clunky and slow so new tests + // have been written in pure Go in websocket_test.go. + // These have been kept for correctness purposes and are occasionally ran. + if os.Getenv("AUTOBAHN") == "" { + t.Skip("Set $AUTOBAHN to run tests against the autobahn test suite") + } + + t.Run("server", testServerAutobahnPython) + t.Run("client", testClientAutobahnPython) +} + +// https://github.com/crossbario/autobahn-python/tree/master/wstest +func testServerAutobahnPython(t *testing.T) { + t.Parallel() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + }) + if err != nil { + t.Logf("server handshake failed: %+v", err) + return + } + echoLoop(r.Context(), c) + })) + defer s.Close() + + spec := map[string]interface{}{ + "outdir": "ci/out/wstestServerReports", + "servers": []interface{}{ + map[string]interface{}{ + "agent": "main", + "url": strings.Replace(s.URL, "http", "ws", 1), + }, + }, + "cases": []string{"*"}, + // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just + // more performance overhead. 7.5.1 is the same. + "exclude-cases": []string{"6.*", "7.5.1"}, + } + specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") + if err != nil { + t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) + } + defer specFile.Close() + + e := json.NewEncoder(specFile) + e.SetIndent("", "\t") + err = e.Encode(spec) + if err != nil { + t.Fatalf("failed to write spec: %v", err) + } + + err = specFile.Close() + if err != nil { + t.Fatalf("failed to close file: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + defer cancel() + + args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} + wstest := exec.CommandContext(ctx, "wstest", args...) + out, err := wstest.CombinedOutput() + if err != nil { + t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) + } + + checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") +} + +func unusedListenAddr() (string, error) { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", err + } + l.Close() + return l.Addr().String(), nil +} + +// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py +func testClientAutobahnPython(t *testing.T) { + t.Parallel() + + if os.Getenv("AUTOBAHN_PYTHON") == "" { + t.Skip("Set $AUTOBAHN_PYTHON to test against the python autobahn test suite") + } + + serverAddr, err := unusedListenAddr() + if err != nil { + t.Fatalf("failed to get unused listen addr for wstest: %v", err) + } + + wsServerURL := "ws://" + serverAddr + + spec := map[string]interface{}{ + "url": wsServerURL, + "outdir": "ci/out/wstestClientReports", + "cases": []string{"*"}, + // See TestAutobahnServer for the reasons why we exclude these. + "exclude-cases": []string{"6.*", "7.5.1"}, + } + specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") + if err != nil { + t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) + } + defer specFile.Close() + + e := json.NewEncoder(specFile) + e.SetIndent("", "\t") + err = e.Encode(spec) + if err != nil { + t.Fatalf("failed to write spec: %v", err) + } + + err = specFile.Close() + if err != nil { + t.Fatalf("failed to close file: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + defer cancel() + + args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), + // Disables some server that runs as part of fuzzingserver mode. + // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 + "--webport=0", + } + wstest := exec.CommandContext(ctx, "wstest", args...) + err = wstest.Start() + if err != nil { + t.Fatal(err) + } + defer func() { + err := wstest.Process.Kill() + if err != nil { + t.Error(err) + } + }() + + // Let it come up. + time.Sleep(time.Second * 5) + + var cases int + func() { + c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) + if err != nil { + t.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + _, r, err := c.Reader(ctx) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + cases, err = strconv.Atoi(string(b)) + if err != nil { + t.Fatal(err) + } + + c.Close(websocket.StatusNormalClosure, "") + }() + + for i := 1; i <= cases; i++ { + func() { + ctx, cancel := context.WithTimeout(ctx, time.Second*45) + defer cancel() + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) + if err != nil { + t.Fatal(err) + } + echoLoop(ctx, c) + }() + } + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) + if err != nil { + t.Fatal(err) + } + c.Close(websocket.StatusNormalClosure, "") + + checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") +} + +func checkWSTestIndex(t *testing.T, path string) { + wstestOut, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("failed to read index.json: %v", err) + } + + var indexJSON map[string]map[string]struct { + Behavior string `json:"behavior"` + BehaviorClose string `json:"behaviorClose"` + } + err = json.Unmarshal(wstestOut, &indexJSON) + if err != nil { + t.Fatalf("failed to unmarshal index.json: %v", err) + } + + var failed bool + for _, tests := range indexJSON { + for test, result := range tests { + switch result.Behavior { + case "OK", "NON-STRICT", "INFORMATIONAL": + default: + failed = true + t.Errorf("test %v failed", test) + } + switch result.BehaviorClose { + case "OK", "INFORMATIONAL": + default: + failed = true + t.Errorf("bad close behaviour for test %v", test) + } + } + } + + if failed { + path = strings.Replace(path, ".json", ".html", 1) + if os.Getenv("CI") == "" { + t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path) + } + } +} diff --git a/close.go b/close.go index 4f48f1b3..b1bc50e9 100644 --- a/close.go +++ b/close.go @@ -5,7 +5,9 @@ import ( "encoding/binary" "errors" "fmt" - "nhooyr.io/websocket/internal/wsframe" + "log" + "nhooyr.io/websocket/internal/bufpool" + "time" ) // StatusCode represents a WebSocket status code. @@ -74,6 +76,87 @@ func CloseStatus(err error) StatusCode { return -1 } +// Close closes the WebSocket connection with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for +// the peer to send a close frame. +// Thus, it implements the full WebSocket close handshake. +// All data messages received from the peer during the close handshake +// will be discarded. +// +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes otherwise an internal +// error will be sent to the peer. For this reason, you should avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection once +// complete. +func (c *Conn) Close(code StatusCode, reason string) error { + err := c.closeHandshake(code, reason) + if err != nil { + return fmt.Errorf("failed to close websocket: %w", err) + } + return nil +} + +func (c *Conn) closeHandshake(code StatusCode, reason string) error { + err := c.cw.sendClose(code, reason) + if err != nil { + return err + } + + return c.cr.waitClose() +} + +func (cw *connWriter) error(code StatusCode, err error) { + cw.c.setCloseErr(err) + cw.sendClose(code, err.Error()) + cw.c.close(nil) +} + +func (cw *connWriter) sendClose(code StatusCode, reason string) error { + ce := CloseError{ + Code: code, + Reason: reason, + } + + cw.c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + + var p []byte + if ce.Code != StatusNoStatusRcvd { + p = ce.bytes() + } + + return cw.control(context.Background(), opClose, p) +} + +func (cr *connReader) waitClose() error { + defer cr.c.close(nil) + + return nil + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := cr.mu.Lock(ctx) + if err != nil { + return err + } + defer cr.mu.Unlock() + + b := bufpool.Get() + buf := b.Bytes() + buf = buf[:cap(buf)] + defer bufpool.Put(b) + + for { + // TODO + return nil + } +} + func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ @@ -81,14 +164,13 @@ func parseClosePayload(p []byte) (CloseError, error) { }, nil } - code, reason, err := wsframe.ParseClosePayload(p) - if err != nil { - return CloseError{}, err + if len(p) < 2 { + return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ - Code: StatusCode(code), - Reason: reason, + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), } if !validWireCloseCode(ce.Code) { @@ -116,11 +198,25 @@ func validWireCloseCode(code StatusCode) bool { return false } -func (ce CloseError) bytes() ([]byte, error) { - // TODO move check into frame write - if len(ce.Reason) > wsframe.MaxControlFramePayload-2 { - return nil, fmt.Errorf("reason string max is %v but got %q with length %v", wsframe.MaxControlFramePayload-2, ce.Reason, len(ce.Reason)) +func (ce CloseError) bytes() []byte { + p, err := ce.bytesErr() + if err != nil { + log.Printf("websocket: failed to marshal close frame: %+v", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytesErr() } + return p +} + +const maxCloseReason = maxControlPayload - 2 + +func (ce CloseError) bytesErr() ([]byte, error) { + if len(ce.Reason) > maxCloseReason { + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + } + if !validWireCloseCode(ce.Code) { return nil, fmt.Errorf("status code %v cannot be set", ce.Code) } @@ -131,44 +227,16 @@ func (ce CloseError) bytes() ([]byte, error) { return buf, nil } -// CloseRead will start a goroutine to read from the connection until it is closed or a data message -// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. -// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. -// After calling this method, you cannot read any data messages from the connection. -// The returned context will be cancelled when the connection is closed. -// -// Use this when you do not want to read data messages from the connection anymore but will -// want to write messages to it. -func (c *Conn) CloseRead(ctx context.Context) context.Context { - c.isReadClosed.Store(1) - - ctx, cancel := context.WithCancel(ctx) - go func() { - defer cancel() - // We use the unexported reader method so that we don't get the read closed error. - c.reader(ctx, true) - // Either the connection is already closed since there was a read error - // or the context was cancelled or a message was read and we should close - // the connection. - c.Close(StatusPolicyViolation, "unexpected data message") - }() - return ctx -} - -// SetReadLimit sets the max number of bytes to read for a single message. -// It applies to the Reader and Read methods. -// -// By default, the connection has a message read limit of 32768 bytes. -// -// When the limit is hit, the connection will be closed with StatusMessageTooBig. -func (c *Conn) SetReadLimit(n int64) { - c.msgReadLimit.Store(n) +func (c *Conn) setCloseErr(err error) { + c.closeMu.Lock() + c.setCloseErrNoLock(err) + c.closeMu.Unlock() } -func (c *Conn) setCloseErr(err error) { - c.closeErrOnce.Do(func() { +func (c *Conn) setCloseErrNoLock(err error) { + if c.closeErr == nil { c.closeErr = fmt.Errorf("websocket closed: %w", err) - }) + } } func (c *Conn) isClosed() bool { diff --git a/close_test.go b/close_test.go index 78096d7e..ee10cd3f 100644 --- a/close_test.go +++ b/close_test.go @@ -5,7 +5,6 @@ import ( "io" "math" "nhooyr.io/websocket/internal/assert" - "nhooyr.io/websocket/internal/wsframe" "strings" "testing" ) @@ -22,7 +21,7 @@ func TestCloseError(t *testing.T) { name: "normal", ce: CloseError{ Code: StatusNormalClosure, - Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-2), + Reason: strings.Repeat("x", maxCloseReason), }, success: true, }, @@ -30,7 +29,7 @@ func TestCloseError(t *testing.T) { name: "bigReason", ce: CloseError{ Code: StatusNormalClosure, - Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-1), + Reason: strings.Repeat("x", maxCloseReason+1), }, success: false, }, @@ -38,7 +37,7 @@ func TestCloseError(t *testing.T) { name: "bigCode", ce: CloseError{ Code: math.MaxUint16, - Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-2), + Reason: strings.Repeat("x", maxCloseReason), }, success: false, }, @@ -49,7 +48,7 @@ func TestCloseError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, err := tc.ce.bytes() + _, err := tc.ce.bytesErr() if (err == nil) != tc.success { t.Fatalf("unexpected error value: %+v", err) } diff --git a/compress.go b/compress.go index 5b5fdce5..9e075430 100644 --- a/compress.go +++ b/compress.go @@ -3,7 +3,10 @@ package websocket import ( + "compress/flate" + "io" "net/http" + "sync" ) // CompressionMode controls the modes available RFC 7692's deflate extension. @@ -76,3 +79,86 @@ func (copts *compressionOptions) setHeader(h http.Header) { // we need to add them back otherwise flate.Reader keeps // trying to return more bytes. const deflateMessageTail = "\x00\x00\xff\xff" + +func (c *Conn) writeNoContextTakeOver() bool { + return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover +} + +func (c *Conn) readNoContextTakeOver() bool { + return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover +} + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (tw *trimLastFourBytesWriter) reset() { + tw.tail = tw.tail[:0] +} + +func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + extra := len(tw.tail) + len(p) - 4 + + if extra <= 0 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(tw.tail) { + extra = len(tw.tail) + } + if extra > 0 { + _, err := tw.w.Write(tw.tail[:extra]) + if err != nil { + return 0, err + } + tw.tail = tw.tail[extra:] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + tw.tail = append(tw.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := tw.w.Write(p) + return n + 4, err +} + +var flateReaderPool sync.Pool + +func getFlateReader(r io.Reader) io.Reader { + fr, ok := flateReaderPool.Get().(io.Reader) + if !ok { + return flate.NewReader(r) + } + fr.(flate.Resetter).Reset(r, nil) + return fr +} + +func putFlateReader(fr io.Reader) { + flateReaderPool.Put(fr) +} + +var flateWriterPool sync.Pool + +func getFlateWriter(w io.Writer) *flate.Writer { + fw, ok := flateWriterPool.Get().(*flate.Writer) + if !ok { + fw, _ = flate.NewWriter(w, flate.BestSpeed) + return fw + } + fw.Reset(w) + return fw +} + +func putFlateWriter(w *flate.Writer) { + flateWriterPool.Put(w) +} diff --git a/conn.go b/conn.go index 791d9b4c..e3f24171 100644 --- a/conn.go +++ b/conn.go @@ -4,25 +4,14 @@ package websocket import ( "bufio" - "compress/flate" "context" - "crypto/rand" - "encoding/binary" "errors" "fmt" "io" - "io/ioutil" - "log" - "nhooyr.io/websocket/internal/atomicint" - "nhooyr.io/websocket/internal/wsframe" "runtime" "strconv" - "strings" "sync" "sync/atomic" - "time" - - "nhooyr.io/websocket/internal/bufpool" ) // MessageType represents the type of a WebSocket message. @@ -51,91 +40,54 @@ const ( // This applies to the Read methods in the wsjson/wspb subpackages as well. type Conn struct { subprotocol string - fw *flate.Writer - bw *bufio.Writer - // writeBuf is used for masking, its the buffer in bufio.Writer. - // Only used by the client for masking the bytes in the buffer. - writeBuf []byte - closer io.Closer - client bool - copts *compressionOptions - - closeOnce sync.Once - closeErrOnce sync.Once - closeErr error - closed chan struct{} - closing *atomicint.Int64 - closeReceived error + rwc io.ReadWriteCloser + client bool + copts *compressionOptions - // messageWriter state. - // writeMsgLock is acquired to write a data message. - writeMsgLock chan struct{} - // writeFrameLock is acquired to write a single frame. - // Effectively meaning whoever holds it gets to write to bw. - writeFrameLock chan struct{} - writeHeaderBuf []byte - writeHeader *header - // read limit for a message in bytes. - msgReadLimit *atomicint.Int64 + cr connReader + cw connWriter - // Used to ensure a previous writer is not used after being closed. - activeWriter atomic.Value - // messageWriter state. - writeMsgOpcode opcode - writeMsgCtx context.Context + closed chan struct{} - setReadTimeout chan context.Context - setWriteTimeout chan context.Context + closeMu sync.Mutex + closeErr error + closeHandshakeErr error - pingCounter *atomicint.Int64 + pingCounter int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} - - logf func(format string, v ...interface{}) } -func (c *Conn) init() { - c.closed = make(chan struct{}) - c.closing = &atomicint.Int64{} - - c.msgReadLimit = &atomicint.Int64{} - c.msgReadLimit.Store(32768) +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions - c.writeMsgLock = make(chan struct{}, 1) - c.writeFrameLock = make(chan struct{}, 1) + bw *bufio.Writer + br *bufio.Reader +} - c.readFrameLock = make(chan struct{}, 1) - c.readLock = make(chan struct{}, 1) - c.payloadReader = framePayloadReader{c} +func newConn(cfg connConfig) *Conn { + c := &Conn{} + c.subprotocol = cfg.subprotocol + c.rwc = cfg.rwc + c.client = cfg.client + c.copts = cfg.copts - c.setReadTimeout = make(chan context.Context) - c.setWriteTimeout = make(chan context.Context) + c.cr.init(c, cfg.br) + c.cw.init(c, cfg.bw) - c.pingCounter = &atomicint.Int64{} + c.closed = make(chan struct{}) c.activePings = make(map[string]chan<- struct{}) - c.writeHeaderBuf = makeWriteHeaderBuf() - c.writeHeader = &header{} - c.readHeaderBuf = makeReadHeaderBuf() - c.isReadClosed = &atomicint.Int64{} - c.controlPayloadBuf = make([]byte, maxControlFramePayload) - runtime.SetFinalizer(c, func(c *Conn) { c.close(errors.New("connection garbage collected")) }) - c.logf = log.Printf - - if c.copts != nil { - if !c.readNoContextTakeOver() { - c.fr = getFlateReader(c.payloadReader) - } - if !c.writeNoContextTakeOver() { - c.fw = getFlateWriter(c.bw) - } - } - go c.timeoutLoop() + + return c } // Subprotocol returns the negotiated subprotocol. @@ -145,38 +97,25 @@ func (c *Conn) Subprotocol() string { } func (c *Conn) close(err error) { - c.closeOnce.Do(func() { - runtime.SetFinalizer(c, nil) + c.closeMu.Lock() + defer c.closeMu.Unlock() - c.setCloseErr(err) - close(c.closed) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - c.closer.Close() + if c.isClosed() { + return + } + close(c.closed) + runtime.SetFinalizer(c, nil) + c.setCloseErrNoLock(err) - // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer - // and we can safely return them. - // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent - // a deadlock. - // As of now, this is in writeFrame, readFramePayload and readHeader. - c.readFrameLock <- struct{}{} - if c.client { - returnBufioReader(c.br) - } - if c.fr != nil { - putFlateReader(c.fr) - } + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.rwc.Close() - c.writeFrameLock <- struct{}{} - if c.client { - returnBufioWriter(c.bw) - } - if c.fw != nil { - putFlateWriter(c.fw) - } - }) + go func() { + c.cr.close() + c.cw.close() + }() } func (c *Conn) timeoutLoop() { @@ -188,20 +127,13 @@ func (c *Conn) timeoutLoop() { case <-c.closed: return - case writeCtx = <-c.setWriteTimeout: - case readCtx = <-c.setReadTimeout: + case writeCtx = <-c.cw.timeout: + case readCtx = <-c.cr.timeout: case <-readCtx.Done(): c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - // Guaranteed to eventually close the connection since we can only ever send - // one close frame. - go func() { - c.exportedClose(StatusPolicyViolation, "read timed out", true) - // Ensure the connection closes, i.e if we already sent a close frame and timed out - // to read the peer's close frame. - c.close(nil) - }() - readCtx = context.Background() + c.cw.error(StatusPolicyViolation, errors.New("timed out")) + return case <-writeCtx.Done(): c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) return @@ -209,843 +141,8 @@ func (c *Conn) timeoutLoop() { } } -func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { - select { - case <-ctx.Done(): - var err error - switch lock { - case c.writeFrameLock, c.writeMsgLock: - err = fmt.Errorf("could not acquire write lock: %v", ctx.Err()) - case c.readFrameLock, c.readLock: - err = fmt.Errorf("could not acquire read lock: %v", ctx.Err()) - default: - panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err())) - } - c.close(err) - return ctx.Err() - case <-c.closed: - return c.closeErr - case lock <- struct{}{}: - return nil - } -} - -func (c *Conn) releaseLock(lock chan struct{}) { - // Allow multiple releases. - select { - case <-lock: - default: - } -} - -func (c *Conn) readTillMsg(ctx context.Context) (header, error) { - for { - h, err := c.readFrameHeader(ctx) - if err != nil { - return header{}, err - } - - if (h.rsv1 && (c.copts == nil || h.opcode.controlOp() || h.opcode == opContinuation)) || h.rsv2 || h.rsv3 { - err := fmt.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) - c.exportedClose(StatusProtocolError, err.Error(), false) - return header{}, err - } - - if h.opcode.controlOp() { - err = c.handleControl(ctx, h) - if err != nil { - // Pass through CloseErrors when receiving a close frame. - if h.opcode == opClose && CloseStatus(err) != -1 { - return header{}, err - } - return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) - } - continue - } - - switch h.opcode { - case opBinary, opText, opContinuation: - return h, nil - default: - err := fmt.Errorf("received unknown opcode %v", h.opcode) - c.exportedClose(StatusProtocolError, err.Error(), false) - return header{}, err - } - } -} - -func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { - wrap := func(err error) error { - return fmt.Errorf("failed to read frame header: %w", err) - } - defer func() { - if err != nil { - err = wrap(err) - } - }() - - err = c.acquireLock(ctx, c.readFrameLock) - if err != nil { - return header{}, err - } - defer c.releaseLock(c.readFrameLock) - - select { - case <-c.closed: - return header{}, c.closeErr - case c.setReadTimeout <- ctx: - } - - h, err := readHeader(c.readHeaderBuf, c.br) - if err != nil { - select { - case <-c.closed: - return header{}, c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } - c.releaseLock(c.readFrameLock) - c.close(wrap(err)) - return header{}, err - } - - select { - case <-c.closed: - return header{}, c.closeErr - case c.setReadTimeout <- context.Background(): - } - - return h, nil -} - -func (c *Conn) handleControl(ctx context.Context, h header) error { - if h.payloadLength > maxControlFramePayload { - err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength) - c.exportedClose(StatusProtocolError, err.Error(), false) - return err - } - - if !h.fin { - err := errors.New("received fragmented control frame") - c.exportedClose(StatusProtocolError, err.Error(), false) - return err - } - - ctx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - - b := c.controlPayloadBuf[:h.payloadLength] - _, err := c.readFramePayload(ctx, b) - if err != nil { - return err - } - - if h.masked { - mask(h.maskKey, b) - } - - switch h.opcode { - case opPing: - return c.writeControl(ctx, opPong, b) - case opPong: - c.activePingsMu.Lock() - pong, ok := c.activePings[string(b)] - c.activePingsMu.Unlock() - if ok { - close(pong) - } - return nil - case opClose: - ce, err := parseClosePayload(b) - if err != nil { - err = fmt.Errorf("received invalid close payload: %w", err) - c.exportedClose(StatusProtocolError, err.Error(), false) - c.closeReceived = err - return err - } - - err = fmt.Errorf("received close: %w", ce) - c.closeReceived = err - c.writeClose(b, err, false) - - if ctx.Err() != nil { - // The above close probably has been returned by the peer in response - // to our read timing out so we have to return the read timed out error instead. - return fmt.Errorf("read timed out: %w", ctx.Err()) - } - - return err - default: - panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) - } -} - -// Reader waits until there is a WebSocket data message to read -// from the connection. -// It returns the type of the message and a reader to read it. -// The passed context will also bound the reader. -// Ensure you read to EOF otherwise the connection will hang. -// -// All returned errors will cause the connection -// to be closed so you do not need to write your own error message. -// This applies to the Read methods in the wsjson/wspb subpackages as well. -// -// You must read from the connection for control frames to be handled. -// Thus if you expect messages to take a long time to be responded to, -// you should handle such messages async to reading from the connection -// to ensure control frames are promptly handled. -// -// If you do not expect any data messages from the peer, call CloseRead. -// -// Only one Reader may be open at a time. -// -// If you need a separate timeout on the Reader call and then the message -// Read, use time.AfterFunc to cancel the context passed in early. -// See https://github.com/nhooyr/websocket/issues/87#issue-451703332 -// Most users should not need this. -func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { - if c.isReadClosed.Load() == 1 { - return 0, nil, errors.New("websocket connection read closed") - } - - typ, r, err := c.reader(ctx, true) - if err != nil { - return 0, nil, fmt.Errorf("failed to get reader: %w", err) - } - return typ, r, nil -} - -func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, error) { - if lock { - err := c.acquireLock(ctx, c.readLock) - if err != nil { - return 0, nil, err - } - defer c.releaseLock(c.readLock) - } - - if c.activeReader != nil && !c.readerFrameEOF { - // The only way we know for sure the previous reader is not yet complete is - // if there is an active frame not yet fully read. - // Otherwise, a user may have read the last byte but not the EOF if the EOF - // is in the next frame so we check for that below. - return 0, nil, errors.New("previous message not read to completion") - } - - h, err := c.readTillMsg(ctx) - if err != nil { - return 0, nil, err - } - - if c.activeReader != nil && !c.activeReader.eof() { - if h.opcode != opContinuation { - err := errors.New("received new data message without finishing the previous message") - c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, nil, err - } - - if !h.fin || h.payloadLength > 0 { - return 0, nil, fmt.Errorf("previous message not read to completion") - } - - c.activeReader = nil - - h, err = c.readTillMsg(ctx) - if err != nil { - return 0, nil, err - } - } else if h.opcode == opContinuation { - err := errors.New("received continuation frame not after data or text frame") - c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, nil, err - } - - c.readerMsgCtx = ctx - c.readerMsgHeader = h - - c.readerPayloadCompressed = h.rsv1 - - if c.readerPayloadCompressed { - c.readerCompressTail.Reset(deflateMessageTail) - } - - c.readerFrameEOF = false - c.readerMaskKey = h.maskKey - c.readMsgLeft = c.msgReadLimit.Load() - - r := &messageReader{ - c: c, - } - c.activeReader = r - if c.readerPayloadCompressed && c.readNoContextTakeOver() { - c.fr = getFlateReader(c.payloadReader) - } - return MessageType(h.opcode), r, nil -} - -type framePayloadReader struct { - c *Conn -} - -func (r framePayloadReader) Read(p []byte) (int, error) { - if r.c.readerFrameEOF { - if r.c.readerPayloadCompressed && r.c.readerMsgHeader.fin { - n, _ := r.c.readerCompressTail.Read(p) - return n, nil - } - - h, err := r.c.readTillMsg(r.c.readerMsgCtx) - if err != nil { - return 0, err - } - - if h.opcode != opContinuation { - err := errors.New("received new data message without finishing the previous message") - r.c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, err - } - - r.c.readerMsgHeader = h - r.c.readerFrameEOF = false - r.c.readerMaskKey = h.maskKey - } - - h := r.c.readerMsgHeader - if int64(len(p)) > h.payloadLength { - p = p[:h.payloadLength] - } - - n, err := r.c.readFramePayload(r.c.readerMsgCtx, p) - - h.payloadLength -= int64(n) - if h.masked { - r.c.readerMaskKey = mask(r.c.readerMaskKey, p) - } - r.c.readerMsgHeader = h - - if err != nil { - return n, err - } - - if h.payloadLength == 0 { - r.c.readerFrameEOF = true - - if h.fin && !r.c.readerPayloadCompressed { - return n, io.EOF - } - } - - return n, nil -} - -// messageReader enables reading a data frame from the WebSocket connection. -type messageReader struct { - c *Conn -} - -func (r *messageReader) eof() bool { - return r.c.activeReader != r -} - -// Read reads as many bytes as possible into p. -func (r *messageReader) Read(p []byte) (int, error) { - return r.exportedRead(p, true) -} - -func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) { - n, err := r.read(p, lock) - if err != nil { - // Have to return io.EOF directly for now, we cannot wrap as errors.Is - // isn't used widely yet. - if errors.Is(err, io.EOF) { - return n, io.EOF - } - return n, fmt.Errorf("failed to read: %w", err) - } - return n, nil -} - -func (r *messageReader) readUnlocked(p []byte) (int, error) { - return r.exportedRead(p, false) -} - -func (r *messageReader) read(p []byte, lock bool) (int, error) { - if lock { - // If we cannot acquire the read lock, then - // there is either a concurrent read or the close handshake - // is proceeding. - select { - case r.c.readLock <- struct{}{}: - defer r.c.releaseLock(r.c.readLock) - default: - if r.c.closing.Load() == 1 { - <-r.c.closed - return 0, r.c.closeErr - } - return 0, errors.New("concurrent read detected") - } - } - - if r.eof() { - return 0, errors.New("cannot use EOFed reader") - } - - if r.c.readMsgLeft <= 0 { - err := fmt.Errorf("read limited at %v bytes", r.c.msgReadLimit) - r.c.exportedClose(StatusMessageTooBig, err.Error(), false) - return 0, err - } - - if int64(len(p)) > r.c.readMsgLeft { - p = p[:r.c.readMsgLeft] - } - - pr := io.Reader(r.c.payloadReader) - if r.c.readerPayloadCompressed { - pr = r.c.fr - } - - n, err := pr.Read(p) - - r.c.readMsgLeft -= int64(n) - - if r.c.readerFrameEOF && r.c.readerMsgHeader.fin { - if r.c.readerPayloadCompressed && r.c.readNoContextTakeOver() { - putFlateReader(r.c.fr) - r.c.fr = nil - } - r.c.activeReader = nil - if err == nil { - err = io.EOF - } - } - - return n, err -} - -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { - wrap := func(err error) error { - return fmt.Errorf("failed to read frame payload: %w", err) - } - defer func() { - if err != nil { - err = wrap(err) - } - }() - - err = c.acquireLock(ctx, c.readFrameLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.readFrameLock) - - select { - case <-c.closed: - return 0, c.closeErr - case c.setReadTimeout <- ctx: - } - - n, err := io.ReadFull(c.br, p) - if err != nil { - select { - case <-c.closed: - return n, c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } - c.releaseLock(c.readFrameLock) - c.close(wrap(err)) - return n, err - } - - select { - case <-c.closed: - return n, c.closeErr - case c.setReadTimeout <- context.Background(): - } - - return n, err -} - -// Read is a convenience method to read a single message from the connection. -// -// See the Reader method if you want to be able to reuse buffers or want to stream a message. -// The docs on Reader apply to this method as well. -func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - typ, r, err := c.Reader(ctx) - if err != nil { - return 0, nil, err - } - - b, err := ioutil.ReadAll(r) - return typ, b, err -} - -// Writer returns a writer bounded by the context that will write -// a WebSocket message of type dataType to the connection. -// -// You must close the writer once you have written the entire message. -// -// Only one writer can be open at a time, multiple calls will block until the previous writer -// is closed. -func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - wc, err := c.writer(ctx, typ) - if err != nil { - return nil, fmt.Errorf("failed to get writer: %w", err) - } - return wc, nil -} - -func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := c.acquireLock(ctx, c.writeMsgLock) - if err != nil { - return nil, err - } - c.writeMsgCtx = ctx - c.writeMsgOpcode = opcode(typ) - w := &messageWriter{ - c: c, - } - c.activeWriter.Store(w) - return w, nil -} - -// Write is a convenience method to write a message to the connection. -// -// See the Writer method if you want to stream a message. -func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - _, err := c.write(ctx, typ, p) - if err != nil { - return fmt.Errorf("failed to write msg: %w", err) - } - return nil -} - -func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { - err := c.acquireLock(ctx, c.writeMsgLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.writeMsgLock) - - n, err := c.writeFrame(ctx, true, opcode(typ), p) - return n, err -} - -// messageWriter enables writing to a WebSocket connection. -type messageWriter struct { - c *Conn -} - -func (w *messageWriter) closed() bool { - return w != w.c.activeWriter.Load() -} - -// Write writes the given bytes to the WebSocket connection. -func (w *messageWriter) Write(p []byte) (int, error) { - n, err := w.write(p) - if err != nil { - return n, fmt.Errorf("failed to write: %w", err) - } - return n, nil -} - -func (w *messageWriter) write(p []byte) (int, error) { - if w.closed() { - return 0, fmt.Errorf("cannot use closed writer") - } - n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p) - if err != nil { - return n, fmt.Errorf("failed to write data frame: %w", err) - } - w.c.writeMsgOpcode = opContinuation - return n, nil -} - -// Close flushes the frame to the connection. -// This must be called for every messageWriter. -func (w *messageWriter) Close() error { - err := w.close() - if err != nil { - return fmt.Errorf("failed to close writer: %w", err) - } - return nil -} - -func (w *messageWriter) close() error { - if w.closed() { - return fmt.Errorf("cannot use closed writer") - } - w.c.activeWriter.Store((*messageWriter)(nil)) - - _, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil) - if err != nil { - return fmt.Errorf("failed to write fin frame: %w", err) - } - - w.c.releaseLock(w.c.writeMsgLock) - return nil -} - -func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { - ctx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - - _, err := c.writeFrame(ctx, true, opcode, p) - if err != nil { - return fmt.Errorf("failed to write control frame %v: %w", opcode, err) - } - return nil -} - -// writeFrame handles all writes to the connection. -func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { - err := c.acquireLock(ctx, c.writeFrameLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.writeFrameLock) - - select { - case <-c.closed: - return 0, c.closeErr - case c.setWriteTimeout <- ctx: - } - - c.writeHeader.fin = fin - c.writeHeader.opcode = opcode - c.writeHeader.masked = c.client - c.writeHeader.payloadLength = int64(len(p)) - - if c.client { - err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey) - if err != nil { - return 0, fmt.Errorf("failed to generate masking key: %w", err) - } - } - - n, err := c.realWriteFrame(ctx, *c.writeHeader, p) - if err != nil { - return n, err - } - - // We already finished writing, no need to potentially brick the connection if - // the context expires. - select { - case <-c.closed: - return n, c.closeErr - case c.setWriteTimeout <- context.Background(): - } - - return n, nil -} - -func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error) { - defer func() { - if err != nil { - select { - case <-c.closed: - err = c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } - - err = fmt.Errorf("failed to write %v frame: %w", h.opcode, err) - // We need to release the lock first before closing the connection to ensure - // the lock can be acquired inside close to ensure no one can access c.bw. - c.releaseLock(c.writeFrameLock) - c.close(err) - } - }() - - headerBytes := writeHeader(c.writeHeaderBuf, h) - _, err = c.bw.Write(headerBytes) - if err != nil { - return 0, err - } - - if c.client { - maskKey := h.maskKey - for len(p) > 0 { - if c.bw.Available() == 0 { - err = c.bw.Flush() - if err != nil { - return n, err - } - } - - // Start of next write in the buffer. - i := c.bw.Buffered() - - p2 := p - if len(p) > c.bw.Available() { - p2 = p[:c.bw.Available()] - } - - n2, err := c.bw.Write(p2) - if err != nil { - return n, err - } - - maskKey = mask(maskKey, c.writeBuf[i:i+n2]) - - p = p[n2:] - n += n2 - } - } else { - n, err = c.bw.Write(p) - if err != nil { - return n, err - } - } - - if h.fin { - err = c.bw.Flush() - if err != nil { - return n, err - } - } - - return n, nil -} - -// Close closes the WebSocket connection with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for -// the peer to send a close frame. -// Thus, it implements the full WebSocket close handshake. -// All data messages received from the peer during the close handshake -// will be discarded. -// -// The connection can only be closed once. Additional calls to Close -// are no-ops. -// -// The maximum length of reason must be 125 bytes otherwise an internal -// error will be sent to the peer. For this reason, you should avoid -// sending a dynamic reason. -// -// Close will unblock all goroutines interacting with the connection once -// complete. -func (c *Conn) Close(code StatusCode, reason string) error { - err := c.exportedClose(code, reason, true) - var ec errClosing - if errors.As(err, &ec) { - <-c.closed - // We wait until the connection closes. - // We use writeClose and not exportedClose to avoid a second failed to marshal close frame error. - err = c.writeClose(nil, ec.ce, true) - } - if err != nil { - return fmt.Errorf("failed to close websocket connection: %w", err) - } - return nil -} - -func (c *Conn) exportedClose(code StatusCode, reason string, handshake bool) error { - ce := CloseError{ - Code: code, - Reason: reason, - } - - // This function also will not wait for a close frame from the peer like the RFC - // wants because that makes no sense and I don't think anyone actually follows that. - // Definitely worth seeing what popular browsers do later. - p, err := ce.bytes() - if err != nil { - c.logf("websocket: failed to marshal close frame: %+v", err) - ce = CloseError{ - Code: StatusInternalError, - } - p, _ = ce.bytes() - } - - return c.writeClose(p, fmt.Errorf("sent close: %w", ce), handshake) -} - -type errClosing struct { - ce error -} - -func (e errClosing) Error() string { - return "already closing connection" -} - -func (c *Conn) writeClose(p []byte, ce error, handshake bool) error { - if c.isClosed() { - return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) - } - - if !c.closing.CAS(0, 1) { - // Normally, we would want to wait until the connection is closed, - // at least for when a user calls into Close, so we handle that case in - // the exported Close function. - // - // But for internal library usage, we always want to return early, e.g. - // if we are performing a close handshake and the peer sends their close frame, - // we do not want to block here waiting for c.closed to close because it won't, - // at least not until we return since the gorouine that will close it is this one. - return errClosing{ - ce: ce, - } - } - - // No matter what happens next, close error should be set. - c.setCloseErr(ce) - defer c.close(nil) - - err := c.writeControl(context.Background(), opClose, p) - if err != nil { - return err - } - - if handshake { - err = c.waitClose() - if CloseStatus(err) == -1 { - // waitClose exited not due to receiving a close frame. - return fmt.Errorf("failed to wait for peer close frame: %w", err) - } - } - - return nil -} - -func (c *Conn) waitClose() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - err := c.acquireLock(ctx, c.readLock) - if err != nil { - return err - } - defer c.releaseLock(c.readLock) - - if c.closeReceived != nil { - // goroutine reading just received the close. - return c.closeReceived - } - - b := bufpool.Get() - buf := b.Bytes() - buf = buf[:cap(buf)] - defer bufpool.Put(b) - - for { - if c.activeReader == nil || c.readerFrameEOF { - _, _, err := c.reader(ctx, false) - if err != nil { - return fmt.Errorf("failed to get reader: %w", err) - } - } - - r := readerFunc(c.activeReader.readUnlocked) - _, err = io.CopyBuffer(ioutil.Discard, r, buf) - if err != nil { - return err - } - } +func (c *Conn) deflateNegotiated() bool { + return c.copts != nil } // Ping sends a ping to the peer and waits for a pong. @@ -1056,9 +153,9 @@ func (c *Conn) waitClose() error { // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { - p := c.pingCounter.Increment(1) + p := atomic.AddInt32(&c.pingCounter, 1) - err := c.ping(ctx, strconv.FormatInt(p, 10)) + err := c.ping(ctx, strconv.Itoa(int(p))) if err != nil { return fmt.Errorf("failed to ping: %w", err) } @@ -1078,7 +175,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { c.activePingsMu.Unlock() }() - err := c.writeControl(ctx, opPing, []byte(p)) + err := c.cw.control(ctx, opPing, []byte(p)) if err != nil { return err } @@ -1095,109 +192,37 @@ func (c *Conn) ping(ctx context.Context, p string) error { } } -type readerFunc func(p []byte) (int, error) - -func (f readerFunc) Read(p []byte) (int, error) { - return f(p) -} - -type writerFunc func(p []byte) (int, error) - -func (f writerFunc) Write(p []byte) (int, error) { - return f(p) -} - -// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer -// and stores it in c.writeBuf. -func (c *Conn) extractBufioWriterBuf(w io.Writer) { - c.bw.Reset(writerFunc(func(p2 []byte) (int, error) { - c.writeBuf = p2[:cap(p2)] - return len(p2), nil - })) - - c.bw.WriteByte(0) - c.bw.Flush() - - c.bw.Reset(w) -} - -var flateWriterPool = &sync.Pool{ - New: func() interface{} { - w, _ := flate.NewWriter(nil, flate.BestSpeed) - return w - }, -} - -func getFlateWriter(w io.Writer) *flate.Writer { - fw := flateWriterPool.Get().(*flate.Writer) - fw.Reset(w) - return fw -} - -func putFlateWriter(w *flate.Writer) { - flateWriterPool.Put(w) +type mu struct { + once sync.Once + ch chan struct{} } -var flateReaderPool = &sync.Pool{ - New: func() interface{} { - return flate.NewReader(nil) - }, -} - -func getFlateReader(r io.Reader) io.Reader { - fr := flateReaderPool.Get().(io.Reader) - fr.(flate.Resetter).Reset(r, nil) - return fr -} - -func putFlateReader(fr io.Reader) { - flateReaderPool.Put(fr) -} - -func (c *Conn) writeNoContextTakeOver() bool { - return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover -} - -func (c *Conn) readNoContextTakeOver() bool { - return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover -} - -type trimLastFourBytesWriter struct { - w io.Writer - tail []byte +func (m *mu) init() { + m.once.Do(func() { + m.ch = make(chan struct{}, 1) + }) } -func (w *trimLastFourBytesWriter) Write(p []byte) (int, error) { - extra := len(w.tail) + len(p) - 4 - - if extra <= 0 { - w.tail = append(w.tail, p...) - return len(p), nil - } - - // Now we need to write as many extra bytes as we can from the previous tail. - if extra > len(w.tail) { - extra = len(w.tail) - } - if extra > 0 { - _, err := w.Write(w.tail[:extra]) - if err != nil { - return 0, err - } - w.tail = w.tail[extra:] +func (m *mu) Lock(ctx context.Context) error { + m.init() + select { + case <-ctx.Done(): + return ctx.Err() + case m.ch <- struct{}{}: + return nil } +} - // If p is less than or equal to 4 bytes, - // all of it is is part of the tail. - if len(p) <= 4 { - w.tail = append(w.tail, p...) - return len(p), nil +func (m *mu) TryLock() bool { + m.init() + select { + case m.ch <- struct{}{}: + return true + default: + return false } +} - // Otherwise, only the last 4 bytes are. - w.tail = append(w.tail, p[len(p)-4:]...) - - p = p[:len(p)-4] - n, err := w.w.Write(p) - return n + 4, err +func (m *mu) Unlock() { + <-m.ch } diff --git a/conn_export_test.go b/conn_export_test.go deleted file mode 100644 index d5f5aa24..00000000 --- a/conn_export_test.go +++ /dev/null @@ -1,129 +0,0 @@ -// +build !js - -package websocket - -import ( - "bufio" - "context" - "fmt" -) - -type ( - Addr = websocketAddr - OpCode int -) - -const ( - OpClose = OpCode(opClose) - OpBinary = OpCode(opBinary) - OpText = OpCode(opText) - OpPing = OpCode(opPing) - OpPong = OpCode(opPong) - OpContinuation = OpCode(opContinuation) -) - -func (c *Conn) SetLogf(fn func(format string, v ...interface{})) { - c.logf = fn -} - -func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) { - h, err := c.readFrameHeader(ctx) - if err != nil { - return 0, nil, err - } - b := make([]byte, h.payloadLength) - _, err = c.readFramePayload(ctx, b) - if err != nil { - return 0, nil, err - } - if h.masked { - mask(h.maskKey, b) - } - return OpCode(h.opcode), b, nil -} - -func (c *Conn) WriteFrame(ctx context.Context, fin bool, opc OpCode, p []byte) (int, error) { - return c.writeFrame(ctx, fin, opcode(opc), p) -} - -// header represents a WebSocket frame header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -type Header struct { - Fin bool - Rsv1 bool - Rsv2 bool - Rsv3 bool - OpCode OpCode - - PayloadLength int64 -} - -func (c *Conn) WriteHeader(ctx context.Context, h Header) error { - headerBytes := writeHeader(c.writeHeaderBuf, header{ - fin: h.Fin, - rsv1: h.Rsv1, - rsv2: h.Rsv2, - rsv3: h.Rsv3, - opcode: opcode(h.OpCode), - payloadLength: h.PayloadLength, - masked: c.client, - }) - _, err := c.bw.Write(headerBytes) - if err != nil { - return fmt.Errorf("failed to write header: %w", err) - } - if h.Fin { - err = c.Flush() - if err != nil { - return err - } - } - return nil -} - -func (c *Conn) PingWithPayload(ctx context.Context, p string) error { - return c.ping(ctx, p) -} - -func (c *Conn) WriteHalfFrame(ctx context.Context) (int, error) { - return c.realWriteFrame(ctx, header{ - fin: true, - opcode: opBinary, - payloadLength: 10, - }, make([]byte, 5)) -} - -func (c *Conn) CloseUnderlyingConn() { - c.closer.Close() -} - -func (c *Conn) Flush() error { - return c.bw.Flush() -} - -func (c CloseError) Bytes() ([]byte, error) { - return c.bytes() -} - -func (c *Conn) BW() *bufio.Writer { - return c.bw -} - -func (c *Conn) WriteClose(ctx context.Context, code StatusCode, reason string) ([]byte, error) { - b, err := CloseError{ - Code: code, - Reason: reason, - }.Bytes() - if err != nil { - return nil, err - } - _, err = c.WriteFrame(ctx, true, OpClose, b) - if err != nil { - return nil, err - } - return b, nil -} - -func ParseClosePayload(p []byte) (CloseError, error) { - return parseClosePayload(p) -} diff --git a/conn_test.go b/conn_test.go index d03a7214..992c8861 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,969 +3,28 @@ package websocket_test import ( - "bytes" "context" - "encoding/binary" - "encoding/json" - "errors" "fmt" "io" - "io/ioutil" - "math/rand" - "net" "net/http" - "net/http/cookiejar" "net/http/httptest" - "net/url" - "os" - "os/exec" - "reflect" - "strconv" + "nhooyr.io/websocket/internal/assert" "strings" + "sync/atomic" "testing" "time" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/timestamp" - "go.uber.org/multierr" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/assert" - "nhooyr.io/websocket/internal/wsecho" - "nhooyr.io/websocket/internal/wsgrace" - "nhooyr.io/websocket/wsjson" - "nhooyr.io/websocket/wspb" ) -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func TestHandshake(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - client func(ctx context.Context, url string) error - server func(w http.ResponseWriter, r *http.Request) error - }{ - { - name: "badOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err == nil { - c.Close(websocket.StatusInternalError, "") - return errors.New("expected error regarding bad origin") - } - return assertErrorContains(err, "not authorized") - }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", "http://unauthorized.com") - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPHeader: h, - }) - if err == nil { - c.Close(websocket.StatusInternalError, "") - return errors.New("expected handshake failure") - } - return assertErrorContains(err, "403") - }, - }, - { - name: "acceptSecureOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", u) - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPHeader: h, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - }, - { - name: "acceptInsecureOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - InsecureSkipVerify: true, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", "https://example.com") - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPHeader: h, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - }, - { - name: "cookies", - server: func(w http.ResponseWriter, r *http.Request) error { - cookie, err := r.Cookie("mycookie") - if err != nil { - return fmt.Errorf("request is missing mycookie: %w", err) - } - err = assert.Equalf("myvalue", cookie.Value, "unexpected cookie value") - if err != nil { - return err - } - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - client: func(ctx context.Context, u string) error { - jar, err := cookiejar.New(nil) - if err != nil { - return fmt.Errorf("failed to create cookie jar: %w", err) - } - parsedURL, err := url.Parse(u) - if err != nil { - return fmt.Errorf("failed to parse url: %w", err) - } - parsedURL.Scheme = "http" - jar.SetCookies(parsedURL, []*http.Cookie{ - { - Name: "mycookie", - Value: "myvalue", - }, - }) - hc := &http.Client{ - Jar: jar, - } - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPClient: hc, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - s, closeFn := testServer(t, tc.server, false) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - err := tc.client(ctx, wsURL) - if err != nil { - t.Fatalf("client failed: %+v", err) - } - }) - } -} - -func TestConn(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - - acceptOpts *websocket.AcceptOptions - server func(ctx context.Context, c *websocket.Conn) error - - dialOpts *websocket.DialOptions - response func(resp *http.Response) error - client func(ctx context.Context, c *websocket.Conn) error - }{ - { - name: "handshake", - acceptOpts: &websocket.AcceptOptions{ - Subprotocols: []string{"myproto"}, - }, - dialOpts: &websocket.DialOptions{ - Subprotocols: []string{"myproto"}, - }, - response: func(resp *http.Response) error { - headers := map[string]string{ - "Connection": "Upgrade", - "Upgrade": "websocket", - "Sec-WebSocket-Protocol": "myproto", - } - for h, exp := range headers { - value := resp.Header.Get(h) - err := assert.Equalf(exp, value, "unexpected value for header %v", h) - if err != nil { - return err - } - } - return nil - }, - }, - { - name: "handshake/defaultSubprotocol", - server: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "") - }, - }, - { - name: "handshake/subprotocolPriority", - acceptOpts: &websocket.AcceptOptions{ - Subprotocols: []string{"echo", "lar"}, - }, - server: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "echo") - }, - dialOpts: &websocket.DialOptions{ - Subprotocols: []string{"poof", "echo"}, - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "echo") - }, - }, - { - name: "closeError", - server: func(ctx context.Context, c *websocket.Conn) error { - return wsjson.Write(ctx, c, "hello") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := assertJSONRead(ctx, c, "hello") - if err != nil { - return err - } - - _, _, err = c.Reader(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "netConn", - server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - defer nc.Close() - - nc.SetWriteDeadline(time.Time{}) - time.Sleep(1) - nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) - - err := assert.Equalf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr") - if err != nil { - return err - } - err = assert.Equalf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr") - if err != nil { - return err - } - - for i := 0; i < 3; i++ { - _, err := nc.Write([]byte("hello")) - if err != nil { - return err - } - } - - return nil - }, - client: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - - nc.SetReadDeadline(time.Time{}) - time.Sleep(1) - nc.SetReadDeadline(time.Now().Add(time.Second * 15)) - - for i := 0; i < 3; i++ { - err := assertNetConnRead(nc, "hello") - if err != nil { - return err - } - } - - // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. - err2 := assertNetConnRead(nc, "hello") - err := assert.Equalf(io.EOF, err2, "unexpected error") - if err != nil { - return err - } - - err2 = assertNetConnRead(nc, "hello") - return assert.Equalf(io.EOF, err2, "unexpected error") - }, - }, - { - name: "netConn/badReadMsgType", - server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - - nc.SetDeadline(time.Now().Add(time.Second * 15)) - - _, err := nc.Read(make([]byte, 1)) - return assertErrorContains(err, "unexpected frame type") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wsjson.Write(ctx, c, "meow") - if err != nil { - return err - } - - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusUnsupportedData) - }, - }, - { - name: "netConn/badRead", - server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - defer nc.Close() - - nc.SetDeadline(time.Now().Add(time.Second * 15)) - - _, err2 := nc.Read(make([]byte, 1)) - err := assertCloseStatus(err2, websocket.StatusBadGateway) - if err != nil { - return err - } - - _, err2 = nc.Write([]byte{0xff}) - return assertErrorContains(err2, "websocket closed") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusBadGateway, "") - }, - }, - { - name: "wsjson/echo", - server: func(ctx context.Context, c *websocket.Conn) error { - return wsjson.Write(ctx, c, "meow") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertJSONRead(ctx, c, "meow") - }, - }, - { - name: "protobuf/echo", - server: func(ctx context.Context, c *websocket.Conn) error { - return wspb.Write(ctx, c, ptypes.DurationProto(100)) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertProtobufRead(ctx, c, ptypes.DurationProto(100)) - }, - }, - { - name: "ping", - server: func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - - err := c.Ping(ctx) - if err != nil { - return err - } - - err = wsjson.Write(ctx, c, "hi") - if err != nil { - return err - } - - <-ctx.Done() - err = c.Ping(context.Background()) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - // We read a message from the connection and then keep reading until - // the Ping completes. - pingErrc := make(chan error, 1) - go func() { - pingErrc <- c.Ping(ctx) - }() - - // Once this completes successfully, that means they sent their ping and we responded to it. - err := assertJSONRead(ctx, c, "hi") - if err != nil { - return err - } - - // Now we need to ensure we're reading for their pong from our ping. - // Need new var to not race with above goroutine. - ctx2 := c.CloseRead(ctx) - - // Now we wait for our pong. - select { - case err = <-pingErrc: - return err - case <-ctx2.Done(): - return fmt.Errorf("failed to wait for pong: %w", ctx2.Err()) - } - }, - }, - { - name: "readLimit", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err2 := c.Read(ctx) - return assertErrorContains(err2, "read limited at 32768 bytes") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) - if err != nil { - return err - } - - _, _, err2 := c.Read(ctx) - return assertCloseStatus(err2, websocket.StatusMessageTooBig) - }, - }, - { - name: "wsjson/binary", - server: func(ctx context.Context, c *websocket.Conn) error { - var v interface{} - err2 := wsjson.Read(ctx, c, &v) - return assertErrorContains(err2, "unexpected frame type") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return wspb.Write(ctx, c, ptypes.DurationProto(100)) - }, - }, - { - name: "wsjson/badRead", - server: func(ctx context.Context, c *websocket.Conn) error { - var v interface{} - err2 := wsjson.Read(ctx, c, &v) - return assertErrorContains(err2, "failed to unmarshal json") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Write(ctx, websocket.MessageText, []byte("notjson")) - }, - }, - { - name: "wsjson/badWrite", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err2 := c.Read(ctx) - return assertCloseStatus(err2, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wsjson.Write(ctx, c, fmt.Println) - return assertErrorContains(err, "failed to encode json") - }, - }, - { - name: "wspb/text", - server: func(ctx context.Context, c *websocket.Conn) error { - var v proto.Message - err := wspb.Read(ctx, c, v) - return assertErrorContains(err, "unexpected frame type") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return wsjson.Write(ctx, c, "hi") - }, - }, - { - name: "wspb/badRead", - server: func(ctx context.Context, c *websocket.Conn) error { - var v timestamp.Timestamp - err := wspb.Read(ctx, c, &v) - return assertErrorContains(err, "failed to unmarshal protobuf") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Write(ctx, websocket.MessageBinary, []byte("notpb")) - }, - }, - { - name: "wspb/badWrite", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wspb.Write(ctx, c, nil) - return assertErrorIs(proto.ErrNil, err) - }, - }, - { - name: "badClose", - server: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(9999, "") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "pingTimeout", - server: func(ctx context.Context, c *websocket.Conn) error { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - err := c.Ping(ctx) - return assertErrorIs(context.DeadlineExceeded, err) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - err1 := assertErrorContains(err, "connection reset") - err2 := assertErrorIs(io.EOF, err) - if err1 != nil || err2 != nil { - return nil - } - return multierr.Combine(err1, err2) - }, - }, - { - name: "writeTimeout", - server: func(ctx context.Context, c *websocket.Conn) error { - c.Writer(ctx, websocket.MessageBinary) - - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - err := c.Write(ctx, websocket.MessageBinary, []byte("meow")) - return assertErrorIs(context.DeadlineExceeded, err) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorIs(io.EOF, err) - }, - }, - { - name: "readTimeout", - server: func(ctx context.Context, c *websocket.Conn) error { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - _, _, err := c.Read(ctx) - return assertErrorIs(context.DeadlineExceeded, err) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorIs(websocket.CloseError{ - Code: websocket.StatusPolicyViolation, - Reason: "read timed out", - }, err) - }, - }, - { - name: "badOpCode", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, 13, []byte("meow")) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertErrorContains(err, "unknown opcode") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "unknown opcode") - }, - }, - { - name: "noRsv", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, 99, []byte("meow")) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "rsv") { - return fmt.Errorf("expected error that contains rsv: %+v", err) - } - return nil - }, - }, - { - name: "largeControlFrame", - server: func(ctx context.Context, c *websocket.Conn) error { - err := c.WriteHeader(ctx, websocket.Header{ - Fin: true, - OpCode: websocket.OpClose, - PayloadLength: 4096, - }) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "too big") - }, - }, - { - name: "fragmentedControlFrame", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, false, websocket.OpPing, []byte(strings.Repeat("x", 32))) - if err != nil { - return err - } - err = c.Flush() - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "fragmented") - }, - }, - { - name: "invalidClosePayload", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{0x17, 0x70}) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "invalid status code") - }, - }, - { - name: "doubleReader", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 10) - _, err = io.ReadFull(r, p) - if err != nil { - return err - } - _, _, err = c.Reader(ctx) - return assertErrorContains(err, "previous message not read to completion") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 11))) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "doubleFragmentedReader", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 10) - _, err = io.ReadFull(r, p) - if err != nil { - return err - } - _, _, err = c.Reader(ctx) - return assertErrorContains(err, "previous message not read to completion") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageBinary) - if err != nil { - return err - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "newMessageInFragmentedMessage", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 10) - _, err = io.ReadFull(r, p) - if err != nil { - return err - } - _, _, err = c.Reader(ctx) - return assertErrorContains(err, "received new data message without finishing") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageBinary) - if err != nil { - return err - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, err = c.WriteFrame(ctx, true, websocket.OpBinary, []byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - _, _, err = c.Read(ctx) - return assertErrorContains(err, "received new data message without finishing") - }, - }, - { - name: "continuationFrameWithoutDataFrame", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Reader(ctx) - return assertErrorContains(err, "received continuation frame not after data") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, false, websocket.OpContinuation, []byte(strings.Repeat("x", 10))) - return err - }, - }, - { - name: "readBeforeEOF", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - var v interface{} - d := json.NewDecoder(r) - err = d.Decode(&v) - if err != nil { - return err - } - err = assert.Equalf("hi", v, "unexpected JSON") - if err != nil { - return err - } - _, b, err := c.Read(ctx) - if err != nil { - return err - } - return assert.Equalf("hi", string(b), "unexpected JSON") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wsjson.Write(ctx, c, "hi") - if err != nil { - return err - } - return c.Write(ctx, websocket.MessageText, []byte("hi")) - }, - }, - { - name: "newMessageInFragmentedMessage2", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 11) - _, err = io.ReadFull(r, p) - return assertErrorContains(err, "received new data message without finishing") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageBinary) - if err != nil { - return err - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, err = c.WriteFrame(ctx, true, websocket.OpBinary, []byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - }, - { - name: "doubleRead", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - _, err = ioutil.ReadAll(r) - if err != nil { - return err - } - _, err = r.Read(make([]byte, 1)) - return assertErrorContains(err, "cannot use EOFed reader") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Write(ctx, websocket.MessageBinary, []byte("hi")) - }, - }, - { - name: "eofInPayload", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "failed to read frame payload") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteHalfFrame(ctx) - if err != nil { - return err - } - c.CloseUnderlyingConn() - return nil - }, - }, - { - name: "closeHandshake", - server: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, "") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, "") - }, - }, - { - // Issue #164 - name: "closeHandshake_concurrentRead", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - errc := make(chan error, 1) - go func() { - _, _, err := c.Read(ctx) - errc <- err - }() - - err := c.Close(websocket.StatusNormalClosure, "") - if err != nil { - return err - } - - err = <-errc - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - // Run random tests over TLS. - tls := rand.Intn(2) == 1 - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, tc.acceptOpts) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - c.SetLogf(t.Logf) - if tc.server == nil { - return nil - } - return tc.server(r.Context(), c) - }, tls) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - opts := tc.dialOpts - if tls { - if opts == nil { - opts = &websocket.DialOptions{} - } - opts.HTTPClient = s.Client() - } - - c, resp, err := websocket.Dial(ctx, wsURL, opts) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - c.SetLogf(t.Logf) - - if tc.response != nil { - err = tc.response(resp) - if err != nil { - t.Fatalf("response asserter failed: %+v", err) - } - } - - if tc.client != nil { - err = tc.client(ctx, c) - if err != nil { - t.Fatalf("client failed: %+v", err) - } - } - - c.Close(websocket.StatusNormalClosure, "") - }) - } -} - -func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) { - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - err := fn(w, r) - if err != nil { - tb.Errorf("server failed: %+v", err) - } - }) +func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request), tls bool) (s *httptest.Server, closeFn func()) { + h := http.HandlerFunc(fn) if tls { s = httptest.NewTLSServer(h) } else { s = httptest.NewServer(h) } - closeFn2 := wsgrace.Grace(s.Config) + closeFn2 := wsgrace(s.Config) return s, func() { err := closeFn2() if err != nil { @@ -974,1417 +33,112 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) e } } -func TestAutobahn(t *testing.T) { - t.Parallel() - - run := func(t *testing.T, name string, fn func(ctx context.Context, c *websocket.Conn) error) { - run2 := func(t *testing.T, testingClient bool) { - // Run random tests over TLS. - tls := rand.Intn(2) == 1 - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - ctx := r.Context() - if testingClient { - err = wsecho.Loop(ctx, c) - if err != nil { - t.Logf("failed to wsecho: %+v", err) - } - return nil - } - - c.SetReadLimit(1 << 30) - err = fn(ctx, c) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, tls) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - opts := &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - } - if tls { - opts.HTTPClient = s.Client() - } - - c, _, err := websocket.Dial(ctx, wsURL, opts) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - if testingClient { - c.SetReadLimit(1 << 30) - err = fn(ctx, c) - if err != nil { - t.Fatalf("client failed: %+v", err) - } - c.Close(websocket.StatusNormalClosure, "") - return - } - - err = wsecho.Loop(ctx, c) - if err != nil { - t.Logf("failed to wsecho: %+v", err) - } - } - t.Run(name, func(t *testing.T) { - t.Parallel() +// grace wraps s.Handler to gracefully shutdown WebSocket connections. +// The returned function must be used to close the server instead of s.Close. +func wsgrace(s *http.Server) (closeFn func() error) { + h := s.Handler + var conns int64 + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&conns, 1) + defer atomic.AddInt64(&conns, -1) - run2(t, true) - }) - } - - // Section 1. - t.Run("echo", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 0, - 125, - 126, - 127, - 128, - 65535, - 65536, - 65536, - } - run := func(typ websocket.MessageType) { - for i, l := range lengths { - l := l - run(t, fmt.Sprintf("%v/%v", typ, l), func(ctx context.Context, c *websocket.Conn) error { - p := randBytes(l) - if i == len(lengths)-1 { - w, err := c.Writer(ctx, typ) - if err != nil { - return err - } - for i := 0; i < l; { - j := i + 997 - if j > l { - j = l - } - _, err = w.Write(p[i:j]) - if err != nil { - return err - } + ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) + defer cancel() - i = j - } + r = r.WithContext(ctx) - err = w.Close() - if err != nil { - return err - } - } else { - err := c.Write(ctx, typ, p) - if err != nil { - return err - } - } - actTyp, p2, err := c.Read(ctx) - if err != nil { - return err - } - err = assert.Equalf(typ, actTyp, "unexpected message type") - if err != nil { - return err - } - return assert.Equalf(p, p2, "unexpected message") - }) - } - } - - run(websocket.MessageText) - run(websocket.MessageBinary) + h.ServeHTTP(w, r) }) - // Section 2. - t.Run("pingPong", func(t *testing.T) { - t.Parallel() - - run(t, "emptyPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - return c.PingWithPayload(ctx, "") - }) - run(t, "smallTextPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - return c.PingWithPayload(ctx, "hi") - }) - run(t, "smallBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - p := bytes.Repeat([]byte{0xFE}, 16) - return c.PingWithPayload(ctx, string(p)) - }) - run(t, "largeBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - p := bytes.Repeat([]byte{0xFE}, 125) - return c.PingWithPayload(ctx, string(p)) - }) - run(t, "tooLargeBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { - c.CloseRead(ctx) - p := bytes.Repeat([]byte{0xFE}, 126) - err := c.PingWithPayload(ctx, string(p)) - return assertCloseStatus(err, websocket.StatusProtocolError) - }) - run(t, "streamPingPayload", func(ctx context.Context, c *websocket.Conn) error { - err := assertStreamPing(ctx, c, 125) - if err != nil { - return err - } - return c.Close(websocket.StatusNormalClosure, "") - }) - t.Run("unsolicitedPong", func(t *testing.T) { - t.Parallel() - - var testCases = []struct { - name string - pongPayload string - ping bool - }{ - { - name: "noPayload", - pongPayload: "", - }, - { - name: "payload", - pongPayload: "hi", - }, - { - name: "pongThenPing", - pongPayload: "hi", - ping: true, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpPong, []byte(tc.pongPayload)) - if err != nil { - return err - } - if tc.ping { - _, err := c.WriteFrame(ctx, true, websocket.OpPing, []byte("meow")) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpPong, []byte("meow")) - if err != nil { - return err - } - } - return c.Close(websocket.StatusNormalClosure, "") - }) - } - }) - run(t, "tenPings", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - - for i := 0; i < 10; i++ { - err := c.Ping(ctx) - if err != nil { - return err - } - } + return func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() - _, err := c.WriteClose(ctx, websocket.StatusNormalClosure, "") - if err != nil { - return err - } - <-ctx.Done() - - err = c.Ping(context.Background()) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }) - - run(t, "tenStreamedPings", func(ctx context.Context, c *websocket.Conn) error { - for i := 0; i < 10; i++ { - err := assertStreamPing(ctx, c, 125) - if err != nil { - return err - } - } - - return c.Close(websocket.StatusNormalClosure, "") - }) - }) - - // Section 3. - // We skip the per octet sending as it will add too much complexity. - t.Run("reserved", func(t *testing.T) { - t.Parallel() - - var testCases = []struct { - name string - header websocket.Header - }{ - { - name: "rsv1", - header: websocket.Header{ - Fin: true, - Rsv1: true, - OpCode: websocket.OpClose, - PayloadLength: 0, - }, - }, - { - name: "rsv2", - header: websocket.Header{ - Fin: true, - Rsv2: true, - OpCode: websocket.OpPong, - PayloadLength: 0, - }, - }, - { - name: "rsv3", - header: websocket.Header{ - Fin: true, - Rsv3: true, - OpCode: websocket.OpBinary, - PayloadLength: 0, - }, - }, - { - name: "rsvAll", - header: websocket.Header{ - Fin: true, - Rsv1: true, - Rsv2: true, - Rsv3: true, - OpCode: websocket.OpText, - PayloadLength: 0, - }, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - err := assertEcho(ctx, c, websocket.MessageText, 4096) - if err != nil { - return err - } - err = c.WriteHeader(ctx, tc.header) - if err != nil { - return err - } - err = c.Flush() - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf")) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - }) - - // Section 4. - t.Run("opcodes", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - opcode websocket.OpCode - payload bool - echo bool - ping bool - }{ - // Section 1. - { - name: "3", - opcode: 3, - }, - { - name: "4", - opcode: 4, - payload: true, - }, - { - name: "5", - opcode: 5, - echo: true, - ping: true, - }, - { - name: "6", - opcode: 6, - payload: true, - echo: true, - ping: true, - }, - { - name: "7", - opcode: 7, - payload: true, - echo: true, - ping: true, - }, - - // Section 2. - { - name: "11", - opcode: 11, - }, - { - name: "12", - opcode: 12, - payload: true, - }, - { - name: "13", - opcode: 13, - payload: true, - echo: true, - ping: true, - }, - { - name: "14", - opcode: 14, - payload: true, - echo: true, - ping: true, - }, - { - name: "15", - opcode: 15, - payload: true, - echo: true, - ping: true, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - if tc.echo { - err := assertEcho(ctx, c, websocket.MessageText, 4096) - if err != nil { - return err - } - } - - p := []byte(nil) - if tc.payload { - p = randBytes(rand.Intn(4096) + 1) - } - _, err := c.WriteFrame(ctx, true, tc.opcode, p) - if err != nil { - return err - } - if tc.ping { - _, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf")) - if err != nil { - return err - } - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - }) - - // Section 5. - t.Run("fragmentation", func(t *testing.T) { - t.Parallel() - - // 5.1 to 5.8 - testCases := []struct { - name string - opcode websocket.OpCode - success bool - pingInBetween bool - }{ - { - name: "ping", - opcode: websocket.OpPing, - success: false, - }, - { - name: "pong", - opcode: websocket.OpPong, - success: false, - }, - { - name: "text", - opcode: websocket.OpText, - success: true, - }, - { - name: "textPing", - opcode: websocket.OpText, - success: true, - pingInBetween: true, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(16) - _, err := c.WriteFrame(ctx, false, tc.opcode, p1) - if err != nil { - return err - } - err = c.BW().Flush() - if err != nil { - return err - } - if !tc.success { - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - } - - if tc.pingInBetween { - _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) - if err != nil { - return err - } - } - - p2 := randBytes(16) - _, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p2) - if err != nil { - return err - } - - err = assertReadFrame(ctx, c, tc.opcode, p1) - if err != nil { - return err - } - - if tc.pingInBetween { - err = assertReadFrame(ctx, c, websocket.OpPong, p1) - if err != nil { - return err - } - } - - return assertReadFrame(ctx, c, websocket.OpContinuation, p2) - }) + err := s.Shutdown(ctx) + if err != nil { + return fmt.Errorf("server shutdown failed: %v", err) } - t.Run("unexpectedContinuation", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - fin bool - textFirst bool - }{ - { - name: "fin", - fin: true, - }, - { - name: "noFin", - fin: false, - }, - { - name: "echoFirst", - fin: false, - textFirst: true, - }, - // The rest of the tests in this section get complicated and do not inspire much confidence. - } - - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - if tc.textFirst { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - return err - } - p1 := randBytes(32) - _, err = w.Write(p1) - if err != nil { - return err - } - p2 := randBytes(32) - _, err = w.Write(p2) - if err != nil { - return err - } - err = w.Close() - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p2) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{}) - if err != nil { - return err - } - } - - _, err := c.WriteFrame(ctx, tc.fin, websocket.OpContinuation, randBytes(32)) - if err != nil { - return err - } - err = c.BW().Flush() - if err != nil { - return err - } - - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - - run(t, "doubleText", func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(32) - _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, true, websocket.OpText, randBytes(32)) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - - run(t, "5.19", func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(32) - p2 := randBytes(32) - p3 := randBytes(32) - p4 := randBytes(32) - p5 := randBytes(32) - - _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p2) - if err != nil { - return err - } - - _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) - if err != nil { - return err - } - - time.Sleep(time.Second) - - _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p3) - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p4) - if err != nil { - return err - } - - _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) - if err != nil { - return err - } - - _, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p5) - if err != nil { - return err - } - - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p2) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpPong, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p3) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p4) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpPong, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p5) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{}) - if err != nil { - return err - } - return c.Close(websocket.StatusNormalClosure, "") - }) - }) - }) - - // Section 7 - t.Run("closeHandling", func(t *testing.T) { - t.Parallel() - - // 1.1 - 1.4 is useless. - run(t, "1.5", func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(32) - _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) - if err != nil { - return err - } - err = c.Flush() - if err != nil { - return err - } - _, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "") - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure) - }) - - run(t, "1.6", func(ctx context.Context, c *websocket.Conn) error { - // 262144 bytes. - p1 := randBytes(1 << 18) - err := c.Write(ctx, websocket.MessageText, p1) - if err != nil { - return err - } - _, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "") - if err != nil { - return err - } - err = assertReadMessage(ctx, c, websocket.MessageText, p1) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure) - }) - - run(t, "emptyClose", func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, nil) - if err != nil { - return err - } - return assertReadFrame(ctx, c, websocket.OpClose, []byte{}) - }) - - run(t, "badClose", func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{1}) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - - run(t, "noReason", func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, "") - }) - - run(t, "simpleReason", func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, randString(16)) - }) - - run(t, "maxReason", func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, randString(123)) - }) - - run(t, "tooBigReason", func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, - append([]byte{0x03, 0xE8}, randString(124)...), - ) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - - t.Run("validCloses", func(t *testing.T) { - t.Parallel() - - codes := [...]websocket.StatusCode{ - 1000, - 1001, - 1002, - 1003, - 1007, - 1008, - 1009, - 1010, - 1011, - 3000, - 3999, - 4000, - 4999, - } - for _, code := range codes { - run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error { - return c.Close(code, randString(32)) - }) - } - }) - - t.Run("invalidCloseCodes", func(t *testing.T) { - t.Parallel() - - codes := []websocket.StatusCode{ - 0, - 999, - 1004, - 1005, - 1006, - 1016, - 1100, - 2000, - 2999, - 5000, - 65535, - } - for _, code := range codes { - run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error { - p := make([]byte, 2) - binary.BigEndian.PutUint16(p, uint16(code)) - p = append(p, randBytes(32)...) - _, err := c.WriteFrame(ctx, true, websocket.OpClose, p) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - }) - }) - - // Section 9. - t.Run("limits", func(t *testing.T) { - t.Parallel() - - t.Run("unfragmentedEcho", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 1 << 16, - 1 << 18, - // Anything higher is completely unnecessary. - } - - for _, l := range lengths { - l := l - run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error { - return assertEcho(ctx, c, websocket.MessageBinary, l) - }) - } - }) - - t.Run("fragmentedEcho", func(t *testing.T) { - t.Parallel() - - fragments := []int{ - 64, - 256, - 1 << 10, - 1 << 12, - 1 << 14, - 1 << 16, - } - - for _, l := range fragments { - fragmentLength := l - run(t, strconv.Itoa(fragmentLength), func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - return err - } - b := randBytes(1 << 16) - for i := 0; i < len(b); { - j := i + fragmentLength - if j > len(b) { - j = len(b) - } - - _, err = w.Write(b[i:j]) - if err != nil { - return err - } - - i = j - } - err = w.Close() - if err != nil { - return err - } - - err = assertReadMessage(ctx, c, websocket.MessageText, b) - if err != nil { - return err - } - return c.Close(websocket.StatusNormalClosure, "") - }) - } - }) - - t.Run("latencyEcho", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 0, - 16, - } - - for _, l := range lengths { - l := l - run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error { - for i := 0; i < 1000; i++ { - err := assertEcho(ctx, c, websocket.MessageBinary, l) - if err != nil { - return err - } - } + t := time.NewTicker(time.Millisecond * 10) + defer t.Stop() + for { + select { + case <-t.C: + if atomic.LoadInt64(&conns) == 0 { return nil - }) - } - }) - }) -} - -func assertCloseStatus(err error, code websocket.StatusCode) error { - var cerr websocket.CloseError - if !errors.As(err, &cerr) { - return fmt.Errorf("no websocket close error in error chain: %+v", err) - } - return assert.Equalf(code, cerr.Code, "unexpected status code") -} - -func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { - expType := reflect.TypeOf(exp) - actv := reflect.New(expType.Elem()) - act := actv.Interface().(proto.Message) - err := wspb.Read(ctx, c, act) - if err != nil { - return err - } - - return assert.Equalf(exp, act, "unexpected protobuf") -} - -func assertNetConnRead(r io.Reader, exp string) error { - act := make([]byte, len(exp)) - _, err := r.Read(act) - if err != nil { - return err - } - return assert.Equalf(exp, string(act), "unexpected net conn read") -} - -func assertErrorContains(err error, exp string) error { - if err == nil || !strings.Contains(err.Error(), exp) { - return fmt.Errorf("expected error that contains %q but got: %+v", exp, err) - } - return nil -} - -func assertErrorIs(exp, act error) error { - if !errors.Is(act, exp) { - return fmt.Errorf("expected error %+v to be in %+v", exp, act) - } - return nil -} - -func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.OpCode, p []byte) error { - actOpcode, actP, err := c.ReadFrame(ctx) - if err != nil { - return err - } - err = assert.Equalf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP) - if err != nil { - return err - } - return assert.Equalf(p, actP, "unexpected frame %v payload", opcode) -} - -func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error { - actOpcode, actP, err := c.ReadFrame(ctx) - if err != nil { - return err - } - err = assert.Equalf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP) - if err != nil { - return err - } - ce, err := websocket.ParseClosePayload(actP) - if err != nil { - return fmt.Errorf("failed to parse close frame payload: %w", err) - } - return assert.Equalf(ce.Code, code, "unexpected frame close frame code with payload %q", actP) -} - -func assertStreamPing(ctx context.Context, c *websocket.Conn, l int) error { - err := c.WriteHeader(ctx, websocket.Header{ - Fin: true, - OpCode: websocket.OpPing, - PayloadLength: int64(l), - }) - if err != nil { - return err - } - for i := 0; i < l; i++ { - err = c.BW().WriteByte(0xFE) - if err != nil { - return fmt.Errorf("failed to write byte %d: %w", i, err) - } - if i%32 == 0 { - err = c.BW().Flush() - if err != nil { - return fmt.Errorf("failed to flush at byte %d: %w", i, err) + } + case <-ctx.Done(): + return fmt.Errorf("failed to wait for WebSocket connections: %v", ctx.Err()) } } } - err = c.BW().Flush() - if err != nil { - return fmt.Errorf("failed to flush: %v", err) - } - return assertReadFrame(ctx, c, websocket.OpPong, bytes.Repeat([]byte{0xFE}, l)) -} - -func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, p []byte) error { - actTyp, actP, err := c.Read(ctx) - if err != nil { - return err - } - err = assert.Equalf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP) - if err != nil { - return err - } - return assert.Equalf(p, actP, "unexpected frame %v payload", actTyp) -} - -func BenchmarkConn(b *testing.B) { - sizes := []int{ - 2, - 16, - 32, - 512, - 4096, - 16384, - } - - b.Run("write", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - b.Run("stream", func(b *testing.B) { - benchConn(b, false, true, size) - }) - b.Run("buffer", func(b *testing.B) { - benchConn(b, false, false, size) - }) - }) - } - }) - - b.Run("echo", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - benchConn(b, true, true, size) - }) - } - }) } -func benchConn(b *testing.B, echo, stream bool, size int) { - s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - if echo { - wsecho.Loop(r.Context(), c) - } else { - discardLoop(r.Context(), c) - } - return nil - }, false) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() - - c, _, err := websocket.Dial(ctx, wsURL, nil) - if err != nil { - b.Fatal(err) - } +// echoLoop echos every msg received from c until an error +// occurs or the context expires. +// The read limit is set to 1 << 30. +func echoLoop(ctx context.Context, c *websocket.Conn) error { defer c.Close(websocket.StatusInternalError, "") - msg := []byte(strings.Repeat("2", size)) - readBuf := make([]byte, len(msg)) - b.SetBytes(int64(len(msg))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if stream { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - b.Fatal(err) - } - - _, err = w.Write(msg) - if err != nil { - b.Fatal(err) - } - - err = w.Close() - if err != nil { - b.Fatal(err) - } - } else { - err = c.Write(ctx, websocket.MessageText, msg) - if err != nil { - b.Fatal(err) - } - } - - if echo { - _, r, err := c.Reader(ctx) - if err != nil { - b.Fatal(err) - } - - _, err = io.ReadFull(r, readBuf) - if err != nil { - b.Fatal(err) - } - } - } - b.StopTimer() - - c.Close(websocket.StatusNormalClosure, "") -} - -func discardLoop(ctx context.Context, c *websocket.Conn) { - defer c.Close(websocket.StatusInternalError, "") + c.SetReadLimit(1 << 30) ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() - b := make([]byte, 32768) - echo := func() error { - _, r, err := c.Reader(ctx) + b := make([]byte, 32<<10) + for { + typ, r, err := c.Reader(ctx) if err != nil { return err } - _, err = io.CopyBuffer(ioutil.Discard, r, b) + w, err := c.Writer(ctx, typ) if err != nil { return err } - return nil - } - for { - err := echo() + _, err = io.CopyBuffer(w, r, b) if err != nil { - return + return err } - } -} - -func TestAutobahnPython(t *testing.T) { - // This test contains the old autobahn test suite tests that use the - // python binary. The approach is clunky and slow so new tests - // have been written in pure Go in websocket_test.go. - // These have been kept for correctness purposes and are occasionally ran. - if os.Getenv("AUTOBAHN_PYTHON") == "" { - t.Skip("Set $AUTOBAHN_PYTHON to run tests against the python autobahn test suite") - } - - t.Run("server", testServerAutobahnPython) - t.Run("client", testClientAutobahnPython) -} - -// https://github.com/crossbario/autobahn-python/tree/master/wstest -func testServerAutobahnPython(t *testing.T) { - t.Parallel() - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - }) + err = w.Close() if err != nil { - t.Logf("server handshake failed: %+v", err) - return + return err } - wsecho.Loop(r.Context(), c) - })) - defer s.Close() - - spec := map[string]interface{}{ - "outdir": "ci/out/wstestServerReports", - "servers": []interface{}{ - map[string]interface{}{ - "agent": "main", - "url": strings.Replace(s.URL, "http", "ws", 1), - }, - }, - "cases": []string{"*"}, - // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just - // more performance overhead. 7.5.1 is the same. - // 12.* and 13.* as we do not support compression. - "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) - } - defer specFile.Close() - - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } - - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() - - args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} - wstest := exec.CommandContext(ctx, "wstest", args...) - out, err := wstest.CombinedOutput() - if err != nil { - t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) } - - checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") } -func unusedListenAddr() (string, error) { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - return "", err - } - l.Close() - return l.Addr().String(), nil -} - -// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py -func testClientAutobahnPython(t *testing.T) { +func TestConn(t *testing.T) { t.Parallel() - if os.Getenv("AUTOBAHN_PYTHON") == "" { - t.Skip("Set $AUTOBAHN_PYTHON to test against the python autobahn test suite") - } - - serverAddr, err := unusedListenAddr() - if err != nil { - t.Fatalf("failed to get unused listen addr for wstest: %v", err) - } - - wsServerURL := "ws://" + serverAddr - - spec := map[string]interface{}{ - "url": wsServerURL, - "outdir": "ci/out/wstestClientReports", - "cases": []string{"*"}, - // See TestAutobahnServer for the reasons why we exclude these. - "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) - } - defer specFile.Close() - - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } - - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() - - args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), - // Disables some server that runs as part of fuzzingserver mode. - // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 - "--webport=0", - } - wstest := exec.CommandContext(ctx, "wstest", args...) - err = wstest.Start() - if err != nil { - t.Fatal(err) - } - defer func() { - err := wstest.Process.Kill() - if err != nil { - t.Error(err) - } - }() - - // Let it come up. - time.Sleep(time.Second * 5) - - var cases int - func() { - c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - _, r, err := c.Reader(ctx) - if err != nil { - t.Fatal(err) - } - b, err := ioutil.ReadAll(r) - if err != nil { - t.Fatal(err) - } - cases, err = strconv.Atoi(string(b)) - if err != nil { - t.Fatal(err) - } - - c.Close(websocket.StatusNormalClosure, "") - }() - - for i := 1; i <= cases; i++ { - func() { - ctx, cancel := context.WithTimeout(ctx, time.Second*45) - defer cancel() - - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) - if err != nil { - t.Fatal(err) - } - wsecho.Loop(ctx, c) - }() - } - - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) - if err != nil { - t.Fatal(err) - } - c.Close(websocket.StatusNormalClosure, "") + t.Run("json", func(t *testing.T) { + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + InsecureSkipVerify: true, + }) + assert.Success(t, err) + defer c.Close(websocket.StatusInternalError, "") - checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") -} + err = echoLoop(r.Context(), c) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }, false) + defer closeFn() -func checkWSTestIndex(t *testing.T, path string) { - wstestOut, err := ioutil.ReadFile(path) - if err != nil { - t.Fatalf("failed to read index.json: %v", err) - } + wsURL := strings.Replace(s.URL, "http", "ws", 1) - var indexJSON map[string]map[string]struct { - Behavior string `json:"behavior"` - BehaviorClose string `json:"behaviorClose"` - } - err = json.Unmarshal(wstestOut, &indexJSON) - if err != nil { - t.Fatalf("failed to unmarshal index.json: %v", err) - } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() - var failed bool - for _, tests := range indexJSON { - for test, result := range tests { - switch result.Behavior { - case "OK", "NON-STRICT", "INFORMATIONAL": - default: - failed = true - t.Errorf("test %v failed", test) - } - switch result.BehaviorClose { - case "OK", "INFORMATIONAL": - default: - failed = true - t.Errorf("bad close behaviour for test %v", test) - } - } - } - - if failed { - path = strings.Replace(path, ".json", ".html", 1) - if os.Getenv("CI") == "" { - t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path) - } - } -} - -func TestWASM(t *testing.T) { - t.Parallel() - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - InsecureSkipVerify: true, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - err = wsecho.Loop(r.Context(), c) - if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - return err + opts := &websocket.DialOptions{ + Subprotocols: []string{"echo"}, } - return nil - }, false) - defer closeFn() + opts.HTTPClient = s.Client() - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() + c, _, err := websocket.Dial(ctx, wsURL, opts) + assert.Success(t, err) - cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") - cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wsURL)) - - b, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("wasm test binary failed: %v:\n%s", err, b) - } + assertJSONEcho(t, ctx, c, 2) + }) } diff --git a/dial.go b/dial.go index 10088681..8fa0f7ab 100644 --- a/dial.go +++ b/dial.go @@ -1,17 +1,19 @@ package websocket import ( + "bufio" "bytes" "context" "crypto/rand" "encoding/base64" + "errors" "fmt" "io" "io/ioutil" "net/http" "net/url" - "nhooyr.io/websocket/internal/bufpool" "strings" + "sync" ) // DialOptions represents the options available to pass to Dial. @@ -50,7 +52,7 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon return c, r, nil } -func (opts *DialOptions) fill() (*DialOptions, error) { +func (opts *DialOptions) ensure() *DialOptions { if opts == nil { opts = &DialOptions{} } else { @@ -60,20 +62,18 @@ func (opts *DialOptions) fill() (*DialOptions, error) { if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } - if opts.HTTPClient.Timeout > 0 { - return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } - return opts, nil + return opts } func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { - opts, err = opts.fill() - if err != nil { - return nil, nil, err + opts = opts.ensure() + + if opts.HTTPClient.Timeout > 0 { + return nil, nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") } parsedURL, err := url.Parse(u) @@ -104,8 +104,10 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } - copts := opts.CompressionMode.opts() - copts.setHeader(req.Header) + if opts.CompressionMode != CompressionDisabled { + copts := opts.CompressionMode.opts() + copts.setHeader(req.Header) + } resp, err := opts.HTTPClient.Do(req) if err != nil { @@ -121,7 +123,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re } }() - copts, err = verifyServerResponse(req, resp, opts) + copts, err := verifyServerResponse(req, resp) if err != nil { return nil, resp, err } @@ -131,18 +133,14 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", rwc) } - c := &Conn{ + return newConn(connConfig{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), - br: bufpool.GetReader(rwc), - bw: bufpool.GetWriter(rwc), - closer: rwc, + rwc: rwc, client: true, copts: copts, - } - c.extractBufioWriterBuf(rwc) - c.init() - - return c, resp, nil + br: getBufioReader(rwc), + bw: getBufioWriter(rwc), + }), resp, nil } func secWebSocketKey() (string, error) { @@ -154,7 +152,7 @@ func secWebSocketKey() (string, error) { return base64.StdEncoding.EncodeToString(b), nil } -func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*compressionOptions, error) { +func verifyServerResponse(r *http.Request, resp *http.Response) (*compressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } @@ -178,7 +176,7 @@ func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOption return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } - copts, err := verifyServerExtensions(resp.Header, opts.CompressionMode) + copts, err := verifyServerExtensions(resp.Header) if err != nil { return nil, err } @@ -186,7 +184,7 @@ func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOption return copts, nil } -func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOptions, error) { +func verifyServerExtensions(h http.Header) (*compressionOptions, error) { exts := websocketExtensions(h) if len(exts) == 0 { return nil, nil @@ -201,7 +199,7 @@ func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOp return nil, fmt.Errorf("unexpected extra extensions from server: %+v", exts[1:]) } - copts := mode.opts() + copts := &compressionOptions{} for _, p := range ext.params { switch p { case "client_no_context_takeover": @@ -217,3 +215,33 @@ func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOp return copts, nil } + +var readerPool sync.Pool + +func getBufioReader(r io.Reader) *bufio.Reader { + br, ok := readerPool.Get().(*bufio.Reader) + if !ok { + return bufio.NewReader(r) + } + br.Reset(r) + return br +} + +func putBufioReader(br *bufio.Reader) { + readerPool.Put(br) +} + +var writerPool sync.Pool + +func getBufioWriter(w io.Writer) *bufio.Writer { + bw, ok := writerPool.Get().(*bufio.Writer) + if !ok { + return bufio.NewWriter(w) + } + bw.Reset(w) + return bw +} + +func putBufioWriter(bw *bufio.Writer) { + writerPool.Put(bw) +} diff --git a/dial_test.go b/dial_test.go index 391aa1ce..5eeb904a 100644 --- a/dial_test.go +++ b/dial_test.go @@ -140,7 +140,7 @@ func Test_verifyServerHandshake(t *testing.T) { resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } - _, err = verifyServerResponse(r, resp, &DialOptions{}) + _, err = verifyServerResponse(r, resp) if (err == nil) != tc.success { t.Fatalf("unexpected error: %+v", err) } diff --git a/example_echo_test.go b/example_echo_test.go index ecc9b97c..16d003d9 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -4,6 +4,7 @@ package websocket_test import ( "context" + "errors" "fmt" "io" "log" @@ -77,7 +78,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { if c.Subprotocol() != "echo" { c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") - return fmt.Errorf("client does not speak echo sub protocol") + return errors.New("client does not speak echo sub protocol") } l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) diff --git a/internal/wsframe/mask.go b/frame.go similarity index 57% rename from internal/wsframe/mask.go rename to frame.go index 2da4c11d..0f10d553 100644 --- a/internal/wsframe/mask.go +++ b/frame.go @@ -1,11 +1,167 @@ -package wsframe +package websocket import ( + "bufio" "encoding/binary" + "math" "math/bits" + "nhooyr.io/websocket/internal/errd" ) -// Mask applies the WebSocket masking algorithm to p +// opcode represents a WebSocket opcode. +type opcode int + +// List at https://tools.ietf.org/html/rfc6455#section-11.8. +const ( + opContinuation opcode = iota + opText + opBinary + // 3 - 7 are reserved for further non-control frames. + _ + _ + _ + _ + _ + opClose + opPing + opPong + // 11-16 are reserved for further control frames. +) + +// header represents a WebSocket frame header. +// See https://tools.ietf.org/html/rfc6455#section-5.2. +type header struct { + fin bool + rsv1 bool + rsv2 bool + rsv3 bool + opcode opcode + + payloadLength int64 + + masked bool + maskKey uint32 +} + +// readFrameHeader reads a header from the reader. +// See https://tools.ietf.org/html/rfc6455#section-5.2. +func readFrameHeader(r *bufio.Reader) (_ header, err error) { + defer errd.Wrap(&err, "failed to read frame header") + + b, err := r.ReadByte() + if err != nil { + return header{}, err + } + + var h header + h.fin = b&(1<<7) != 0 + h.rsv1 = b&(1<<6) != 0 + h.rsv2 = b&(1<<5) != 0 + h.rsv3 = b&(1<<4) != 0 + + h.opcode = opcode(b & 0xf) + + b, err = r.ReadByte() + if err != nil { + return header{}, err + } + + h.masked = b&(1<<7) != 0 + + payloadLength := b &^ (1 << 7) + switch { + case payloadLength < 126: + h.payloadLength = int64(payloadLength) + case payloadLength == 126: + var pl uint16 + err = binary.Read(r, binary.BigEndian, &pl) + h.payloadLength = int64(pl) + case payloadLength == 127: + err = binary.Read(r, binary.BigEndian, &h.payloadLength) + } + if err != nil { + return header{}, err + } + + if h.masked { + err = binary.Read(r, binary.LittleEndian, &h.maskKey) + if err != nil { + return header{}, err + } + } + + return h, nil +} + +// maxControlPayload is the maximum length of a control frame payload. +// See https://tools.ietf.org/html/rfc6455#section-5.5. +const maxControlPayload = 125 + +// writeFrameHeader writes the bytes of the header to w. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func writeFrameHeader(h header, w *bufio.Writer) (err error) { + defer errd.Wrap(&err, "failed to write frame header") + + var b byte + if h.fin { + b |= 1 << 7 + } + if h.rsv1 { + b |= 1 << 6 + } + if h.rsv2 { + b |= 1 << 5 + } + if h.rsv3 { + b |= 1 << 4 + } + + b |= byte(h.opcode) + + err = w.WriteByte(b) + if err != nil { + return err + } + + lengthByte := byte(0) + if h.masked { + lengthByte |= 1 << 7 + } + + switch { + case h.payloadLength > math.MaxUint16: + lengthByte |= 127 + case h.payloadLength > 125: + lengthByte |= 126 + case h.payloadLength >= 0: + lengthByte |= byte(h.payloadLength) + } + err = w.WriteByte(lengthByte) + if err != nil { + return err + } + + switch { + case h.payloadLength > math.MaxUint16: + err = binary.Write(w, binary.BigEndian, h.payloadLength) + case h.payloadLength > 125: + err = binary.Write(w, binary.BigEndian, uint16(h.payloadLength)) + } + if err != nil { + return err + } + + if h.masked { + err = binary.Write(w, binary.LittleEndian, h.maskKey) + if err != nil { + return err + } + } + + return nil +} + +// mask applies the WebSocket masking algorithm to p // with the given key. // See https://tools.ietf.org/html/rfc6455#section-5.3 // @@ -16,7 +172,7 @@ import ( // to be in little endian. // // See https://github.com/golang/go/issues/31586 -func Mask(key uint32, b []byte) uint32 { +func mask(key uint32, b []byte) uint32 { if len(b) >= 8 { key64 := uint64(key)<<32 | uint64(key) diff --git a/internal/wsframe/mask_test.go b/frame_test.go similarity index 51% rename from internal/wsframe/mask_test.go rename to frame_test.go index fbd29892..0ed14aef 100644 --- a/internal/wsframe/mask_test.go +++ b/frame_test.go @@ -1,32 +1,108 @@ -package wsframe_test +// +build !js + +package websocket import ( - "crypto/rand" + "bufio" + "bytes" "encoding/binary" - "github.com/gobwas/ws" - "github.com/google/go-cmp/cmp" "math/bits" - "nhooyr.io/websocket/internal/wsframe" + "nhooyr.io/websocket/internal/assert" "strconv" "testing" + "time" _ "unsafe" + + "github.com/gobwas/ws" + _ "github.com/gorilla/websocket" + "math/rand" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func TestHeader(t *testing.T) { + t.Parallel() + + t.Run("lengths", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 124, + 125, + 126, + 127, + + 65534, + 65535, + 65536, + 65537, + } + + for _, n := range lengths { + n := n + t.Run(strconv.Itoa(n), func(t *testing.T) { + t.Parallel() + + testHeader(t, header{ + payloadLength: int64(n), + }) + }) + } + }) + + t.Run("fuzz", func(t *testing.T) { + t.Parallel() + + randBool := func() bool { + return rand.Intn(1) == 0 + } + + for i := 0; i < 10000; i++ { + h := header{ + fin: randBool(), + rsv1: randBool(), + rsv2: randBool(), + rsv3: randBool(), + opcode: opcode(rand.Intn(16)), + + masked: randBool(), + maskKey: rand.Uint32(), + payloadLength: rand.Int63(), + } + + testHeader(t, h) + } + }) +} + +func testHeader(t *testing.T, h header) { + b := &bytes.Buffer{} + w := bufio.NewWriter(b) + r := bufio.NewReader(b) + + err := writeFrameHeader(h, w) + assert.Success(t, err) + err = w.Flush() + assert.Success(t, err) + + h2, err := readFrameHeader(r) + assert.Success(t, err) + + assert.Equalf(t, h, h2, "written and read headers differ") +} + func Test_mask(t *testing.T) { t.Parallel() key := []byte{0xa, 0xb, 0xc, 0xff} key32 := binary.LittleEndian.Uint32(key) p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} - gotKey32 := wsframe.Mask(key32, p) + gotKey32 := mask(key32, p) - if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { - t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) - } - - if exp := bits.RotateLeft32(key32, -8); !cmp.Equal(exp, gotKey32) { - t.Fatalf("unexpected mask key: %v", cmp.Diff(exp, gotKey32)) - } + assert.Equalf(t, []byte{0, 0, 0, 0x0d, 0x6}, p, "unexpected mask") + assert.Equalf(t, bits.RotateLeft32(key32, -8), gotKey32, "unexpected mask key") } func basicMask(maskKey [4]byte, pos int, b []byte) int { @@ -74,7 +150,7 @@ func Benchmark_mask(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - wsframe.Mask(key32, p) + mask(key32, p) } }, }, @@ -98,9 +174,7 @@ func Benchmark_mask(b *testing.B) { var key [4]byte _, err := rand.Read(key[:]) - if err != nil { - b.Fatalf("failed to populate mask key: %v", err) - } + assert.Success(b, err) for _, size := range sizes { p := make([]byte, size) diff --git a/internal/assert/assert.go b/internal/assert/assert.go index 372d5465..1d9aeced 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -2,6 +2,7 @@ package assert import ( "reflect" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -53,7 +54,7 @@ func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { } } -func Equalf(t *testing.T, exp, act interface{}, f string, v ...interface{}) { +func Equalf(t testing.TB, exp, act interface{}, f string, v ...interface{}) { t.Helper() diff := cmpDiff(exp, act) if diff != "" { @@ -61,7 +62,40 @@ func Equalf(t *testing.T, exp, act interface{}, f string, v ...interface{}) { } } -func Success(t *testing.T, err error) { +func NotEqualf(t testing.TB, exp, act interface{}, f string, v ...interface{}) { t.Helper() - Equalf(t, error(nil), err, "unexpected failure") + diff := cmpDiff(exp, act) + if diff == "" { + t.Fatalf(f+": %v", append(v, diff)...) + } +} + +func Success(t testing.TB, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } +} + +func Error(t testing.TB, err error) { + t.Helper() + if err == nil { + t.Fatal("expected error") + } +} + +func ErrorContains(t testing.TB, err error, sub string) { + t.Helper() + Error(t, err) + errs := err.Error() + if !strings.Contains(errs, sub) { + t.Fatalf("error string %q does not contain %q", errs, sub) + } +} + +func Panicf(t testing.TB, f string, v ...interface{}) { + r := recover() + if r == nil { + t.Fatalf(f, v...) + } } diff --git a/internal/atomicint/atomicint.go b/internal/atomicint/atomicint.go deleted file mode 100644 index 668b3b4b..00000000 --- a/internal/atomicint/atomicint.go +++ /dev/null @@ -1,32 +0,0 @@ -package atomicint - -import ( - "fmt" - "sync/atomic" -) - -// See https://github.com/nhooyr/websocket/issues/153 -type Int64 struct { - v int64 -} - -func (v *Int64) Load() int64 { - return atomic.LoadInt64(&v.v) -} - -func (v *Int64) Store(i int64) { - atomic.StoreInt64(&v.v, i) -} - -func (v *Int64) String() string { - return fmt.Sprint(v.Load()) -} - -// Increment increments the value and returns the new value. -func (v *Int64) Increment(delta int64) int64 { - return atomic.AddInt64(&v.v, delta) -} - -func (v *Int64) CAS(old, new int64) (swapped bool) { - return atomic.CompareAndSwapInt64(&v.v, old, new) -} diff --git a/internal/bufpool/buf.go b/internal/bufpool/buf.go index 324a17e1..0f7d9765 100644 --- a/internal/bufpool/buf.go +++ b/internal/bufpool/buf.go @@ -5,12 +5,12 @@ import ( "sync" ) -var bpool sync.Pool +var pool sync.Pool // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { - b, ok := bpool.Get().(*bytes.Buffer) + b, ok := pool.Get().(*bytes.Buffer) if !ok { b = &bytes.Buffer{} } @@ -20,5 +20,5 @@ func Get() *bytes.Buffer { // Put returns a buffer into the pool. func Put(b *bytes.Buffer) { b.Reset() - bpool.Put(b) + pool.Put(b) } diff --git a/internal/bufpool/bufio.go b/internal/bufpool/bufio.go deleted file mode 100644 index 875bbf4b..00000000 --- a/internal/bufpool/bufio.go +++ /dev/null @@ -1,40 +0,0 @@ -package bufpool - -import ( - "bufio" - "io" - "sync" -) - -var readerPool = sync.Pool{ - New: func() interface{} { - return bufio.NewReader(nil) - }, -} - -func GetReader(r io.Reader) *bufio.Reader { - br := readerPool.Get().(*bufio.Reader) - br.Reset(r) - return br -} - -func PutReader(br *bufio.Reader) { - readerPool.Put(br) -} - -var writerPool = sync.Pool{ - New: func() interface{} { - return bufio.NewWriter(nil) - }, -} - -func GetWriter(w io.Writer) *bufio.Writer { - bw := writerPool.Get().(*bufio.Writer) - bw.Reset(w) - return bw -} - -func PutWriter(bw *bufio.Writer) { - writerPool.Put(bw) -} - diff --git a/internal/errd/errd.go b/internal/errd/errd.go new file mode 100644 index 00000000..51b7b4f6 --- /dev/null +++ b/internal/errd/errd.go @@ -0,0 +1,11 @@ +package errd + +import ( + "fmt" +) + +func Wrap(err *error, f string, v ...interface{}) { + if *err != nil { + *err = fmt.Errorf(f+ ": %w", append(v, *err)...) + } +} diff --git a/internal/wsecho/wsecho.go b/internal/wsecho/wsecho.go deleted file mode 100644 index c408f07f..00000000 --- a/internal/wsecho/wsecho.go +++ /dev/null @@ -1,55 +0,0 @@ -// +build !js - -package wsecho - -import ( - "context" - "io" - "time" - - "nhooyr.io/websocket" -) - -// Loop echos every msg received from c until an error -// occurs or the context expires. -// The read limit is set to 1 << 30. -func Loop(ctx context.Context, c *websocket.Conn) error { - defer c.Close(websocket.StatusInternalError, "") - - c.SetReadLimit(1 << 30) - - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - - b := make([]byte, 32<<10) - echo := func() error { - typ, r, err := c.Reader(ctx) - if err != nil { - return err - } - - w, err := c.Writer(ctx, typ) - if err != nil { - return err - } - - _, err = io.CopyBuffer(w, r, b) - if err != nil { - return err - } - - err = w.Close() - if err != nil { - return err - } - - return nil - } - - for { - err := echo() - if err != nil { - return err - } - } -} diff --git a/internal/wsframe/frame.go b/internal/wsframe/frame.go deleted file mode 100644 index 50ff8c11..00000000 --- a/internal/wsframe/frame.go +++ /dev/null @@ -1,194 +0,0 @@ -package wsframe - -import ( - "encoding/binary" - "fmt" - "io" - "math" -) - -// Opcode represents a WebSocket Opcode. -type Opcode int - -// Opcode constants. -const ( - OpContinuation Opcode = iota - OpText - OpBinary - // 3 - 7 are reserved for further non-control frames. - _ - _ - _ - _ - _ - OpClose - OpPing - OpPong - // 11-16 are reserved for further control frames. -) - -func (o Opcode) Control() bool { - switch o { - case OpClose, OpPing, OpPong: - return true - } - return false -} - -func (o Opcode) Data() bool { - switch o { - case OpText, OpBinary: - return true - } - return false -} - -// First byte contains fin, rsv1, rsv2, rsv3. -// Second byte contains mask flag and payload len. -// Next 8 bytes are the maximum extended payload length. -// Last 4 bytes are the mask key. -// https://tools.ietf.org/html/rfc6455#section-5.2 -const maxHeaderSize = 1 + 1 + 8 + 4 - -// Header represents a WebSocket frame Header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -type Header struct { - Fin bool - RSV1 bool - RSV2 bool - RSV3 bool - Opcode Opcode - - PayloadLength int64 - - Masked bool - MaskKey uint32 -} - -// bytes returns the bytes of the Header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func (h Header) Bytes(b []byte) []byte { - if b == nil { - b = make([]byte, maxHeaderSize) - } - - b = b[:2] - b[0] = 0 - - if h.Fin { - b[0] |= 1 << 7 - } - if h.RSV1 { - b[0] |= 1 << 6 - } - if h.RSV2 { - b[0] |= 1 << 5 - } - if h.RSV3 { - b[0] |= 1 << 4 - } - - b[0] |= byte(h.Opcode) - - switch { - case h.PayloadLength < 0: - panic(fmt.Sprintf("websocket: invalid Header: negative length: %v", h.PayloadLength)) - case h.PayloadLength <= 125: - b[1] = byte(h.PayloadLength) - case h.PayloadLength <= math.MaxUint16: - b[1] = 126 - b = b[:len(b)+2] - binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.PayloadLength)) - default: - b[1] = 127 - b = b[:len(b)+8] - binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.PayloadLength)) - } - - if h.Masked { - b[1] |= 1 << 7 - b = b[:len(b)+4] - binary.LittleEndian.PutUint32(b[len(b)-4:], h.MaskKey) - } - - return b -} - -func MakeReadHeaderBuf() []byte { - return make([]byte, maxHeaderSize-2) -} - -// ReadHeader reads a Header from the reader. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func ReadHeader(r io.Reader, b []byte) (Header, error) { - // We read the first two bytes first so that we know - // exactly how long the Header is. - b = b[:2] - _, err := io.ReadFull(r, b) - if err != nil { - return Header{}, err - } - - var h Header - h.Fin = b[0]&(1<<7) != 0 - h.RSV1 = b[0]&(1<<6) != 0 - h.RSV2 = b[0]&(1<<5) != 0 - h.RSV3 = b[0]&(1<<4) != 0 - - h.Opcode = Opcode(b[0] & 0xf) - - var extra int - - h.Masked = b[1]&(1<<7) != 0 - if h.Masked { - extra += 4 - } - - payloadLength := b[1] &^ (1 << 7) - switch { - case payloadLength < 126: - h.PayloadLength = int64(payloadLength) - case payloadLength == 126: - extra += 2 - case payloadLength == 127: - extra += 8 - } - - if extra == 0 { - return h, nil - } - - b = b[:extra] - _, err = io.ReadFull(r, b) - if err != nil { - return Header{}, err - } - - switch { - case payloadLength == 126: - h.PayloadLength = int64(binary.BigEndian.Uint16(b)) - b = b[2:] - case payloadLength == 127: - h.PayloadLength = int64(binary.BigEndian.Uint64(b)) - if h.PayloadLength < 0 { - return Header{}, fmt.Errorf("Header with negative payload length: %v", h.PayloadLength) - } - b = b[8:] - } - - if h.Masked { - h.MaskKey = binary.LittleEndian.Uint32(b) - } - - return h, nil -} - -const MaxControlFramePayload = 125 - -func ParseClosePayload(p []byte) (uint16, string, error) { - if len(p) < 2 { - return 0, "", fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) - } - - return binary.BigEndian.Uint16(p), string(p[2:]), nil -} diff --git a/internal/wsframe/frame_stringer.go b/internal/wsframe/frame_stringer.go deleted file mode 100644 index b2e7f423..00000000 --- a/internal/wsframe/frame_stringer.go +++ /dev/null @@ -1,91 +0,0 @@ -// Code generated by "stringer -type=Opcode,MessageType,StatusCode -output=frame_stringer.go"; DO NOT EDIT. - -package wsframe - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[OpContinuation-0] - _ = x[OpText-1] - _ = x[OpBinary-2] - _ = x[OpClose-8] - _ = x[OpPing-9] - _ = x[OpPong-10] -} - -const ( - _opcode_name_0 = "opContinuationopTextopBinary" - _opcode_name_1 = "opCloseopPingopPong" -) - -var ( - _opcode_index_0 = [...]uint8{0, 14, 20, 28} - _opcode_index_1 = [...]uint8{0, 7, 13, 19} -) - -func (i Opcode) String() string { - switch { - case 0 <= i && i <= 2: - return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] - case 8 <= i && i <= 10: - i -= 8 - return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] - default: - return "Opcode(" + strconv.FormatInt(int64(i), 10) + ")" - } -} -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[MessageText-1] - _ = x[MessageBinary-2] -} - -const _MessageType_name = "MessageTextMessageBinary" - -var _MessageType_index = [...]uint8{0, 11, 24} - -func (i MessageType) String() string { - i -= 1 - if i < 0 || i >= MessageType(len(_MessageType_index)-1) { - return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" - } - return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] -} -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[StatusNormalClosure-1000] - _ = x[StatusGoingAway-1001] - _ = x[StatusProtocolError-1002] - _ = x[StatusUnsupportedData-1003] - _ = x[statusReserved-1004] - _ = x[StatusNoStatusRcvd-1005] - _ = x[StatusAbnormalClosure-1006] - _ = x[StatusInvalidFramePayloadData-1007] - _ = x[StatusPolicyViolation-1008] - _ = x[StatusMessageTooBig-1009] - _ = x[StatusMandatoryExtension-1010] - _ = x[StatusInternalError-1011] - _ = x[StatusServiceRestart-1012] - _ = x[StatusTryAgainLater-1013] - _ = x[StatusBadGateway-1014] - _ = x[StatusTLSHandshake-1015] -} - -const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake" - -var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312} - -func (i StatusCode) String() string { - i -= 1000 - if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) { - return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")" - } - return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]] -} diff --git a/internal/wsframe/frame_test.go b/internal/wsframe/frame_test.go deleted file mode 100644 index d6b66e7e..00000000 --- a/internal/wsframe/frame_test.go +++ /dev/null @@ -1,157 +0,0 @@ -// +build !js - -package wsframe - -import ( - "bytes" - "io" - "math/rand" - "strconv" - "testing" - "time" - _ "unsafe" - - "github.com/google/go-cmp/cmp" - _ "github.com/gorilla/websocket" -) - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func randBool() bool { - return rand.Intn(1) == 0 -} - -func TestHeader(t *testing.T) { - t.Parallel() - - t.Run("eof", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - bytes []byte - }{ - { - "start", - []byte{0xff}, - }, - { - "middle", - []byte{0xff, 0xff, 0xff}, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - b := bytes.NewBuffer(tc.bytes) - _, err := ReadHeader(nil, b) - if io.ErrUnexpectedEOF != err { - t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) - } - }) - } - }) - - t.Run("writeNegativeLength", func(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r == nil { - t.Fatal("failed to induce panic in writeHeader with negative payload length") - } - }() - - Header{ - PayloadLength: -1, - }.Bytes(nil) - }) - - t.Run("readNegativeLength", func(t *testing.T) { - t.Parallel() - - b := Header{ - PayloadLength: 1<<16 + 1, - }.Bytes(nil) - - // Make length negative - b[2] |= 1 << 7 - - r := bytes.NewReader(b) - _, err := ReadHeader(nil, r) - if err == nil { - t.Fatalf("unexpected error value: %+v", err) - } - }) - - t.Run("lengths", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 124, - 125, - 126, - 4096, - 16384, - 65535, - 65536, - 65537, - 131072, - } - - for _, n := range lengths { - n := n - t.Run(strconv.Itoa(n), func(t *testing.T) { - t.Parallel() - - testHeader(t, Header{ - PayloadLength: int64(n), - }) - }) - } - }) - - t.Run("fuzz", func(t *testing.T) { - t.Parallel() - - for i := 0; i < 10000; i++ { - h := Header{ - Fin: randBool(), - RSV1: randBool(), - RSV2: randBool(), - RSV3: randBool(), - Opcode: Opcode(rand.Intn(1 << 4)), - - Masked: randBool(), - PayloadLength: rand.Int63(), - } - - if h.Masked { - h.MaskKey = rand.Uint32() - } - - testHeader(t, h) - } - }) -} - -func testHeader(t *testing.T, h Header) { - b := h.Bytes(nil) - r := bytes.NewReader(b) - h2, err := ReadHeader(r, nil) - if err != nil { - t.Logf("Header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("failed to read Header: %v", err) - } - - if !cmp.Equal(h, h2, cmp.AllowUnexported(Header{})) { - t.Logf("Header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("parsed and read Header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(Header{}))) - } -} diff --git a/internal/wsgrace/wsgrace.go b/internal/wsgrace/wsgrace.go deleted file mode 100644 index 513af1fe..00000000 --- a/internal/wsgrace/wsgrace.go +++ /dev/null @@ -1,50 +0,0 @@ -package wsgrace - -import ( - "context" - "fmt" - "net/http" - "sync/atomic" - "time" -) - -// Grace wraps s.Handler to gracefully shutdown WebSocket connections. -// The returned function must be used to close the server instead of s.Close. -func Grace(s *http.Server) (closeFn func() error) { - h := s.Handler - var conns int64 - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&conns, 1) - defer atomic.AddInt64(&conns, -1) - - ctx, cancel := context.WithTimeout(r.Context(), time.Minute) - defer cancel() - - r = r.WithContext(ctx) - - h.ServeHTTP(w, r) - }) - - return func() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - err := s.Shutdown(ctx) - if err != nil { - return fmt.Errorf("server shutdown failed: %v", err) - } - - t := time.NewTicker(time.Millisecond * 10) - defer t.Stop() - for { - select { - case <-t.C: - if atomic.LoadInt64(&conns) == 0 { - return nil - } - case <-ctx.Done(): - return fmt.Errorf("failed to wait for WebSocket connections: %v", ctx.Err()) - } - } - } -} diff --git a/js_test.go b/js_test.go deleted file mode 100644 index 80af7896..00000000 --- a/js_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package websocket_test - -import ( - "context" - "fmt" - "net/http" - "nhooyr.io/websocket/internal/wsecho" - "os" - "os/exec" - "strings" - "testing" - "time" - - "nhooyr.io/websocket" -) - -func TestJS(t *testing.T) { - t.Parallel() - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - InsecureSkipVerify: true, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - err = wsecho.Loop(r.Context(), c) - if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - return err - } - return nil - }, false) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") - cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wsURL)) - - b, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("wasm test binary failed: %v:\n%s", err, b) - } -} diff --git a/read.go b/read.go new file mode 100644 index 00000000..97096f74 --- /dev/null +++ b/read.go @@ -0,0 +1,479 @@ +package websocket + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "nhooyr.io/websocket/internal/errd" + "strings" + "sync/atomic" + "time" +) + +// Reader waits until there is a WebSocket data message to read +// from the connection. +// It returns the type of the message and a reader to read it. +// The passed context will also bound the reader. +// Ensure you read to EOF otherwise the connection will hang. +// +// All returned errors will cause the connection +// to be closed so you do not need to write your own error message. +// This applies to the Read methods in the wsjson/wspb subpackages as well. +// +// You must read from the connection for control frames to be handled. +// Thus if you expect messages to take a long time to be responded to, +// you should handle such messages async to reading from the connection +// to ensure control frames are promptly handled. +// +// If you do not expect any data messages from the peer, call CloseRead. +// +// Only one Reader may be open at a time. +// +// If you need a separate timeout on the Reader call and then the message +// Read, use time.AfterFunc to cancel the context passed in early. +// See https://github.com/nhooyr/websocket/issues/87#issue-451703332 +// Most users should not need this. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + typ, r, err := c.cr.reader(ctx) + if err != nil { + return 0, nil, fmt.Errorf("failed to get reader: %w", err) + } + return typ, r, nil +} + +// Read is a convenience method to read a single message from the connection. +// +// See the Reader method to reuse buffers or for streaming. +// The docs on Reader apply to this method as well. +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + typ, r, err := c.Reader(ctx) + if err != nil { + return 0, nil, err + } + + b, err := ioutil.ReadAll(r) + return typ, b, err +} + +// CloseRead will start a goroutine to read from the connection until it is closed or a data message +// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. +// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. +// After calling this method, you cannot read any data messages from the connection. +// The returned context will be cancelled when the connection is closed. +// +// Use this when you do not want to read data messages from the connection anymore but will +// want to write messages to it. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.Reader(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusMessageTooBig. +func (c *Conn) SetReadLimit(n int64) { + c.cr.mr.lr.limit.Store(n) +} + +type connReader struct { + c *Conn + br *bufio.Reader + timeout chan context.Context + + mu mu + controlPayloadBuf [maxControlPayload]byte + mr *msgReader +} + +func (cr *connReader) init(c *Conn, br *bufio.Reader) { + cr.c = c + cr.br = br + cr.timeout = make(chan context.Context) + + cr.mr = &msgReader{ + cr: cr, + fin: true, + } + + cr.mr.lr = newLimitReader(c, readerFunc(cr.mr.read), 32768) + if c.deflateNegotiated() && cr.contextTakeover() { + cr.ensureFlateReader() + } +} + +func (cr *connReader) ensureFlateReader() { + cr.mr.fr = getFlateReader(readerFunc(cr.mr.read)) + cr.mr.lr.reset(cr.mr.fr) +} + +func (cr *connReader) close() { + cr.mu.Lock(context.Background()) + if cr.c.client { + putBufioReader(cr.br) + } + if cr.c.deflateNegotiated() && cr.contextTakeover() { + putFlateReader(cr.mr.fr) + } +} + +func (cr *connReader) contextTakeover() bool { + if cr.c.client { + return cr.c.copts.serverNoContextTakeover + } + return cr.c.copts.clientNoContextTakeover +} + +func (cr *connReader) rsv1Illegal(h header) bool { + // If compression is enabled, rsv1 is always illegal. + if !cr.c.deflateNegotiated() { + return true + } + // rsv1 is only allowed on data frames beginning messages. + if h.opcode != opText && h.opcode != opBinary { + return true + } + return false +} + +func (cr *connReader) loop(ctx context.Context) (header, error) { + for { + h, err := cr.frameHeader(ctx) + if err != nil { + return header{}, err + } + + if h.rsv1 && cr.rsv1Illegal(h) || h.rsv2 || h.rsv3 { + err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) + cr.c.cw.error(StatusProtocolError, err) + return header{}, err + } + + if !cr.c.client && !h.masked { + return header{}, errors.New("received unmasked frame from client") + } + + switch h.opcode { + case opClose, opPing, opPong: + err = cr.control(ctx, h) + if err != nil { + // Pass through CloseErrors when receiving a close frame. + if h.opcode == opClose && CloseStatus(err) != -1 { + return header{}, err + } + return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) + } + case opContinuation, opText, opBinary: + return h, nil + default: + err := fmt.Errorf("received unknown opcode %v", h.opcode) + cr.c.cw.error(StatusProtocolError, err) + return header{}, err + } + } +} + +func (cr *connReader) frameHeader(ctx context.Context) (header, error) { + select { + case <-cr.c.closed: + return header{}, cr.c.closeErr + case cr.timeout <- ctx: + } + + h, err := readFrameHeader(cr.br) + if err != nil { + select { + case <-cr.c.closed: + return header{}, cr.c.closeErr + case <-ctx.Done(): + return header{}, ctx.Err() + default: + cr.c.close(err) + return header{}, err + } + } + + select { + case <-cr.c.closed: + return header{}, cr.c.closeErr + case cr.timeout <- context.Background(): + } + + return h, nil +} + +func (cr *connReader) framePayload(ctx context.Context, p []byte) (int, error) { + select { + case <-cr.c.closed: + return 0, cr.c.closeErr + case cr.timeout <- ctx: + } + + n, err := io.ReadFull(cr.br, p) + if err != nil { + select { + case <-cr.c.closed: + return n, cr.c.closeErr + case <-ctx.Done(): + return n, ctx.Err() + default: + err = fmt.Errorf("failed to read frame payload: %w", err) + cr.c.close(err) + return n, err + } + } + + select { + case <-cr.c.closed: + return n, cr.c.closeErr + case cr.timeout <- context.Background(): + } + + return n, err +} + +func (cr *connReader) control(ctx context.Context, h header) error { + if h.payloadLength < 0 { + err := fmt.Errorf("received header with negative payload length: %v", h.payloadLength) + cr.c.cw.error(StatusProtocolError, err) + return err + } + + if h.payloadLength > maxControlPayload { + err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength) + cr.c.cw.error(StatusProtocolError, err) + return err + } + + if !h.fin { + err := errors.New("received fragmented control frame") + cr.c.cw.error(StatusProtocolError, err) + return err + } + + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + b := cr.controlPayloadBuf[:h.payloadLength] + _, err := cr.framePayload(ctx, b) + if err != nil { + return err + } + + if h.masked { + mask(h.maskKey, b) + } + + switch h.opcode { + case opPing: + return cr.c.cw.control(ctx, opPong, b) + case opPong: + cr.c.activePingsMu.Lock() + pong, ok := cr.c.activePings[string(b)] + cr.c.activePingsMu.Unlock() + if ok { + close(pong) + } + return nil + } + + ce, err := parseClosePayload(b) + if err != nil { + err = fmt.Errorf("received invalid close payload: %w", err) + cr.c.cw.error(StatusProtocolError, err) + return err + } + + err = fmt.Errorf("received close frame: %w", ce) + cr.c.setCloseErr(err) + cr.c.cw.control(context.Background(), opClose, ce.bytes()) + return err +} + +func (cr *connReader) reader(ctx context.Context) (MessageType, io.Reader, error) { + err := cr.mu.Lock(ctx) + if err != nil { + return 0, nil, err + } + defer cr.mu.Unlock() + + if !cr.mr.fin { + return 0, nil, errors.New("previous message not read to completion") + } + + h, err := cr.loop(ctx) + if err != nil { + return 0, nil, err + } + + if h.opcode == opContinuation { + err := errors.New("received continuation frame without text or binary frame") + cr.c.cw.error(StatusProtocolError, err) + return 0, nil, err + } + + cr.mr.reset(ctx, h) + + return MessageType(h.opcode), cr.mr, nil +} + +type msgReader struct { + cr *connReader + fr io.Reader + lr *limitReader + + ctx context.Context + + deflate bool + deflateTail strings.Reader + + payloadLength int64 + maskKey uint32 + fin bool +} + +func (mr *msgReader) reset(ctx context.Context, h header) { + mr.ctx = ctx + mr.deflate = h.rsv1 + if mr.deflate { + mr.deflateTail.Reset(deflateMessageTail) + if !mr.cr.contextTakeover() { + mr.cr.ensureFlateReader() + } + } + mr.setFrame(h) + mr.fin = false +} + +func (mr *msgReader) setFrame(h header) { + mr.payloadLength = h.payloadLength + mr.maskKey = h.maskKey + mr.fin = h.fin +} + +func (mr *msgReader) Read(p []byte) (_ int, err error) { + defer func() { + errd.Wrap(&err, "failed to read") + if errors.Is(err, io.EOF) { + err = io.EOF + } + }() + + err = mr.cr.mu.Lock(mr.ctx) + if err != nil { + return 0, err + } + defer mr.cr.mu.Unlock() + + if mr.payloadLength == 0 && mr.fin { + if mr.cr.c.deflateNegotiated() && !mr.cr.contextTakeover() { + if mr.fr != nil { + putFlateReader(mr.fr) + mr.fr = nil + } + } + return 0, io.EOF + } + + return mr.lr.Read(p) +} + +func (mr *msgReader) read(p []byte) (int, error) { + log.Println("compress", mr.deflate) + + if mr.payloadLength == 0 { + h, err := mr.cr.loop(mr.ctx) + if err != nil { + return 0, err + } + if h.opcode != opContinuation { + err := errors.New("received new data message without finishing the previous message") + mr.cr.c.cw.error(StatusProtocolError, err) + return 0, err + } + mr.setFrame(h) + } + + if int64(len(p)) > mr.payloadLength { + p = p[:mr.payloadLength] + } + + n, err := mr.cr.framePayload(mr.ctx, p) + if err != nil { + return n, err + } + + mr.payloadLength -= int64(n) + + if !mr.cr.c.client { + mr.maskKey = mask(mr.maskKey, p) + } + + return n, nil +} + +type limitReader struct { + c *Conn + r io.Reader + limit atomicInt64 + n int64 +} + +func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { + lr := &limitReader{ + c: c, + } + lr.limit.Store(limit) + lr.reset(r) + return lr +} + +func (lr *limitReader) reset(r io.Reader) { + lr.n = lr.limit.Load() + lr.r = r +} + +func (lr *limitReader) Read(p []byte) (int, error) { + if lr.n <= 0 { + err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) + lr.c.cw.error(StatusMessageTooBig, err) + return 0, err + } + + if int64(len(p)) > lr.n { + p = p[:lr.n] + } + n, err := lr.r.Read(p) + lr.n -= int64(n) + return n, err +} + +type atomicInt64 struct { + i atomic.Value +} + +func (v *atomicInt64) Load() int64 { + i, _ := v.i.Load().(int64) + return i +} + +func (v *atomicInt64) Store(i int64) { + v.i.Store(i) +} + +type readerFunc func(p []byte) (int, error) + +func (f readerFunc) Read(p []byte) (int, error) { + return f(p) +} diff --git a/reader.go b/reader.go deleted file mode 100644 index fe716569..00000000 --- a/reader.go +++ /dev/null @@ -1,31 +0,0 @@ -package websocket - -import ( - "bufio" - "context" - "io" - "nhooyr.io/websocket/internal/atomicint" - "nhooyr.io/websocket/internal/wsframe" - "strings" -) - -type reader struct { - // Acquired before performing any sort of read operation. - readLock chan struct{} - - c *Conn - - deflateReader io.Reader - br *bufio.Reader - - readClosed *atomicint.Int64 - readHeaderBuf []byte - controlPayloadBuf []byte - - msgCtx context.Context - msgCompressed bool - frameHeader wsframe.Header - frameMaskKey uint32 - frameEOF bool - deflateTail strings.Reader -} diff --git a/write.go b/write.go new file mode 100644 index 00000000..5bb489b4 --- /dev/null +++ b/write.go @@ -0,0 +1,348 @@ +package websocket + +import ( + "bufio" + "compress/flate" + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "nhooyr.io/websocket/internal/errd" + "time" +) + +// Writer returns a writer bounded by the context that will write +// a WebSocket message of type dataType to the connection. +// +// You must close the writer once you have written the entire message. +// +// Only one writer can be open at a time, multiple calls will block until the previous writer +// is closed. +// +// Never close the returned writer twice. +func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + w, err := c.cw.writer(ctx, typ) + if err != nil { + return nil, fmt.Errorf("failed to get writer: %w", err) + } + return w, nil +} + +// Write writes a message to the connection. +// +// See the Writer method if you want to stream a message. +// +// If compression is disabled, then it is guaranteed to write the message +// in a single frame. +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + _, err := c.cw.write(ctx, typ, p) + if err != nil { + return fmt.Errorf("failed to write msg: %w", err) + } + return nil +} + +type connWriter struct { + c *Conn + bw *bufio.Writer + + writeBuf []byte + + mw *messageWriter + frameMu mu + h header + + timeout chan context.Context +} + +func (cw *connWriter) init(c *Conn, bw *bufio.Writer) { + cw.c = c + cw.bw = bw + + if cw.c.client { + cw.writeBuf = extractBufioWriterBuf(cw.bw, c.rwc) + } + + cw.timeout = make(chan context.Context) + + cw.mw = &messageWriter{ + cw: cw, + } + cw.mw.tw = &trimLastFourBytesWriter{ + w: writerFunc(cw.mw.write), + } + if cw.c.deflateNegotiated() && cw.mw.contextTakeover() { + cw.mw.ensureFlateWriter() + } +} + +func (mw *messageWriter) ensureFlateWriter() { + mw.fw = getFlateWriter(mw.tw) +} + +func (cw *connWriter) close() { + if cw.c.client { + cw.frameMu.Lock(context.Background()) + putBufioWriter(cw.bw) + } + if cw.c.deflateNegotiated() && cw.mw.contextTakeover() { + cw.mw.mu.Lock(context.Background()) + putFlateWriter(cw.mw.fw) + } +} + +func (mw *messageWriter) contextTakeover() bool { + if mw.cw.c.client { + return mw.cw.c.copts.clientNoContextTakeover + } + return mw.cw.c.copts.serverNoContextTakeover +} + +func (cw *connWriter) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + err := cw.mw.reset(ctx, typ) + if err != nil { + return nil, err + } + return cw.mw, nil +} + +func (cw *connWriter) write(ctx context.Context, typ MessageType, p []byte) (int, error) { + ww, err := cw.writer(ctx, typ) + if err != nil { + return 0, err + } + + if !cw.c.deflateNegotiated() { + // Fast single frame path. + defer cw.mw.mu.Unlock() + return cw.frame(ctx, true, cw.mw.opcode, p) + } + + n, err := ww.Write(p) + if err != nil { + return n, err + } + + err = ww.Close() + return n, err +} + +type messageWriter struct { + cw *connWriter + + mu mu + compress bool + tw *trimLastFourBytesWriter + fw *flate.Writer + ctx context.Context + opcode opcode + closed bool +} + +func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error { + err := mw.mu.Lock(ctx) + if err != nil { + return err + } + + mw.closed = false + mw.ctx = ctx + mw.opcode = opcode(typ) + return nil +} + +// Write writes the given bytes to the WebSocket connection. +func (mw *messageWriter) Write(p []byte) (_ int, err error) { + defer errd.Wrap(&err, "failed to write") + + if mw.closed { + return 0, errors.New("cannot use closed writer") + } + + if mw.cw.c.deflateNegotiated() { + if !mw.compress { + if !mw.contextTakeover() { + mw.ensureFlateWriter() + } + mw.tw.reset() + mw.compress = true + } + + return mw.fw.Write(p) + } + + return mw.write(p) +} + +func (mw *messageWriter) write(p []byte) (int, error) { + n, err := mw.cw.frame(mw.ctx, false, mw.opcode, p) + if err != nil { + return n, fmt.Errorf("failed to write data frame: %w", err) + } + mw.opcode = opContinuation + return n, nil +} + +// Close flushes the frame to the connection. +// This must be called for every messageWriter. +func (mw *messageWriter) Close() (err error) { + defer errd.Wrap(&err, "failed to close writer") + + if mw.closed { + return errors.New("cannot use closed writer") + } + mw.closed = true + + if mw.cw.c.deflateNegotiated() { + err = mw.fw.Flush() + if err != nil { + return fmt.Errorf("failed to flush flate writer: %w", err) + } + } + + _, err = mw.cw.frame(mw.ctx, true, mw.opcode, nil) + if err != nil { + return fmt.Errorf("failed to write fin frame: %w", err) + } + + if mw.compress && !mw.contextTakeover() { + putFlateWriter(mw.fw) + mw.compress = false + } + + mw.mu.Unlock() + return nil +} + +func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) error { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + _, err := cw.frame(ctx, true, opcode, p) + if err != nil { + return fmt.Errorf("failed to write control frame %v: %w", opcode, err) + } + return nil +} + +// frame handles all writes to the connection. +func (cw *connWriter) frame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { + err := cw.frameMu.Lock(ctx) + if err != nil { + return 0, err + } + defer cw.frameMu.Unlock() + + select { + case <-cw.c.closed: + return 0, cw.c.closeErr + case cw.timeout <- ctx: + } + + cw.h.fin = fin + cw.h.opcode = opcode + cw.h.masked = cw.c.client + cw.h.payloadLength = int64(len(p)) + + cw.h.rsv1 = false + if cw.mw.compress && (opcode == opText || opcode == opBinary) { + cw.h.rsv1 = true + } + + if cw.h.masked { + err = binary.Read(rand.Reader, binary.LittleEndian, &cw.h.maskKey) + if err != nil { + return 0, fmt.Errorf("failed to generate masking key: %w", err) + } + } + + err = writeFrameHeader(cw.h, cw.bw) + if err != nil { + return 0, err + } + + n, err := cw.framePayload(p) + if err != nil { + return n, err + } + + if cw.h.fin { + err = cw.bw.Flush() + if err != nil { + return n, fmt.Errorf("failed to flush: %w", err) + } + } + + select { + case <-cw.c.closed: + return n, cw.c.closeErr + case cw.timeout <- context.Background(): + } + + return n, nil +} + +func (cw *connWriter) framePayload(p []byte) (_ int, err error) { + defer errd.Wrap(&err, "failed to write frame payload") + + if !cw.h.masked { + return cw.bw.Write(p) + } + + var n int + maskKey := cw.h.maskKey + for len(p) > 0 { + // If the buffer is full, we need to flush. + if cw.bw.Available() == 0 { + err = cw.bw.Flush() + if err != nil { + return n, err + } + } + + // Start of next write in the buffer. + i := cw.bw.Buffered() + + j := len(p) + if j > cw.bw.Available() { + j = cw.bw.Available() + } + + _, err := cw.bw.Write(p[:j]) + if err != nil { + return n, err + } + + maskKey = mask(maskKey, cw.writeBuf[i:cw.bw.Buffered()]) + + p = p[j:] + n += j + } + + return n, nil +} + +type writerFunc func(p []byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { + return f(p) +} + +// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer +// and returns it. +func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { + var writeBuf []byte + bw.Reset(writerFunc(func(p2 []byte) (int, error) { + writeBuf = p2[:cap(p2)] + return len(p2), nil + })) + + bw.WriteByte(0) + bw.Flush() + + bw.Reset(w) + + return writeBuf +} diff --git a/writer.go b/writer.go deleted file mode 100644 index b31d57ad..00000000 --- a/writer.go +++ /dev/null @@ -1,5 +0,0 @@ -package websocket - -type writer struct { - -} diff --git a/ws_js.go b/ws_js.go index 4c067430..10ce0da8 100644 --- a/ws_js.go +++ b/ws_js.go @@ -9,7 +9,7 @@ import ( "fmt" "io" "net/http" - "nhooyr.io/websocket/internal/atomicint" + "nhooyr.io/websocket/internal/wssync" "reflect" "runtime" "sync" @@ -24,10 +24,10 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit *atomicint.Int64 + msgReadLimit *wssync.Int64 closingMu sync.Mutex - isReadClosed *atomicint.Int64 + isReadClosed *wssync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -59,10 +59,10 @@ func (c *Conn) init() { c.closed = make(chan struct{}) c.readSignal = make(chan struct{}, 1) - c.msgReadLimit = &atomicint.Int64{} + c.msgReadLimit = &wssync.Int64{} c.msgReadLimit.Store(32768) - c.isReadClosed = &atomicint.Int64{} + c.isReadClosed = &wssync.Int64{} c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { err := CloseError{ @@ -105,7 +105,7 @@ func (c *Conn) closeWithInternal() { // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { if c.isReadClosed.Load() == 1 { - return 0, nil, fmt.Errorf("websocket connection read closed") + return 0, nil, errors.New("websocket connection read closed") } typ, p, err := c.read(ctx) diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 9fa8b54c..e8188051 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "log" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bufpool" ) @@ -41,6 +42,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { err = json.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") + log.Printf("%X", b.Bytes()) return fmt.Errorf("failed to unmarshal json: %w", err) } From dd107dd12713665b37436d2af3302f9e83409240 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 28 Nov 2019 14:59:31 -0500 Subject: [PATCH 09/55] Update CI --- .github/CODEOWNERS | 1 + .github/workflows/ci.yml | 38 +++++++++++++++++---- Makefile | 4 --- ci/{ => image}/Dockerfile | 13 -------- conn_test.go | 69 ++++++++++++++++++++------------------- 5 files changed, 68 insertions(+), 57 deletions(-) create mode 100644 .github/CODEOWNERS rename ci/{ => image}/Dockerfile (52%) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..d2eae33e --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @nhooyr diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2cc69828..865c67f0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,22 +4,48 @@ on: [push, pull_request] jobs: fmt: runs-on: ubuntu-latest - container: nhooyr/websocket-ci@sha256:8a8fd73fdea33585d50a33619c4936adfd016246a2ed6bbfbf06def24b518a6a steps: - uses: actions/checkout@v1 - - run: make fmt + - uses: actions/cache@v1 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + - name: make fmt + uses: ./ci/image + with: + args: make fmt + lint: runs-on: ubuntu-latest - container: nhooyr/websocket-ci@sha256:8a8fd73fdea33585d50a33619c4936adfd016246a2ed6bbfbf06def24b518a6a steps: - uses: actions/checkout@v1 - - run: make lint + - uses: actions/cache@v1 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + - name: make lint + uses: ./ci/image + with: + args: make lint + test: runs-on: ubuntu-latest - container: nhooyr/websocket-ci@sha256:8a8fd73fdea33585d50a33619c4936adfd016246a2ed6bbfbf06def24b518a6a steps: - uses: actions/checkout@v1 - - run: make test + - uses: actions/cache@v1 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + - name: make test + uses: ./ci/image + with: + args: make test env: COVERALLS_TOKEN: ${{ secrets.github_token }} - name: Upload coverage.html diff --git a/Makefile b/Makefile index 8c8e1a08..ad1ba257 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,3 @@ SHELL = bash include ci/fmt.mk include ci/lint.mk include ci/test.mk - -ci-image: - docker build -f ./ci/Dockerfile -t nhooyr/websocket-ci . - docker push nhooyr/websocket-ci diff --git a/ci/Dockerfile b/ci/image/Dockerfile similarity index 52% rename from ci/Dockerfile rename to ci/image/Dockerfile index 0f0fc7d9..ccfac109 100644 --- a/ci/Dockerfile +++ b/ci/image/Dockerfile @@ -5,8 +5,6 @@ RUN apt-get install -y chromium RUN apt-get install -y npm RUN apt-get install -y jq -ENV GOPATH=/root/gopath -ENV PATH=$GOPATH/bin:$PATH ENV GOFLAGS="-mod=readonly" ENV PAGER=cat ENV CI=true @@ -18,14 +16,3 @@ RUN go get golang.org/x/tools/cmd/goimports RUN go get golang.org/x/lint/golint RUN go get github.com/agnivade/wasmbrowsertest RUN go get github.com/mattn/goveralls - -# Cache go modules and build cache. -COPY . /tmp/websocket -RUN cd /tmp/websocket && \ - CI= make && \ - rm -rf /tmp/websocket - -# GitHub actions tries to override HOME to /github/home and then -# mounts a temp directory into there. We do not want this behaviour. -# I assume it is so that $HOME is preserved between steps in a job. -ENTRYPOINT ["env", "HOME=/root"] diff --git a/conn_test.go b/conn_test.go index 992c8861..1014dbf3 100644 --- a/conn_test.go +++ b/conn_test.go @@ -17,6 +17,41 @@ import ( "nhooyr.io/websocket" ) +func TestConn(t *testing.T) { + t.Parallel() + + t.Run("json", func(t *testing.T) { + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + InsecureSkipVerify: true, + }) + assert.Success(t, err) + defer c.Close(websocket.StatusInternalError, "") + + err = echoLoop(r.Context(), c) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }, false) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + opts := &websocket.DialOptions{ + Subprotocols: []string{"echo"}, + } + opts.HTTPClient = s.Client() + + c, _, err := websocket.Dial(ctx, wsURL, opts) + assert.Success(t, err) + + assertJSONEcho(t, ctx, c, 2) + }) +} + + func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request), tls bool) (s *httptest.Server, closeFn func()) { h := http.HandlerFunc(fn) if tls { @@ -108,37 +143,3 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error { } } } - -func TestConn(t *testing.T) { - t.Parallel() - - t.Run("json", func(t *testing.T) { - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - InsecureSkipVerify: true, - }) - assert.Success(t, err) - defer c.Close(websocket.StatusInternalError, "") - - err = echoLoop(r.Context(), c) - assertCloseStatus(t, websocket.StatusNormalClosure, err) - }, false) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - opts := &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - } - opts.HTTPClient = s.Client() - - c, _, err := websocket.Dial(ctx, wsURL, opts) - assert.Success(t, err) - - assertJSONEcho(t, ctx, c, 2) - }) -} From 6c6b8e9af2030e9ce4352ae006092255ca62fef5 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 28 Nov 2019 15:58:41 -0500 Subject: [PATCH 10/55] Cleanup wspb and wsjson --- .github/CONTRIBUTING.md | 45 -------------- .github/ISSUE_TEMPLATE.md | 1 - .github/PULL_REQUEST_TEMPLATE.md | 4 -- README.md | 4 -- assert_test.go | 12 ++-- ci/test.mk | 4 +- close.go | 6 +- close_test.go | 2 +- frame.go | 1 - frame_test.go | 6 +- go.mod | 4 +- internal/assert/assert.go | 69 ++------------------- internal/assert/cmp.go | 52 ++++++++++++++++ internal/{bufpool/buf.go => bpool/bpool.go} | 10 +-- internal/bufpool/buf_test.go | 46 -------------- internal/errd/errd.go | 11 ---- internal/errd/wrap.go | 14 +++++ read.go | 3 - write.go | 1 - ws_js.go | 6 +- wsjson/wsjson.go | 60 ++++++++---------- wspb/wspb.go | 53 ++++++++-------- 22 files changed, 146 insertions(+), 268 deletions(-) delete mode 100644 .github/CONTRIBUTING.md delete mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 internal/assert/cmp.go rename internal/{bufpool/buf.go => bpool/bpool.go} (72%) delete mode 100644 internal/bufpool/buf_test.go delete mode 100644 internal/errd/errd.go create mode 100644 internal/errd/wrap.go diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md deleted file mode 100644 index 357c314a..00000000 --- a/.github/CONTRIBUTING.md +++ /dev/null @@ -1,45 +0,0 @@ -# Contributing - -## Issues - -Please be as descriptive as possible. - -Reproducible examples are key to finding and fixing bugs. - -## Pull requests - -Good issues for first time contributors are marked as such. Feel free to -reach out for clarification on what needs to be done. - -Split up large changes into several small descriptive commits. - -Capitalize the first word in the commit message title. - -The commit message title should use the verb tense + phrase that completes the blank in - -> This change modifies websocket to \_\_\_\_\_\_\_\_\_ - -Be sure to [correctly link](https://help.github.com/en/articles/closing-issues-using-keywords) -to an existing issue if one exists. In general, create an issue before a PR to get some -discussion going and to make sure you do not spend time on a PR that may be rejected. - -CI must pass on your changes for them to be merged. - -### CI - -CI will ensure your code is formatted, lints and passes tests. -It will collect coverage and report it to [coveralls](https://coveralls.io/github/nhooyr/websocket) -and also upload a html `coverage` artifact that you can download to browse coverage. - -You can run CI locally. - -See [ci/image/Dockerfile](../ci/image/Dockerfile) for the installation of the CI dependencies on Ubuntu. - -1. `make fmt` performs code generation and formatting. -1. `make lint` performs linting. -1. `make test` runs tests. -1. `make` runs the above targets. - -For coverage details locally, see `ci/out/coverage.html` after running `make test`. - -You can run tests normally with `go test`. `make test` wraps around `go test` to collect coverage. diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index fce01709..7b580937 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -1,4 +1,3 @@ diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index 901c994a..00000000 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,4 +0,0 @@ - diff --git a/README.md b/README.md index 17c7c838..8ac418a7 100644 --- a/README.md +++ b/README.md @@ -165,10 +165,6 @@ faster, the compression extensions are fully supported and as much as possible i See the gorilla/websocket comparison for more performance details. -## Contributing - -See [.github/CONTRIBUTING.md](.github/CONTRIBUTING.md). - ## Users If your company or project is using this library, feel free to open an issue or PR to amend this list. diff --git a/assert_test.go b/assert_test.go index 0cc9dfe3..6e4e75e6 100644 --- a/assert_test.go +++ b/assert_test.go @@ -33,7 +33,7 @@ func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) err = wsjson.Read(ctx, c, &act) assert.Success(t, err) - assert.Equalf(t, exp, act, "unexpected JSON") + assert.Equal(t, exp, act, "unexpected JSON") } func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { @@ -43,7 +43,7 @@ func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp in err := wsjson.Read(ctx, c, &act) assert.Success(t, err) - assert.Equalf(t, exp, act, "unexpected JSON") + assert.Equal(t, exp, act, "unexpected JSON") } func randString(n int) string { @@ -69,18 +69,18 @@ func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websoc typ2, p2, err := c.Read(ctx) assert.Success(t, err) - assert.Equalf(t, typ, typ2, "unexpected data type") - assert.Equalf(t, p, p2, "unexpected payload") + assert.Equal(t, typ, typ2, "unexpected data type") + assert.Equal(t, p, p2, "unexpected payload") } func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { t.Helper() - assert.Equalf(t, exp, c.Subprotocol(), "unexpected subprotocol") + assert.Equal(t, exp, c.Subprotocol(), "unexpected subprotocol") } func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) { t.Helper() - assert.Equalf(t, exp, websocket.CloseStatus(err), "unexpected status code") + assert.Equal(t, exp, websocket.CloseStatus(err), "unexpected status code") } diff --git a/ci/test.mk b/ci/test.mk index 3183552e..9e4e0803 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -20,6 +20,4 @@ coveralls: gotest gotest: go test -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... sed -i '/_stringer\.go/d' ci/out/coverage.prof - sed -i '/wsecho\.go/d' ci/out/coverage.prof - sed -i '/assert\.go/d' ci/out/coverage.prof - sed -i '/wsgrace\.go/d' ci/out/coverage.prof + sed -i '/assert/d' ci/out/coverage.prof diff --git a/close.go b/close.go index b1bc50e9..57d69a37 100644 --- a/close.go +++ b/close.go @@ -6,7 +6,7 @@ import ( "errors" "fmt" "log" - "nhooyr.io/websocket/internal/bufpool" + "nhooyr.io/websocket/internal/bpool" "time" ) @@ -146,10 +146,10 @@ func (cr *connReader) waitClose() error { } defer cr.mu.Unlock() - b := bufpool.Get() + b := bpool.Get() buf := b.Bytes() buf = buf[:cap(buf)] - defer bufpool.Put(b) + defer bpool.Put(b) for { // TODO diff --git a/close_test.go b/close_test.go index ee10cd3f..ca51a298 100644 --- a/close_test.go +++ b/close_test.go @@ -189,7 +189,7 @@ func TestCloseStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - assert.Equalf(t, tc.exp, CloseStatus(tc.in), "unexpected close status") + assert.Equal(t, tc.exp, CloseStatus(tc.in), "unexpected close status") }) } } diff --git a/frame.go b/frame.go index 0f10d553..f36334c2 100644 --- a/frame.go +++ b/frame.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "math" "math/bits" - "nhooyr.io/websocket/internal/errd" ) // opcode represents a WebSocket opcode. diff --git a/frame_test.go b/frame_test.go index 0ed14aef..a4a1f5a8 100644 --- a/frame_test.go +++ b/frame_test.go @@ -90,7 +90,7 @@ func testHeader(t *testing.T, h header) { h2, err := readFrameHeader(r) assert.Success(t, err) - assert.Equalf(t, h, h2, "written and read headers differ") + assert.Equal(t, h, h2, "written and read headers differ") } func Test_mask(t *testing.T) { @@ -101,8 +101,8 @@ func Test_mask(t *testing.T) { p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} gotKey32 := mask(key32, p) - assert.Equalf(t, []byte{0, 0, 0, 0x0d, 0x6}, p, "unexpected mask") - assert.Equalf(t, bits.RotateLeft32(key32, -8), gotKey32, "unexpected mask key") + assert.Equal(t, []byte{0, 0, 0, 0x0d, 0x6}, p, "unexpected mask") + assert.Equal(t, bits.RotateLeft32(key32, -8), gotKey32, "unexpected mask key") } func basicMask(maskKey [4]byte, pos int, b []byte) int { diff --git a/go.mod b/go.mod index e6ef0014..3108c020 100644 --- a/go.mod +++ b/go.mod @@ -7,13 +7,13 @@ require ( github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect github.com/gobwas/pool v0.2.0 // indirect github.com/gobwas/ws v1.0.2 - github.com/golang/protobuf v1.3.2 + github.com/golang/protobuf v1.3.2 // indirect github.com/google/go-cmp v0.3.1 github.com/gorilla/websocket v1.4.1 github.com/kr/pretty v0.1.0 // indirect github.com/stretchr/testify v1.4.0 // indirect go.uber.org/atomic v1.4.0 // indirect - go.uber.org/multierr v1.1.0 + go.uber.org/multierr v1.1.0 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/internal/assert/assert.go b/internal/assert/assert.go index 1d9aeced..4ebdb511 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -1,79 +1,29 @@ package assert import ( - "reflect" "strings" "testing" - - "github.com/google/go-cmp/cmp" ) -// https://github.com/google/go-cmp/issues/40#issuecomment-328615283 -func cmpDiff(exp, act interface{}) string { - return cmp.Diff(exp, act, deepAllowUnexported(exp, act)) -} - -func deepAllowUnexported(vs ...interface{}) cmp.Option { - m := make(map[reflect.Type]struct{}) - for _, v := range vs { - structTypes(reflect.ValueOf(v), m) - } - var typs []interface{} - for t := range m { - typs = append(typs, reflect.New(t).Elem().Interface()) - } - return cmp.AllowUnexported(typs...) -} - -func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { - if !v.IsValid() { - return - } - switch v.Kind() { - case reflect.Ptr: - if !v.IsNil() { - structTypes(v.Elem(), m) - } - case reflect.Interface: - if !v.IsNil() { - structTypes(v.Elem(), m) - } - case reflect.Slice, reflect.Array: - for i := 0; i < v.Len(); i++ { - structTypes(v.Index(i), m) - } - case reflect.Map: - for _, k := range v.MapKeys() { - structTypes(v.MapIndex(k), m) - } - case reflect.Struct: - m[v.Type()] = struct{}{} - for i := 0; i < v.NumField(); i++ { - structTypes(v.Field(i), m) - } - } -} - -func Equalf(t testing.TB, exp, act interface{}, f string, v ...interface{}) { +func Equal(t testing.TB, exp, act interface{}, name string) { t.Helper() diff := cmpDiff(exp, act) if diff != "" { - t.Fatalf(f+": %v", append(v, diff)...) + t.Fatalf("unexpected %v: %v", name, diff) } } -func NotEqualf(t testing.TB, exp, act interface{}, f string, v ...interface{}) { +func NotEqual(t testing.TB, exp, act interface{}, name string) { t.Helper() - diff := cmpDiff(exp, act) - if diff == "" { - t.Fatalf(f+": %v", append(v, diff)...) + if cmpDiff(exp, act) == "" { + t.Fatalf("expected different %v: %+v", name, act) } } func Success(t testing.TB, err error) { t.Helper() if err != nil { - t.Fatalf("unexpected error: %+v", err) + t.Fatalf("unexpected error : %+v", err) } } @@ -92,10 +42,3 @@ func ErrorContains(t testing.TB, err error, sub string) { t.Fatalf("error string %q does not contain %q", errs, sub) } } - -func Panicf(t testing.TB, f string, v ...interface{}) { - r := recover() - if r == nil { - t.Fatalf(f, v...) - } -} diff --git a/internal/assert/cmp.go b/internal/assert/cmp.go new file mode 100644 index 00000000..0edcf2cd --- /dev/null +++ b/internal/assert/cmp.go @@ -0,0 +1,52 @@ +package assert + +import ( + "github.com/google/go-cmp/cmp" + "reflect" +) + +// https://github.com/google/go-cmp/issues/40#issuecomment-328615283 +func cmpDiff(exp, act interface{}) string { + return cmp.Diff(exp, act, deepAllowUnexported(exp, act)) +} + +func deepAllowUnexported(vs ...interface{}) cmp.Option { + m := make(map[reflect.Type]struct{}) + for _, v := range vs { + structTypes(reflect.ValueOf(v), m) + } + var typs []interface{} + for t := range m { + typs = append(typs, reflect.New(t).Elem().Interface()) + } + return cmp.AllowUnexported(typs...) +} + +func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { + if !v.IsValid() { + return + } + switch v.Kind() { + case reflect.Ptr: + if !v.IsNil() { + structTypes(v.Elem(), m) + } + case reflect.Interface: + if !v.IsNil() { + structTypes(v.Elem(), m) + } + case reflect.Slice, reflect.Array: + for i := 0; i < v.Len(); i++ { + structTypes(v.Index(i), m) + } + case reflect.Map: + for _, k := range v.MapKeys() { + structTypes(v.MapIndex(k), m) + } + case reflect.Struct: + m[v.Type()] = struct{}{} + for i := 0; i < v.NumField(); i++ { + structTypes(v.Field(i), m) + } + } +} diff --git a/internal/bufpool/buf.go b/internal/bpool/bpool.go similarity index 72% rename from internal/bufpool/buf.go rename to internal/bpool/bpool.go index 0f7d9765..e2c5f76a 100644 --- a/internal/bufpool/buf.go +++ b/internal/bpool/bpool.go @@ -1,4 +1,4 @@ -package bufpool +package bpool import ( "bytes" @@ -10,11 +10,11 @@ var pool sync.Pool // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { - b, ok := pool.Get().(*bytes.Buffer) - if !ok { - b = &bytes.Buffer{} + b := pool.Get() + if b == nil { + return &bytes.Buffer{} } - return b + return b.(*bytes.Buffer) } // Put returns a buffer into the pool. diff --git a/internal/bufpool/buf_test.go b/internal/bufpool/buf_test.go deleted file mode 100644 index 42a2fea7..00000000 --- a/internal/bufpool/buf_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package bufpool - -import ( - "strconv" - "sync" - "testing" -) - -func BenchmarkSyncPool(b *testing.B) { - sizes := []int{ - 2, - 16, - 32, - 64, - 128, - 256, - 512, - 4096, - 16384, - } - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - b.Run("allocate", func(b *testing.B) { - b.ReportAllocs() - for i := 0; i < b.N; i++ { - buf := make([]byte, size) - _ = buf - } - }) - b.Run("pool", func(b *testing.B) { - b.ReportAllocs() - - p := sync.Pool{} - - for i := 0; i < b.N; i++ { - buf := p.Get() - if buf == nil { - buf = make([]byte, size) - } - - p.Put(buf) - } - }) - }) - } -} diff --git a/internal/errd/errd.go b/internal/errd/errd.go deleted file mode 100644 index 51b7b4f6..00000000 --- a/internal/errd/errd.go +++ /dev/null @@ -1,11 +0,0 @@ -package errd - -import ( - "fmt" -) - -func Wrap(err *error, f string, v ...interface{}) { - if *err != nil { - *err = fmt.Errorf(f+ ": %w", append(v, *err)...) - } -} diff --git a/internal/errd/wrap.go b/internal/errd/wrap.go new file mode 100644 index 00000000..849335c9 --- /dev/null +++ b/internal/errd/wrap.go @@ -0,0 +1,14 @@ +package errd + +import ( + "fmt" +) + +// Wrap wraps err with fmt.Errorf if err is non nil. +// Intended for use with defer and a named error return. +// Inspired by https://github.com/golang/go/issues/32676. +func Wrap(err *error, f string, v ...interface{}) { + if *err != nil { + *err = fmt.Errorf(f+ ": %w", append(v, *err)...) + } +} diff --git a/read.go b/read.go index 97096f74..1f5a88ad 100644 --- a/read.go +++ b/read.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "io/ioutil" - "log" "nhooyr.io/websocket/internal/errd" "strings" "sync/atomic" @@ -390,8 +389,6 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) { } func (mr *msgReader) read(p []byte) (int, error) { - log.Println("compress", mr.deflate) - if mr.payloadLength == 0 { h, err := mr.cr.loop(mr.ctx) if err != nil { diff --git a/write.go b/write.go index 5bb489b4..e1ea007e 100644 --- a/write.go +++ b/write.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "nhooyr.io/websocket/internal/errd" "time" ) diff --git a/ws_js.go b/ws_js.go index 10ce0da8..882535b1 100644 --- a/ws_js.go +++ b/ws_js.go @@ -15,7 +15,7 @@ import ( "sync" "syscall/js" - "nhooyr.io/websocket/internal/bufpool" + "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" ) @@ -302,7 +302,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err c: c, ctx: ctx, typ: typ, - b: bufpool.Get(), + b: bpool.Get(), }, nil } @@ -332,7 +332,7 @@ func (w writer) Close() error { return errors.New("cannot close closed writer") } w.closed = true - defer bufpool.Put(w.b) + defer bpool.Put(w.b) err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) if err != nil { diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index e8188051..36dd2dfd 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -1,38 +1,36 @@ -// Package wsjson provides websocket helpers for JSON messages. +// Package wsjson provides helpers for reading and writing JSON messages. package wsjson // import "nhooyr.io/websocket/wsjson" import ( "context" "encoding/json" "fmt" - "log" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/bufpool" + "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/errd" ) -// Read reads a json message from c into v. -// It will reuse buffers to avoid allocations. +// Read reads a JSON message from c into v. +// It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { - err := read(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to read json: %w", err) - } - return nil + return read(ctx, c, v) } -func read(ctx context.Context, c *websocket.Conn, v interface{}) error { +func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { + defer errd.Wrap(&err, "failed to read JSON message") + typ, r, err := c.Reader(ctx) if err != nil { return err } if typ != websocket.MessageText { - c.Close(websocket.StatusUnsupportedData, "can only accept text messages") - return fmt.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ) + c.Close(websocket.StatusUnsupportedData, "expected text message") + return fmt.Errorf("expected text message for JSON but got: %v", typ) } - b := bufpool.Get() - defer bufpool.Put(b) + b := bpool.Get() + defer bpool.Put(b) _, err = b.ReadFrom(r) if err != nil { @@ -42,40 +40,32 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { err = json.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") - log.Printf("%X", b.Bytes()) - return fmt.Errorf("failed to unmarshal json: %w", err) + return fmt.Errorf("failed to unmarshal JSON: %w", err) } return nil } -// Write writes the json message v to c. -// It will reuse buffers to avoid allocations. +// Write writes the JSON message v to c. +// It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { - err := write(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to write json: %w", err) - } - return nil + return write(ctx, c, v) } -func write(ctx context.Context, c *websocket.Conn, v interface{}) error { +func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { + defer errd.Wrap(&err, "failed to write JSON message") + w, err := c.Writer(ctx, websocket.MessageText) if err != nil { return err } - // We use Encode because it automatically enables buffer reuse without us - // needing to do anything. Though see https://github.com/golang/go/issues/27735 - e := json.NewEncoder(w) - err = e.Encode(v) + // json.Marshal cannot reuse buffers between calls as it has to return + // a copy of the byte slice but Encoder does as it directly writes to w. + err = json.NewEncoder(w).Encode(v) if err != nil { - return fmt.Errorf("failed to encode json: %w", err) + return fmt.Errorf("failed to marshal JSON: %w", err) } - err = w.Close() - if err != nil { - return err - } - return nil + return w.Close() } diff --git a/wspb/wspb.go b/wspb/wspb.go index 52ddcd57..f4b7c1c5 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -1,40 +1,39 @@ -// Package wspb provides websocket helpers for protobuf messages. +// Package wspb provides helpers for reading and writing protobuf messages. package wspb // import "nhooyr.io/websocket/wspb" import ( "bytes" "context" "fmt" + "nhooyr.io/websocket/internal/errd" "github.com/golang/protobuf/proto" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/bufpool" + "nhooyr.io/websocket/internal/bpool" ) -// Read reads a protobuf message from c into v. -// It will reuse buffers to avoid allocations. +// Read reads a Protobuf message from c into v. +// It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { - err := read(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to read protobuf: %w", err) - } - return nil + return read(ctx, c, v) } -func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { +func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { + defer errd.Wrap(&err, "failed to read Protobuf message") + typ, r, err := c.Reader(ctx) if err != nil { return err } if typ != websocket.MessageBinary { - c.Close(websocket.StatusUnsupportedData, "can only accept binary messages") - return fmt.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ) + c.Close(websocket.StatusUnsupportedData, "expected binary message") + return fmt.Errorf("expected binary message for Protobuf but got: %v", typ) } - b := bufpool.Get() - defer bufpool.Put(b) + b := bpool.Get() + defer bpool.Put(b) _, err = b.ReadFrom(r) if err != nil { @@ -43,33 +42,31 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { err = proto.Unmarshal(b.Bytes(), v) if err != nil { - c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") - return fmt.Errorf("failed to unmarshal protobuf: %w", err) + c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal Protobuf") + return fmt.Errorf("failed to unmarshal Protobuf: %w", err) } return nil } -// Write writes the protobuf message v to c. -// It will reuse buffers to avoid allocations. +// Write writes the Protobuf message v to c. +// It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { - err := write(ctx, c, v) - if err != nil { - return fmt.Errorf("failed to write protobuf: %w", err) - } - return nil + return write(ctx, c, v) } -func write(ctx context.Context, c *websocket.Conn, v proto.Message) error { - b := bufpool.Get() +func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { + defer errd.Wrap(&err, "failed to write Protobuf message") + + b := bpool.Get() pb := proto.NewBuffer(b.Bytes()) defer func() { - bufpool.Put(bytes.NewBuffer(pb.Bytes())) + bpool.Put(bytes.NewBuffer(pb.Bytes())) }() - err := pb.Marshal(v) + err = pb.Marshal(v) if err != nil { - return fmt.Errorf("failed to marshal protobuf: %w", err) + return fmt.Errorf("failed to marshal Protobuf: %w", err) } return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) From 6b782a3359d2055dba2a975c828047e6d36cdad4 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 28 Nov 2019 16:00:02 -0500 Subject: [PATCH 11/55] Run make fmt --- autobahn_test.go | 3 +- ci/fmt.mk | 2 +- close.go | 3 +- close_test.go | 6 ++- conn.go | 2 +- conn_test.go | 3 +- frame.go | 2 + frame_test.go | 5 ++- internal/assert/cmp.go | 3 +- internal/errd/wrap.go | 2 +- netconn.go | 1 - read.go | 3 +- websocket_stringer.go | 91 ++++++++++++++++++++++++++++++++++++++++++ write.go | 2 + ws_js.go | 2 +- ws_js_test.go | 10 +++++ wsjson/wsjson.go | 1 + wspb/wspb.go | 2 +- 18 files changed, 127 insertions(+), 16 deletions(-) create mode 100644 websocket_stringer.go diff --git a/autobahn_test.go b/autobahn_test.go index 27f8a1b4..21a30b4f 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -8,13 +8,14 @@ import ( "net" "net/http" "net/http/httptest" - "nhooyr.io/websocket" "os" "os/exec" "strconv" "strings" "testing" "time" + + "nhooyr.io/websocket" ) func TestAutobahn(t *testing.T) { diff --git a/ci/fmt.mk b/ci/fmt.mk index 3637c1ac..f82d74dd 100644 --- a/ci/fmt.mk +++ b/ci/fmt.mk @@ -22,4 +22,4 @@ prettier: prettier --write --print-width=120 --no-semi --trailing-comma=all --loglevel=warn $$(git ls-files "*.yml" "*.md") gen: - stringer -type=Opcode,MessageType,StatusCode -output=websocket_stringer.go + stringer -type=opcode,MessageType,StatusCode -output=websocket_stringer.go diff --git a/close.go b/close.go index 57d69a37..432019c6 100644 --- a/close.go +++ b/close.go @@ -6,8 +6,9 @@ import ( "errors" "fmt" "log" - "nhooyr.io/websocket/internal/bpool" "time" + + "nhooyr.io/websocket/internal/bpool" ) // StatusCode represents a WebSocket status code. diff --git a/close_test.go b/close_test.go index ca51a298..c2d11bb8 100644 --- a/close_test.go +++ b/close_test.go @@ -1,12 +1,14 @@ package websocket import ( - "github.com/google/go-cmp/cmp" "io" "math" - "nhooyr.io/websocket/internal/assert" "strings" "testing" + + "github.com/google/go-cmp/cmp" + + "nhooyr.io/websocket/internal/assert" ) func TestCloseError(t *testing.T) { diff --git a/conn.go b/conn.go index e3f24171..5c041b8d 100644 --- a/conn.go +++ b/conn.go @@ -194,7 +194,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { type mu struct { once sync.Once - ch chan struct{} + ch chan struct{} } func (m *mu) init() { diff --git a/conn_test.go b/conn_test.go index 1014dbf3..6b8a778b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,13 +8,13 @@ import ( "io" "net/http" "net/http/httptest" - "nhooyr.io/websocket/internal/assert" "strings" "sync/atomic" "testing" "time" "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/assert" ) func TestConn(t *testing.T) { @@ -51,7 +51,6 @@ func TestConn(t *testing.T) { }) } - func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request), tls bool) (s *httptest.Server, closeFn func()) { h := http.HandlerFunc(fn) if tls { diff --git a/frame.go b/frame.go index f36334c2..e55c8f2c 100644 --- a/frame.go +++ b/frame.go @@ -5,6 +5,8 @@ import ( "encoding/binary" "math" "math/bits" + + "nhooyr.io/websocket/internal/errd" ) // opcode represents a WebSocket opcode. diff --git a/frame_test.go b/frame_test.go index a4a1f5a8..fa231c57 100644 --- a/frame_test.go +++ b/frame_test.go @@ -7,7 +7,7 @@ import ( "bytes" "encoding/binary" "math/bits" - "nhooyr.io/websocket/internal/assert" + "math/rand" "strconv" "testing" "time" @@ -15,7 +15,8 @@ import ( "github.com/gobwas/ws" _ "github.com/gorilla/websocket" - "math/rand" + + "nhooyr.io/websocket/internal/assert" ) func init() { diff --git a/internal/assert/cmp.go b/internal/assert/cmp.go index 0edcf2cd..39be1f4a 100644 --- a/internal/assert/cmp.go +++ b/internal/assert/cmp.go @@ -1,8 +1,9 @@ package assert import ( - "github.com/google/go-cmp/cmp" "reflect" + + "github.com/google/go-cmp/cmp" ) // https://github.com/google/go-cmp/issues/40#issuecomment-328615283 diff --git a/internal/errd/wrap.go b/internal/errd/wrap.go index 849335c9..6e779131 100644 --- a/internal/errd/wrap.go +++ b/internal/errd/wrap.go @@ -9,6 +9,6 @@ import ( // Inspired by https://github.com/golang/go/issues/32676. func Wrap(err *error, f string, v ...interface{}) { if *err != nil { - *err = fmt.Errorf(f+ ": %w", append(v, *err)...) + *err = fmt.Errorf(f+": %w", append(v, *err)...) } } diff --git a/netconn.go b/netconn.go index 74a2c7c1..64aadf0b 100644 --- a/netconn.go +++ b/netconn.go @@ -164,4 +164,3 @@ func (c *netConn) SetReadDeadline(t time.Time) error { } return nil } - diff --git a/read.go b/read.go index 1f5a88ad..13c8d703 100644 --- a/read.go +++ b/read.go @@ -7,10 +7,11 @@ import ( "fmt" "io" "io/ioutil" - "nhooyr.io/websocket/internal/errd" "strings" "sync/atomic" "time" + + "nhooyr.io/websocket/internal/errd" ) // Reader waits until there is a WebSocket data message to read diff --git a/websocket_stringer.go b/websocket_stringer.go new file mode 100644 index 00000000..571e505f --- /dev/null +++ b/websocket_stringer.go @@ -0,0 +1,91 @@ +// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=websocket_stringer.go"; DO NOT EDIT. + +package websocket + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[opContinuation-0] + _ = x[opText-1] + _ = x[opBinary-2] + _ = x[opClose-8] + _ = x[opPing-9] + _ = x[opPong-10] +} + +const ( + _opcode_name_0 = "opContinuationopTextopBinary" + _opcode_name_1 = "opCloseopPingopPong" +) + +var ( + _opcode_index_0 = [...]uint8{0, 14, 20, 28} + _opcode_index_1 = [...]uint8{0, 7, 13, 19} +) + +func (i opcode) String() string { + switch { + case 0 <= i && i <= 2: + return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] + case 8 <= i && i <= 10: + i -= 8 + return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] + default: + return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" + } +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[MessageText-1] + _ = x[MessageBinary-2] +} + +const _MessageType_name = "MessageTextMessageBinary" + +var _MessageType_index = [...]uint8{0, 11, 24} + +func (i MessageType) String() string { + i -= 1 + if i < 0 || i >= MessageType(len(_MessageType_index)-1) { + return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[StatusNormalClosure-1000] + _ = x[StatusGoingAway-1001] + _ = x[StatusProtocolError-1002] + _ = x[StatusUnsupportedData-1003] + _ = x[statusReserved-1004] + _ = x[StatusNoStatusRcvd-1005] + _ = x[StatusAbnormalClosure-1006] + _ = x[StatusInvalidFramePayloadData-1007] + _ = x[StatusPolicyViolation-1008] + _ = x[StatusMessageTooBig-1009] + _ = x[StatusMandatoryExtension-1010] + _ = x[StatusInternalError-1011] + _ = x[StatusServiceRestart-1012] + _ = x[StatusTryAgainLater-1013] + _ = x[StatusBadGateway-1014] + _ = x[StatusTLSHandshake-1015] +} + +const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake" + +var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312} + +func (i StatusCode) String() string { + i -= 1000 + if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) { + return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")" + } + return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]] +} diff --git a/write.go b/write.go index e1ea007e..9cafc5c5 100644 --- a/write.go +++ b/write.go @@ -10,6 +10,8 @@ import ( "fmt" "io" "time" + + "nhooyr.io/websocket/internal/errd" ) // Writer returns a writer bounded by the context that will write diff --git a/ws_js.go b/ws_js.go index 882535b1..2e654feb 100644 --- a/ws_js.go +++ b/ws_js.go @@ -9,7 +9,6 @@ import ( "fmt" "io" "net/http" - "nhooyr.io/websocket/internal/wssync" "reflect" "runtime" "sync" @@ -17,6 +16,7 @@ import ( "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" + "nhooyr.io/websocket/internal/wssync" ) // Conn provides a wrapper around the browser WebSocket API. diff --git a/ws_js_test.go b/ws_js_test.go index abd950c7..9330b411 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -1,5 +1,15 @@ package websocket +import ( + "context" + "net/http" + "os" + "testing" + "time" + + "nhooyr.io/websocket" +) + func TestEcho(t *testing.T) { t.Parallel() diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 36dd2dfd..99996a69 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" diff --git a/wspb/wspb.go b/wspb/wspb.go index f4b7c1c5..666c6fa5 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -5,12 +5,12 @@ import ( "bytes" "context" "fmt" - "nhooyr.io/websocket/internal/errd" "github.com/golang/protobuf/proto" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/errd" ) // Read reads a Protobuf message from c into v. From 989ba2f7ae8912e0bf586be4a9e1bd2c6b7b3fcc Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 28 Nov 2019 16:37:18 -0500 Subject: [PATCH 12/55] Change websocket to WebSocket in docs/errors --- README.md | 2 +- accept.go | 126 ++++++++++----------------- accept_test.go | 3 +- ci/fmt.mk | 2 +- ci/test.mk | 2 +- close.go | 4 +- dial.go | 10 +-- example_echo_test.go | 2 +- internal/wsjs/wsjs_js.go | 4 +- websocket_stringer.go => stringer.go | 2 +- ws_js.go | 10 +-- 11 files changed, 68 insertions(+), 99 deletions(-) rename websocket_stringer.go => stringer.go (98%) diff --git a/README.md b/README.md index 8ac418a7..c927e8c1 100644 --- a/README.md +++ b/README.md @@ -170,4 +170,4 @@ See the gorilla/websocket comparison for more performance details. If your company or project is using this library, feel free to open an issue or PR to amend this list. - [Coder](https://github.com/cdr) -- [Tatsu Works](https://github.com/tatsuworks) - Ingresses 20 TB in websocket data every month on their Discord bot. +- [Tatsu Works](https://github.com/tatsuworks) - Ingresses 20 TB in WebSocket data every month on their Discord bot. diff --git a/accept.go b/accept.go index 2028d4b2..dbfb2c30 100644 --- a/accept.go +++ b/accept.go @@ -10,68 +10,56 @@ import ( "net/http" "net/textproto" "net/url" + "nhooyr.io/websocket/internal/errd" "strings" ) // AcceptOptions represents the options available to pass to Accept. type AcceptOptions struct { - // Subprotocols lists the websocket subprotocols that Accept will negotiate with a client. + // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to - // reject it, close the connection if c.Subprotocol() == "". + // reject it, close the connection when c.Subprotocol() == "". Subprotocols []string - // InsecureSkipVerify disables Accept's origin verification - // behaviour. By default Accept only allows the handshake to - // succeed if the javascript that is initiating the handshake - // is on the same domain as the server. This is to prevent CSRF - // attacks when secure data is stored in a cookie as there is no same - // origin policy for WebSockets. In other words, javascript from - // any domain can perform a WebSocket dial on an arbitrary server. - // This dial will include cookies which means the arbitrary javascript - // can perform actions as the authenticated user. + // InsecureSkipVerify disables Accept's origin verification behaviour. By default, + // the connection will only be accepted if the request origin is equal to the request + // host. + // + // This is only required if you want javascript served from a different domain + // to access your WebSocket server. // // See https://stackoverflow.com/a/37837709/4283659 // - // The only time you need this is if your javascript is running on a different domain - // than your WebSocket server. - // Think carefully about whether you really need this option before you use it. - // If you do, remember that if you store secure data in cookies, you wil need to verify the - // Origin header yourself otherwise you are exposing yourself to a CSRF attack. + // Please ensure you understand the ramifications of enabling this. + // If used incorrectly your WebSocket server will be open to CSRF attacks. InsecureSkipVerify bool // CompressionMode sets the compression mode. - // See docs on the CompressionMode type and defined constants. + // See docs on the CompressionMode type. CompressionMode CompressionMode } -// Accept accepts a WebSocket HTTP handshake from a client and upgrades the +// Accept accepts a WebSocket handshake from a client and upgrades the // the connection to a WebSocket. // -// Accept will reject the handshake if the Origin domain is not the same as the Host unless -// the InsecureSkipVerify option is set. In other words, by default it does not allow -// cross origin requests. +// Accept will not allow cross origin requests by default. +// See the InsecureSkipVerify option to allow cross origin requests. // -// If an error occurs, Accept will write a response with a safe error message to w. +// Accept will write a response to w on all errors. func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - c, err := accept(w, r, opts) - if err != nil { - return nil, fmt.Errorf("failed to accept websocket connection: %w", err) - } - return c, nil + return accept(w, r, opts) } -func (opts *AcceptOptions) ensure() *AcceptOptions { +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { + defer errd.Wrap(&err, "failed to accept WebSocket connection") + if opts == nil { - return &AcceptOptions{} + opts = &AcceptOptions{} } - return opts -} - -func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - opts = opts.ensure() - err := verifyClientRequest(w, r) + err = verifyClientRequest(r) if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } @@ -85,7 +73,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, hj, ok := w.(http.Hijacker) if !ok { - err = errors.New("passed ResponseWriter does not implement http.Hijacker") + err = errors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return nil, err } @@ -93,7 +81,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, w.Header().Set("Upgrade", "websocket") w.Header().Set("Connection", "Upgrade") - handleSecWebSocketKey(w, r) + key := r.Header.Get("Sec-WebSocket-Key") + w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) subproto := selectSubprotocol(r, opts.Subprotocols) if subproto != "" { @@ -102,7 +91,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, copts, err := acceptCompression(r, w, opts.CompressionMode) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } @@ -129,41 +117,29 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, }), nil } -func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { +func verifyClientRequest(r *http.Request) error { if !r.ProtoAtLeast(1, 1) { - err := fmt.Errorf("websocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) - http.Error(w, err.Error(), http.StatusBadRequest) - return err + return fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } if !headerContainsToken(r.Header, "Connection", "Upgrade") { - err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) - http.Error(w, err.Error(), http.StatusBadRequest) - return err + return fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } - if !headerContainsToken(r.Header, "Upgrade", "WebSocket") { - err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) - http.Error(w, err.Error(), http.StatusBadRequest) - return err + if !headerContainsToken(r.Header, "Upgrade", "websocket") { + return fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) } if r.Method != "GET" { - err := fmt.Errorf("websocket protocol violation: handshake request method is not GET but %q", r.Method) - http.Error(w, err.Error(), http.StatusBadRequest) - return err + return fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) } if r.Header.Get("Sec-WebSocket-Version") != "13" { - err := fmt.Errorf("unsupported websocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) - http.Error(w, err.Error(), http.StatusBadRequest) - return err + return fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } if r.Header.Get("Sec-WebSocket-Key") == "" { - err := errors.New("websocket protocol violation: missing Sec-WebSocket-Key") - http.Error(w, err.Error(), http.StatusBadRequest) - return err + return errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } return nil @@ -171,30 +147,20 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { func authenticateOrigin(r *http.Request) error { origin := r.Header.Get("Origin") - if origin == "" { - return nil - } - u, err := url.Parse(origin) - if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) - } - if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + if origin != "" { + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + if !strings.EqualFold(u.Host, r.Host) { + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + } } return nil } -func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { - key := r.Header.Get("Sec-WebSocket-Key") - w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) -} - func selectSubprotocol(r *http.Request, subprotocols []string) string { cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") - if len(cps) == 0 { - return "" - } - for _, sp := range subprotocols { for _, cp := range cps { if strings.EqualFold(sp, cp) { @@ -236,7 +202,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi continue } - return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err } copts.setHeader(w.Header()) @@ -264,7 +232,9 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com // // Either way, we're only implementing this for webkit which never sends the max_window_bits // parameter so we don't need to worry about it. - return nil, fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) + err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err } s := "x-webkit-deflate-frame" diff --git a/accept_test.go b/accept_test.go index 9598cd58..a8ab7d69 100644 --- a/accept_test.go +++ b/accept_test.go @@ -114,7 +114,6 @@ func Test_verifyClientHandshake(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - w := httptest.NewRecorder() r := httptest.NewRequest(tc.method, "/", nil) r.ProtoMajor = 1 @@ -127,7 +126,7 @@ func Test_verifyClientHandshake(t *testing.T) { r.Header.Set(k, v) } - err := verifyClientRequest(w, r) + err := verifyClientRequest(r) if (err == nil) != tc.success { t.Fatalf("unexpected error value: %+v", err) } diff --git a/ci/fmt.mk b/ci/fmt.mk index f82d74dd..f3969721 100644 --- a/ci/fmt.mk +++ b/ci/fmt.mk @@ -22,4 +22,4 @@ prettier: prettier --write --print-width=120 --no-semi --trailing-comma=all --loglevel=warn $$(git ls-files "*.yml" "*.md") gen: - stringer -type=opcode,MessageType,StatusCode -output=websocket_stringer.go + stringer -type=opcode,MessageType,StatusCode -output=stringer.go diff --git a/ci/test.mk b/ci/test.mk index 9e4e0803..f9a6e09a 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -19,5 +19,5 @@ coveralls: gotest gotest: go test -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... - sed -i '/_stringer\.go/d' ci/out/coverage.prof + sed -i '/stringer\.go/d' ci/out/coverage.prof sed -i '/assert/d' ci/out/coverage.prof diff --git a/close.go b/close.go index 432019c6..6bb48bd5 100644 --- a/close.go +++ b/close.go @@ -97,7 +97,7 @@ func CloseStatus(err error) StatusCode { func (c *Conn) Close(code StatusCode, reason string) error { err := c.closeHandshake(code, reason) if err != nil { - return fmt.Errorf("failed to close websocket: %w", err) + return fmt.Errorf("failed to close WebSocket: %w", err) } return nil } @@ -236,7 +236,7 @@ func (c *Conn) setCloseErr(err error) { func (c *Conn) setCloseErrNoLock(err error) { if c.closeErr == nil { - c.closeErr = fmt.Errorf("websocket closed: %w", err) + c.closeErr = fmt.Errorf("WebSocket closed: %w", err) } } diff --git a/dial.go b/dial.go index 8fa0f7ab..3a2165ab 100644 --- a/dial.go +++ b/dial.go @@ -47,7 +47,7 @@ type DialOptions struct { func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { c, r, err := dial(ctx, u, opts) if err != nil { - return nil, r, fmt.Errorf("failed to websocket dial: %w", err) + return nil, r, fmt.Errorf("failed to WebSocket dial: %w", err) } return c, r, nil } @@ -158,22 +158,22 @@ func verifyServerResponse(r *http.Request, resp *http.Response) (*compressionOpt } if !headerContainsToken(resp.Header, "Connection", "Upgrade") { - return nil, fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { - return nil, fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { - return nil, fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), r.Header.Get("Sec-WebSocket-Key"), ) } if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { - return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) + return nil, fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } copts, err := verifyServerExtensions(resp.Header) diff --git a/example_echo_test.go b/example_echo_test.go index 16d003d9..cd195d2e 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -93,7 +93,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { } } -// echo reads from the websocket connection and then writes +// echo reads from the WebSocket connection and then writes // the received message back to it. // The entire function has 10s to complete. func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { diff --git a/internal/wsjs/wsjs_js.go b/internal/wsjs/wsjs_js.go index d48691d4..26ffb456 100644 --- a/internal/wsjs/wsjs_js.go +++ b/internal/wsjs/wsjs_js.go @@ -102,7 +102,7 @@ type MessageEvent struct { // See https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent } -// OnMessage registers a function to be called when the websocket receives a message. +// OnMessage registers a function to be called when the WebSocket receives a message. func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) { return c.addEventListener("message", func(e js.Value) { var data interface{} @@ -128,7 +128,7 @@ func (c WebSocket) Subprotocol() string { return c.v.Get("protocol").String() } -// OnOpen registers a function to be called when the websocket is opened. +// OnOpen registers a function to be called when the WebSocket is opened. func (c WebSocket) OnOpen(fn func(e js.Value)) (remove func()) { return c.addEventListener("open", fn) } diff --git a/websocket_stringer.go b/stringer.go similarity index 98% rename from websocket_stringer.go rename to stringer.go index 571e505f..5a66ba29 100644 --- a/websocket_stringer.go +++ b/stringer.go @@ -1,4 +1,4 @@ -// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=websocket_stringer.go"; DO NOT EDIT. +// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT. package websocket diff --git a/ws_js.go b/ws_js.go index 2e654feb..7f10ee17 100644 --- a/ws_js.go +++ b/ws_js.go @@ -105,7 +105,7 @@ func (c *Conn) closeWithInternal() { // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { if c.isReadClosed.Load() == 1 { - return 0, nil, errors.New("websocket connection read closed") + return 0, nil, errors.New("WebSocket connection read closed") } typ, p, err := c.read(ctx) @@ -188,14 +188,14 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { } } -// Close closes the websocket with the given code and reason. +// Close closes the WebSocket with the given code and reason. // It will wait until the peer responds with a close frame // or the connection is closed. // It thus performs the full WebSocket close handshake. func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { - return fmt.Errorf("failed to close websocket: %w", err) + return fmt.Errorf("failed to close WebSocket: %w", err) } return nil } @@ -245,7 +245,7 @@ type DialOptions struct { func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { c, resp, err := dial(ctx, url, opts) if err != nil { - return nil, resp, fmt.Errorf("failed to websocket dial %q: %w", url, err) + return nil, resp, fmt.Errorf("failed to WebSocket dial %q: %w", url, err) } return c, resp, nil } @@ -359,7 +359,7 @@ func (c *Conn) SetReadLimit(n int64) { func (c *Conn) setCloseErr(err error) { c.closeErrOnce.Do(func() { - c.closeErr = fmt.Errorf("websocket closed: %w", err) + c.closeErr = fmt.Errorf("WebSocket closed: %w", err) }) } From 9f159635813f2022ea18007636be4f89e6043afa Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 28 Nov 2019 17:10:42 -0500 Subject: [PATCH 13/55] Simplify dial.go --- accept.go | 7 +- accept_test.go | 27 ++++--- autobahn_test.go | 21 ++---- dial.go | 182 +++++++++++++++++++++++++---------------------- dial_test.go | 6 +- 5 files changed, 124 insertions(+), 119 deletions(-) diff --git a/accept.go b/accept.go index dbfb2c30..964e0401 100644 --- a/accept.go +++ b/accept.go @@ -10,11 +10,12 @@ import ( "net/http" "net/textproto" "net/url" - "nhooyr.io/websocket/internal/errd" "strings" + + "nhooyr.io/websocket/internal/errd" ) -// AcceptOptions represents the options available to pass to Accept. +// AcceptOptions represents Accept's options. type AcceptOptions struct { // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to @@ -35,7 +36,7 @@ type AcceptOptions struct { InsecureSkipVerify bool // CompressionMode sets the compression mode. - // See docs on the CompressionMode type. + // See the docs on CompressionMode. CompressionMode CompressionMode } diff --git a/accept_test.go b/accept_test.go index a8ab7d69..d68d4d6d 100644 --- a/accept_test.go +++ b/accept_test.go @@ -4,6 +4,8 @@ import ( "net/http/httptest" "strings" "testing" + + "nhooyr.io/websocket/internal/assert" ) func TestAccept(t *testing.T) { @@ -16,10 +18,7 @@ func TestAccept(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) _, err := Accept(w, r, nil) - if err == nil { - t.Fatalf("unexpected error value: %v", err) - } - + assert.ErrorContains(t, err, "protocol violation") }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -33,9 +32,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) - if err == nil || !strings.Contains(err.Error(), "http.Hijacker") { - t.Fatalf("unexpected error value: %v", err) - } + assert.ErrorContains(t, err, "http.ResponseWriter does not implement http.Hijacker") }) } @@ -127,8 +124,10 @@ func Test_verifyClientHandshake(t *testing.T) { } err := verifyClientRequest(r) - if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %+v", err) + if tc.success { + assert.Success(t, err) + } else { + assert.Error(t, err) } }) } @@ -178,9 +177,7 @@ func Test_selectSubprotocol(t *testing.T) { r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) negotiated := selectSubprotocol(r, tc.serverProtocols) - if tc.negotiated != negotiated { - t.Fatalf("expected %q but got %q", tc.negotiated, negotiated) - } + assert.Equal(t, tc.negotiated, negotiated, "negotiated") }) } } @@ -234,8 +231,10 @@ func Test_authenticateOrigin(t *testing.T) { r.Header.Set("Origin", tc.origin) err := authenticateOrigin(r) - if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %+v", err) + if tc.success { + assert.Success(t, err) + } else { + assert.Error(t, err) } }) } diff --git a/autobahn_test.go b/autobahn_test.go index 21a30b4f..30c96a7c 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -18,21 +18,19 @@ import ( "nhooyr.io/websocket" ) +// https://github.com/crossbario/autobahn-python/tree/master/wstest func TestAutobahn(t *testing.T) { - // This test contains the old autobahn test suite tests that use the - // python binary. The approach is clunky and slow so new tests - // have been written in pure Go in websocket_test.go. - // These have been kept for correctness purposes and are occasionally ran. + t.Parallel() + if os.Getenv("AUTOBAHN") == "" { t.Skip("Set $AUTOBAHN to run tests against the autobahn test suite") } - t.Run("server", testServerAutobahnPython) - t.Run("client", testClientAutobahnPython) + t.Run("server", testServerAutobahn) + t.Run("client", testClientAutobahn) } -// https://github.com/crossbario/autobahn-python/tree/master/wstest -func testServerAutobahnPython(t *testing.T) { +func testServerAutobahn(t *testing.T) { t.Parallel() s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -101,14 +99,9 @@ func unusedListenAddr() (string, error) { return l.Addr().String(), nil } -// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py -func testClientAutobahnPython(t *testing.T) { +func testClientAutobahn(t *testing.T) { t.Parallel() - if os.Getenv("AUTOBAHN_PYTHON") == "" { - t.Skip("Set $AUTOBAHN_PYTHON to test against the python autobahn test suite") - } - serverAddr, err := unusedListenAddr() if err != nil { t.Fatalf("failed to get unused listen addr for wstest: %v", err) diff --git a/dial.go b/dial.go index 3a2165ab..a1a10556 100644 --- a/dial.go +++ b/dial.go @@ -14,51 +14,50 @@ import ( "net/url" "strings" "sync" + + "nhooyr.io/websocket/internal/errd" ) -// DialOptions represents the options available to pass to Dial. +// DialOptions represents Dial's options. type DialOptions struct { - // HTTPClient is the http client used for the handshake. - // Its Transport must return writable bodies - // for WebSocket handshakes. - // http.Transport does this correctly beginning with Go 1.12. + // HTTPClient is used for the connection. + // Its Transport must return writable bodies for WebSocket handshakes. + // http.Transport does beginning with Go 1.12. HTTPClient *http.Client // HTTPHeader specifies the HTTP headers included in the handshake request. HTTPHeader http.Header - // Subprotocols lists the subprotocols to negotiate with the server. + // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string - // See docs on CompressionMode. + // CompressionMode sets the compression mode. + // See the docs on CompressionMode. CompressionMode CompressionMode } -// Dial performs a WebSocket handshake on the given url with the given options. +// Dial performs a WebSocket handshake on url. +// // The response is the WebSocket handshake response from the server. -// If an error occurs, the returned response may be non nil. However, you can only -// read the first 1024 bytes of its body. +// You never need to close resp.Body yourself. // -// You never need to close the resp.Body yourself. +// If an error occurs, the returned response may be non nil. +// However, you can only read the first 1024 bytes of the body. // -// This function requires at least Go 1.12 to succeed as it uses a new feature -// in net/http to perform WebSocket handshakes and get a writable body -// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 +// This function requires at least Go 1.12 as it uses a new feature +// in net/http to perform WebSocket handshakes. +// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { - c, r, err := dial(ctx, u, opts) - if err != nil { - return nil, r, fmt.Errorf("failed to WebSocket dial: %w", err) - } - return c, r, nil + return dial(ctx, u, opts) } -func (opts *DialOptions) ensure() *DialOptions { +func dial(ctx context.Context, urls string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { + defer errd.Wrap(&err, "failed to WebSocket dial") + if opts == nil { opts = &DialOptions{} - } else { - opts = &*opts } - + opts = &*opts if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } @@ -66,71 +65,35 @@ func (opts *DialOptions) ensure() *DialOptions { opts.HTTPHeader = http.Header{} } - return opts -} - -func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { - opts = opts.ensure() - - if opts.HTTPClient.Timeout > 0 { - return nil, nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } - - parsedURL, err := url.Parse(u) - if err != nil { - return nil, nil, fmt.Errorf("failed to parse url: %w", err) - } - - switch parsedURL.Scheme { - case "ws": - parsedURL.Scheme = "http" - case "wss": - parsedURL.Scheme = "https" - default: - return nil, nil, fmt.Errorf("unexpected url scheme: %q", parsedURL.Scheme) - } - - req, _ := http.NewRequest("GET", parsedURL.String(), nil) - req = req.WithContext(ctx) - req.Header = opts.HTTPHeader - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Sec-WebSocket-Version", "13") secWebSocketKey, err := secWebSocketKey() if err != nil { return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } - req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) - if len(opts.Subprotocols) > 0 { - req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) - } - if opts.CompressionMode != CompressionDisabled { - copts := opts.CompressionMode.opts() - copts.setHeader(req.Header) - } - resp, err := opts.HTTPClient.Do(req) + resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey) if err != nil { - return nil, nil, fmt.Errorf("failed to send handshake request: %w", err) + return nil, resp, err } + respBody := resp.Body + resp.Body = nil defer func() { if err != nil { // We read a bit of the body for easier debugging. - r := io.LimitReader(resp.Body, 1024) + r := io.LimitReader(respBody, 1024) b, _ := ioutil.ReadAll(r) - resp.Body.Close() + respBody.Close() resp.Body = ioutil.NopCloser(bytes.NewReader(b)) } }() - copts, err := verifyServerResponse(req, resp) + copts, err := verifyServerResponse(opts, secWebSocketKey, resp) if err != nil { return nil, resp, err } - rwc, ok := resp.Body.(io.ReadWriteCloser) + rwc, ok := respBody.(io.ReadWriteCloser) if !ok { - return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", rwc) + return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) } return newConn(connConfig{ @@ -143,6 +106,46 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re }), resp, nil } +func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) { + if opts.HTTPClient.Timeout > 0 { + return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") + } + + u, err := url.Parse(urls) + if err != nil { + return nil, fmt.Errorf("failed to parse url: %w", err) + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) + } + + req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req.Header = opts.HTTPHeader.Clone() + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) + if len(opts.Subprotocols) > 0 { + req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) + } + if opts.CompressionMode != CompressionDisabled { + copts := opts.CompressionMode.opts() + copts.setHeader(req.Header) + } + + resp, err := opts.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send handshake request: %w", err) + } + return resp, nil +} + func secWebSocketKey() (string, error) { b := make([]byte, 16) _, err := io.ReadFull(rand.Reader, b) @@ -152,7 +155,7 @@ func secWebSocketKey() (string, error) { return base64.StdEncoding.EncodeToString(b), nil } -func verifyServerResponse(r *http.Request, resp *http.Response) (*compressionOptions, error) { +func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } @@ -165,23 +168,34 @@ func verifyServerResponse(r *http.Request, resp *http.Response) (*compressionOpt return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } - if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { + if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), - r.Header.Get("Sec-WebSocket-Key"), + secWebSocketKey, ) } - if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { - return nil, fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) - } - - copts, err := verifyServerExtensions(resp.Header) + err := verifySubprotocol(opts.Subprotocols, resp) if err != nil { return nil, err } - return copts, nil + return verifyServerExtensions(resp.Header) +} + +func verifySubprotocol(subprotos []string, resp *http.Response) error { + proto := resp.Header.Get("Sec-WebSocket-Protocol") + if proto == "" { + return nil + } + + for _, sp2 := range subprotos { + if strings.EqualFold(sp2, proto) { + return nil + } + } + + return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } func verifyServerExtensions(h http.Header) (*compressionOptions, error) { @@ -191,12 +205,8 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) { } ext := exts[0] - if ext.name != "permessage-deflate" { - return nil, fmt.Errorf("unexpected extension from server: %q", ext) - } - - if len(exts) > 1 { - return nil, fmt.Errorf("unexpected extra extensions from server: %+v", exts[1:]) + if ext.name != "permessage-deflate" || len(exts) > 1 { + return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } copts := &compressionOptions{} @@ -204,13 +214,11 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) { switch p { case "client_no_context_takeover": copts.clientNoContextTakeover = true - continue case "server_no_context_takeover": copts.serverNoContextTakeover = true - continue + default: + return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } - - return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) } return copts, nil diff --git a/dial_test.go b/dial_test.go index 5eeb904a..6286f0ff 100644 --- a/dial_test.go +++ b/dial_test.go @@ -6,6 +6,7 @@ import ( "context" "net/http" "net/http/httptest" + "strings" "testing" "time" ) @@ -140,7 +141,10 @@ func Test_verifyServerHandshake(t *testing.T) { resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } - _, err = verifyServerResponse(r, resp) + opts := &DialOptions{ + Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), + } + _, err = verifyServerResponse(opts, key, resp) if (err == nil) != tc.success { t.Fatalf("unexpected error: %+v", err) } From 120911b598eab98c2ad624baaa0f81b473e7baad Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 28 Nov 2019 18:13:37 -0500 Subject: [PATCH 14/55] Remove use of math/rand.Init --- assert_test.go | 38 +++++++++++++++----------------------- close.go | 43 ++++++++++++++++++++++--------------------- close_test.go | 25 +++++++++++-------------- doc.go | 43 +++++++++---------------------------------- frame.go | 2 +- frame_test.go | 23 +++++++++-------------- go.mod | 2 +- ws_js.go | 14 +++++--------- ws_js_test.go | 2 +- 9 files changed, 74 insertions(+), 118 deletions(-) diff --git a/assert_test.go b/assert_test.go index 6e4e75e6..b6e50a47 100644 --- a/assert_test.go +++ b/assert_test.go @@ -2,38 +2,31 @@ package websocket_test import ( "context" - "math/rand" + "crypto/rand" + "io" "strings" "testing" - "time" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/assert" "nhooyr.io/websocket/wsjson" ) -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func randBytes(n int) []byte { +func randBytes(t *testing.T, n int) []byte { b := make([]byte, n) - rand.Read(b) + _, err := io.ReadFull(rand.Reader, b) + assert.Success(t, err) return b } func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) { t.Helper() - exp := randString(n) + exp := randString(t, n) err := wsjson.Write(ctx, c, exp) assert.Success(t, err) - var act interface{} - err = wsjson.Read(ctx, c, &act) - assert.Success(t, err) - - assert.Equal(t, exp, act, "unexpected JSON") + assertJSONRead(t, ctx, c, exp) } func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { @@ -43,11 +36,11 @@ func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp in err := wsjson.Read(ctx, c, &act) assert.Success(t, err) - assert.Equal(t, exp, act, "unexpected JSON") + assert.Equal(t, exp, act, "JSON") } -func randString(n int) string { - s := strings.ToValidUTF8(string(randBytes(n)), "_") +func randString(t *testing.T, n int) string { + s := strings.ToValidUTF8(string(randBytes(t, n)), "_") if len(s) > n { return s[:n] } @@ -62,25 +55,24 @@ func randString(n int) string { func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) { t.Helper() - p := randBytes(n) + p := randBytes(t, n) err := c.Write(ctx, typ, p) assert.Success(t, err) typ2, p2, err := c.Read(ctx) assert.Success(t, err) - assert.Equal(t, typ, typ2, "unexpected data type") - assert.Equal(t, p, p2, "unexpected payload") + assert.Equal(t, typ, typ2, "data type") + assert.Equal(t, p, p2, "payload") } func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { t.Helper() - assert.Equal(t, exp, c.Subprotocol(), "unexpected subprotocol") + assert.Equal(t, exp, c.Subprotocol(), "subprotocol") } func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) { t.Helper() - - assert.Equal(t, exp, websocket.CloseStatus(err), "unexpected status code") + assert.Equal(t, exp, websocket.CloseStatus(err), "StatusCode") } diff --git a/close.go b/close.go index 6bb48bd5..baa1a7e0 100644 --- a/close.go +++ b/close.go @@ -15,11 +15,13 @@ import ( // https://tools.ietf.org/html/rfc6455#section-7.4 type StatusCode int -// These codes were retrieved from: // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // -// The defined constants only represent the status codes registered with IANA. -// The 4000-4999 range of status codes is reserved for arbitrary use by applications. +// These are only the status codes defined by the protocol. +// +// You can define custom codes in the 3000-4999 range. +// The 3000-3999 range is reserved for use by libraries, frameworks and applications. +// The 4000-4999 range is reserved for private use. const ( StatusNormalClosure StatusCode = 1000 StatusGoingAway StatusCode = 1001 @@ -31,11 +33,12 @@ const ( // StatusNoStatusRcvd cannot be sent in a close message. // It is reserved for when a close message is received without - // an explicit status. + // a status code. StatusNoStatusRcvd StatusCode = 1005 - // StatusAbnormalClosure is only exported for use with Wasm. - // In non Wasm Go, the returned error will indicate whether the connection was closed or not or what happened. + // StatusAbnormalClosure is exported for use only with Wasm. + // In non Wasm Go, the returned error will indicate whether the + // connection was closed abnormally. StatusAbnormalClosure StatusCode = 1006 StatusInvalidFramePayloadData StatusCode = 1007 @@ -48,15 +51,15 @@ const ( StatusBadGateway StatusCode = 1014 // StatusTLSHandshake is only exported for use with Wasm. - // In non Wasm Go, the returned error will indicate whether there was a TLS handshake failure. + // In non Wasm Go, the returned error will indicate whether there was + // a TLS handshake failure. StatusTLSHandshake StatusCode = 1015 ) -// CloseError represents a WebSocket close frame. -// It is returned by Conn's methods when a WebSocket close frame is received from -// the peer. -// You will need to use the https://golang.org/pkg/errors/#As function, new in Go 1.13, -// to check for this error. See the CloseError example. +// CloseError is returned when the connection is closed with a status and reason. +// +// Use Go 1.13's errors.As to check for this error. +// Also see the CloseStatus helper. type CloseError struct { Code StatusCode Reason string @@ -66,9 +69,10 @@ func (ce CloseError) Error() string { return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } -// CloseStatus is a convenience wrapper around errors.As to grab -// the status code from a *CloseError. If the passed error is nil -// or not a *CloseError, the returned StatusCode will be -1. +// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab +// the status code from a CloseError. +// +// -1 will be returned if the passed error is nil or not a CloseError. func CloseStatus(err error) StatusCode { var ce CloseError if errors.As(err, &ce) { @@ -77,19 +81,16 @@ func CloseStatus(err error) StatusCode { return -1 } -// Close closes the WebSocket connection with the given status code and reason. +// Close performs the WebSocket close handshake with the given status code and reason. // // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for // the peer to send a close frame. -// Thus, it implements the full WebSocket close handshake. -// All data messages received from the peer during the close handshake -// will be discarded. +// All data messages received from the peer during the close handshake will be discarded. // // The connection can only be closed once. Additional calls to Close // are no-ops. // -// The maximum length of reason must be 125 bytes otherwise an internal -// error will be sent to the peer. For this reason, you should avoid +// The maximum length of reason must be 125 bytes. Avoid // sending a dynamic reason. // // Close will unblock all goroutines interacting with the connection once diff --git a/close_test.go b/close_test.go index c2d11bb8..9551699a 100644 --- a/close_test.go +++ b/close_test.go @@ -6,8 +6,6 @@ import ( "strings" "testing" - "github.com/google/go-cmp/cmp" - "nhooyr.io/websocket/internal/assert" ) @@ -51,8 +49,10 @@ func TestCloseError(t *testing.T) { t.Parallel() _, err := tc.ce.bytesErr() - if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %+v", err) + if (tc.success) { + assert.Success(t, err) + } else { + assert.Error(t, err) } }) } @@ -101,12 +101,11 @@ func Test_parseClosePayload(t *testing.T) { t.Parallel() ce, err := parseClosePayload(tc.p) - if (err == nil) != tc.success { - t.Fatalf("unexpected expected error value: %+v", err) - } - - if tc.success && tc.ce != ce { - t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) + if (tc.success) { + assert.Success(t, err) + assert.Equal(t, tc.ce, ce, "CloseError") + } else { + assert.Error(t, err) } }) } @@ -152,9 +151,7 @@ func Test_validWireCloseCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - if valid := validWireCloseCode(tc.code); tc.valid != valid { - t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) - } + assert.Equal(t, tc.code, validWireCloseCode(tc.code), "validWireCloseCode") }) } } @@ -191,7 +188,7 @@ func TestCloseStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, tc.exp, CloseStatus(tc.in), "unexpected close status") + assert.Equal(t, tc.exp, CloseStatus(tc.in), "CloseStatus") }) } } diff --git a/doc.go b/doc.go index 5285a780..54b7e1ea 100644 --- a/doc.go +++ b/doc.go @@ -4,51 +4,26 @@ // // https://tools.ietf.org/html/rfc6455 // -// Conn, Dial, and Accept are the main entrypoints into this package. Use Dial to dial -// a WebSocket server, Accept to accept a WebSocket client dial and then Conn to interact -// with the resulting WebSocket connections. +// Use Dial to dial a WebSocket server and Accept to accept a WebSocket client. +// Conn represents the resulting WebSocket connection. // // The examples are the best way to understand how to correctly use the library. // -// The wsjson and wspb subpackages contain helpers for JSON and ProtoBuf messages. +// The wsjson and wspb subpackages contain helpers for JSON and Protobuf messages. // -// See https://nhooyr.io/websocket for more overview docs and a -// comparison with existing implementations. -// -// Use the errors.As function new in Go 1.13 to check for websocket.CloseError. -// Or use the CloseStatus function to grab the StatusCode out of a websocket.CloseError -// See the CloseStatus example. +// See https://nhooyr.io/websocket for further information. // // Wasm // -// The client side fully supports compiling to Wasm. +// The client side supports compiling to Wasm. // It wraps the WebSocket browser API. // // See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket // -// Thus the unsupported features (not compiled in) for Wasm are: -// -// - Accept and AcceptOptions -// - Conn.Ping -// - HTTPClient and HTTPHeader fields in DialOptions -// - CompressionOptions -// -// The *http.Response returned by Dial will always either be nil or &http.Response{} as -// we do not have access to the handshake response in the browser. -// -// The Writer method on the Conn buffers everything in memory and then sends it as a message -// when the writer is closed. -// -// The Reader method also reads the entire response and then returns a reader that -// reads from the byte slice. -// -// SetReadLimit cannot actually limit the number of bytes read from the connection so instead -// when a message beyond the limit is fully read, it throws an error. -// -// Writes are also always async so the passed context is no-op. +// Some important caveats to be aware of: // -// Everything else is fully supported. This includes the wsjson and wspb helper packages. +// - Conn.Ping is no-op +// - HTTPClient, HTTPHeader and CompressionOptions in DialOptions are no-op +// - *http.Response from Dial is &http.Response{} on success // -// Once https://github.com/gopherjs/gopherjs/issues/929 is closed, GopherJS should be supported -// as well. package websocket // import "nhooyr.io/websocket" diff --git a/frame.go b/frame.go index e55c8f2c..0257835e 100644 --- a/frame.go +++ b/frame.go @@ -12,7 +12,7 @@ import ( // opcode represents a WebSocket opcode. type opcode int -// List at https://tools.ietf.org/html/rfc6455#section-11.8. +// https://tools.ietf.org/html/rfc6455#section-11.8. const ( opContinuation opcode = iota opText diff --git a/frame_test.go b/frame_test.go index fa231c57..68455cfa 100644 --- a/frame_test.go +++ b/frame_test.go @@ -19,10 +19,6 @@ import ( "nhooyr.io/websocket/internal/assert" ) -func init() { - rand.Seed(time.Now().UnixNano()) -} - func TestHeader(t *testing.T) { t.Parallel() @@ -56,8 +52,9 @@ func TestHeader(t *testing.T) { t.Run("fuzz", func(t *testing.T) { t.Parallel() + r := rand.New(rand.NewSource(time.Now().UnixNano())) randBool := func() bool { - return rand.Intn(1) == 0 + return r.Intn(1) == 0 } for i := 0; i < 10000; i++ { @@ -66,11 +63,11 @@ func TestHeader(t *testing.T) { rsv1: randBool(), rsv2: randBool(), rsv3: randBool(), - opcode: opcode(rand.Intn(16)), + opcode: opcode(r.Intn(16)), masked: randBool(), - maskKey: rand.Uint32(), - payloadLength: rand.Int63(), + maskKey: r.Uint32(), + payloadLength: r.Int63(), } testHeader(t, h) @@ -91,7 +88,7 @@ func testHeader(t *testing.T, h header) { h2, err := readFrameHeader(r) assert.Success(t, err) - assert.Equal(t, h, h2, "written and read headers differ") + assert.Equal(t, h, h2, "header") } func Test_mask(t *testing.T) { @@ -102,8 +99,8 @@ func Test_mask(t *testing.T) { p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} gotKey32 := mask(key32, p) - assert.Equal(t, []byte{0, 0, 0, 0x0d, 0x6}, p, "unexpected mask") - assert.Equal(t, bits.RotateLeft32(key32, -8), gotKey32, "unexpected mask key") + assert.Equal(t, []byte{0, 0, 0, 0x0d, 0x6}, p, "mask") + assert.Equal(t, bits.RotateLeft32(key32, -8), gotKey32, "mask key") } func basicMask(maskKey [4]byte, pos int, b []byte) int { @@ -173,9 +170,7 @@ func Benchmark_mask(b *testing.B) { }, } - var key [4]byte - _, err := rand.Read(key[:]) - assert.Success(b, err) + key := [4]byte{1, 2, 3, 4} for _, size := range sizes { p := make([]byte, size) diff --git a/go.mod b/go.mod index 3108c020..1a2b08f4 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect github.com/gobwas/pool v0.2.0 // indirect github.com/gobwas/ws v1.0.2 - github.com/golang/protobuf v1.3.2 // indirect + github.com/golang/protobuf v1.3.2 github.com/google/go-cmp v0.3.1 github.com/gorilla/websocket v1.4.1 github.com/kr/pretty v0.1.0 // indirect diff --git a/ws_js.go b/ws_js.go index 7f10ee17..3043106b 100644 --- a/ws_js.go +++ b/ws_js.go @@ -1,5 +1,3 @@ -// +build js - package websocket // import "nhooyr.io/websocket" import ( @@ -8,7 +6,6 @@ import ( "errors" "fmt" "io" - "net/http" "reflect" "runtime" "sync" @@ -242,15 +239,15 @@ type DialOptions struct { // The passed context bounds the maximum time spent waiting for the connection to open. // The returned *http.Response is always nil or the zero value. It's only in the signature // to match the core API. -func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { - c, resp, err := dial(ctx, url, opts) +func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, error) { + c, err := dial(ctx, url, opts) if err != nil { return nil, resp, fmt.Errorf("failed to WebSocket dial %q: %w", url, err) } - return c, resp, nil + return c, nil } -func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { +func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, error) { if opts == nil { opts = &DialOptions{} } @@ -280,8 +277,7 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp return c, nil, c.closeErr } - // Have to return a non nil response as the normal API does that. - return c, &http.Response{}, nil + return c, nil } // Reader attempts to read a message from the connection. diff --git a/ws_js_test.go b/ws_js_test.go index 9330b411..ea888b59 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -23,7 +23,7 @@ func TestEcho(t *testing.T) { defer c.Close(websocket.StatusInternalError, "") assertSubprotocol(t, c, "echo") - assert.Equalf(t, &http.Response{}, resp, "unexpected http response") + assert.Equalf(t, &http.Response{}, resp, "http.Response") assertJSONEcho(t, ctx, c, 1024) assertEcho(t, ctx, c, websocket.MessageBinary, 1024) From 7ad15141157fc06bdeb2505085811563f182688d Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 28 Nov 2019 19:13:42 -0500 Subject: [PATCH 15/55] Update README.md comparison --- README.md | 136 +++++++++++++++++++----------------------------------- close.go | 17 ++++--- conn.go | 26 +++++------ read.go | 4 +- 4 files changed, 70 insertions(+), 113 deletions(-) diff --git a/README.md b/README.md index c927e8c1..477a59ff 100644 --- a/README.md +++ b/README.md @@ -16,17 +16,17 @@ go get nhooyr.io/websocket ## Features - Minimal and idiomatic API -- Tiny codebase at 2200 lines - First class [context.Context](https://blog.golang.org/context) support - Thorough tests, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - [Zero dependencies](https://godoc.org/nhooyr.io/websocket?imports) - JSON and ProtoBuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages -- Highly optimized by default - - Zero alloc reads and writes -- Concurrent writes out of the box -- [Complete Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) support -- [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) -- Full support of [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression extension +- Zero alloc reads and writes +- Concurrent writes +- WebSocket [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) +- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper +- WebSocket [Pings](https://godoc.org/nhooyr.io/websocket#Conn.Ping) +- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression +- [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) ## Roadmap @@ -34,11 +34,7 @@ go get nhooyr.io/websocket ## Examples -For a production quality example that shows off the full API, see the [echo example on the godoc](https://godoc.org/nhooyr.io/websocket#example-package--Echo). On github, the example is at [example_echo_test.go](./example_echo_test.go). - -Use the [errors.As](https://golang.org/pkg/errors/#As) function [new in Go 1.13](https://golang.org/doc/go1.13#error_wrapping) to check for [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError). -There is also [websocket.CloseStatus](https://godoc.org/nhooyr.io/websocket#CloseStatus) to quickly grab the close status code out of a [websocket.CloseError](https://godoc.org/nhooyr.io/websocket#CloseError). -See the [CloseStatus godoc example](https://godoc.org/nhooyr.io/websocket#example-CloseStatus). +For a production quality example that demonstrates the full API, see the [echo example](https://godoc.org/nhooyr.io/websocket#example-package--Echo). ### Server @@ -87,83 +83,45 @@ c.Close(websocket.StatusNormalClosure, "") ## Comparison -Before the comparison, I want to point out that gorilla/websocket was extremely useful in implementing the -WebSocket protocol correctly so _big thanks_ to its authors. In particular, I made sure to go through the -issue tracker of gorilla/websocket to ensure I implemented details correctly and understood how people were -using WebSockets in production. - -### gorilla/websocket - -https://github.com/gorilla/websocket - -The implementation of gorilla/websocket is 6 years old. As such, it is -widely used and very mature compared to nhooyr.io/websocket. - -On the other hand, it has grown organically and now there are too many ways to do -the same thing. Compare the godoc of -[nhooyr/websocket](https://godoc.org/nhooyr.io/websocket) with -[gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side. - -The API for nhooyr.io/websocket has been designed such that there is only one way to do things. -This makes it easy to use correctly. Not only is the API simpler, the implementation is -only 2200 lines whereas gorilla/websocket is at 3500 lines. That's more code to maintain, -more code to test, more code to document and more surface area for bugs. - -Moreover, nhooyr.io/websocket supports newer Go idioms such as context.Context. -It also uses net/http's Client and ResponseWriter directly for WebSocket handshakes. -gorilla/websocket writes its handshakes to the underlying net.Conn. -Thus it has to reinvent hooks for TLS and proxies and prevents easy support of HTTP/2. - -Some more advantages of nhooyr.io/websocket are that it supports concurrent writes and -makes it very easy to close the connection with a status code and reason. In fact, -nhooyr.io/websocket even implements the complete WebSocket close handshake for you whereas -with gorilla/websocket you have to perform it manually. See [gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448). - -The ping API is also nicer. gorilla/websocket requires registering a pong handler on the Conn -which results in awkward control flow. With nhooyr.io/websocket you use the Ping method on the Conn -that sends a ping and also waits for the pong. - -Additionally, nhooyr.io/websocket can compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) for the browser. - -In terms of performance, the differences mostly depend on your application code. nhooyr.io/websocket -reuses message buffers out of the box if you use the wsjson and wspb subpackages. -As mentioned above, nhooyr.io/websocket also supports concurrent writers. - -The WebSocket masking algorithm used by this package is [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) -faster than gorilla/websocket while using only pure safe Go. - -The [permessage-deflate compression extension](https://tools.ietf.org/html/rfc7692) is fully supported by this library -whereas gorilla only supports no context takeover mode. See our godoc for the differences. This will make a big -difference on bandwidth used in most use cases. - -The only performance con to nhooyr.io/websocket is that it uses a goroutine to support -cancellation with context.Context. This costs 2 KB of memory which is cheap compared to -the benefits. - -### x/net/websocket - -https://godoc.org/golang.org/x/net/websocket - -Unmaintained and the API does not reflect WebSocket semantics. Should never be used. - -See https://github.com/golang/go/issues/18152 - -### gobwas/ws - -https://github.com/gobwas/ws - -This library has an extremely flexible API but that comes at the cost of usability -and clarity. - -Due to its flexibility, it can be used in a event driven style for performance. -Definitely check out his fantastic [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb) about performant WebSocket servers. - -If you want a library that gives you absolute control over everything, this is the library. -But for 99.9% of use cases, nhooyr.io/websocket will fit better as it is both easier and -faster for normal idiomatic Go. The masking implementation is [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) -faster, the compression extensions are fully supported and as much as possible is reused by default. - -See the gorilla/websocket comparison for more performance details. +### [gorilla/websocket](https://github.com/gorilla/websocket) + +Advantages of nhooyr.io/websocket: + - Minimal and idiomatic API + - Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side. + - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper + - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) + - Full [context.Context](https://blog.golang.org/context) support + - Uses [net/http.Client](https://golang.org/pkg/net/http/#Client) for dialing + - Will enable easy HTTP/2 support in the future + - Gorilla writes directly to a net.Conn and so duplicates features from net/http.Client. + - Concurrent writes + - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) + - Idiomatic [ping](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API + - gorilla/websocket requires registering a pong callback and then sending a Ping + - Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) + - Transparent buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages + - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go + - Gorilla's implementation depends on unsafe and is slower + - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support + - Gorilla only supports no context takeover mode + - [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper + - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) + +Advantages of gorilla/websocket: + - Widely used and mature + +### [x/net/websocket](https://godoc.org/golang.org/x/net/websocket) + +Deprecated. See ([golang/go/issues/18152](https://github.com/golang/go/issues/18152)). + +The [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper will ease in transitioning to nhooyr.io/websocket. + +### [gobwas/ws](https://github.com/gobwas/ws) + +This library has an extremely flexible API that allows it to be used in an unidiomatic event driven style +for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). + +When writing idiomatic Go, nhooyr.io/websocket is a better choice as it will be faster and easier to use. ## Users diff --git a/close.go b/close.go index baa1a7e0..a02dc7d9 100644 --- a/close.go +++ b/close.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log" + "nhooyr.io/websocket/internal/errd" "time" "nhooyr.io/websocket/internal/bpool" @@ -96,15 +97,13 @@ func CloseStatus(err error) StatusCode { // Close will unblock all goroutines interacting with the connection once // complete. func (c *Conn) Close(code StatusCode, reason string) error { - err := c.closeHandshake(code, reason) - if err != nil { - return fmt.Errorf("failed to close WebSocket: %w", err) - } - return nil + return c.closeHandshake(code, reason) } -func (c *Conn) closeHandshake(code StatusCode, reason string) error { - err := c.cw.sendClose(code, reason) +func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + err = c.cw.sendClose(code, reason) if err != nil { return err } @@ -115,7 +114,7 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) error { func (cw *connWriter) error(code StatusCode, err error) { cw.c.setCloseErr(err) cw.sendClose(code, err.Error()) - cw.c.close(nil) + cw.c.closeWithErr(nil) } func (cw *connWriter) sendClose(code StatusCode, reason string) error { @@ -135,7 +134,7 @@ func (cw *connWriter) sendClose(code StatusCode, reason string) error { } func (cr *connReader) waitClose() error { - defer cr.c.close(nil) + defer cr.c.closeWithErr(nil) return nil diff --git a/conn.go b/conn.go index 5c041b8d..d9001791 100644 --- a/conn.go +++ b/conn.go @@ -33,11 +33,10 @@ const ( // frames will not be handled. See the docs on Reader and CloseRead. // // Be sure to call Close on the connection when you -// are finished with it to release the associated resources. +// are finished with it to release associated resources. // -// Every error from Read or Reader will cause the connection -// to be closed so you do not need to write your own error message. -// This applies to the Read methods in the wsjson/wspb subpackages as well. +// On any error from any method, the connection is closed +// with an appropriate reason. type Conn struct { subprotocol string rwc io.ReadWriteCloser @@ -69,11 +68,12 @@ type connConfig struct { } func newConn(cfg connConfig) *Conn { - c := &Conn{} - c.subprotocol = cfg.subprotocol - c.rwc = cfg.rwc - c.client = cfg.client - c.copts = cfg.copts + c := &Conn{ + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + } c.cr.init(c, cfg.br) c.cw.init(c, cfg.bw) @@ -82,7 +82,7 @@ func newConn(cfg connConfig) *Conn { c.activePings = make(map[string]chan<- struct{}) runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) + c.closeWithErr(errors.New("connection garbage collected")) }) go c.timeoutLoop() @@ -96,7 +96,7 @@ func (c *Conn) Subprotocol() string { return c.subprotocol } -func (c *Conn) close(err error) { +func (c *Conn) closeWithErr(err error) { c.closeMu.Lock() defer c.closeMu.Unlock() @@ -135,7 +135,7 @@ func (c *Conn) timeoutLoop() { c.cw.error(StatusPolicyViolation, errors.New("timed out")) return case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + c.closeWithErr(fmt.Errorf("write timed out: %w", writeCtx.Err())) return } } @@ -185,7 +185,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { return c.closeErr case <-ctx.Done(): err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) + c.closeWithErr(err) return err case <-pong: return nil diff --git a/read.go b/read.go index 13c8d703..7dba832a 100644 --- a/read.go +++ b/read.go @@ -199,7 +199,7 @@ func (cr *connReader) frameHeader(ctx context.Context) (header, error) { case <-ctx.Done(): return header{}, ctx.Err() default: - cr.c.close(err) + cr.c.closeWithErr(err) return header{}, err } } @@ -229,7 +229,7 @@ func (cr *connReader) framePayload(ctx context.Context, p []byte) (int, error) { return n, ctx.Err() default: err = fmt.Errorf("failed to read frame payload: %w", err) - cr.c.close(err) + cr.c.closeWithErr(err) return n, err } } From 746140b8b5604d895bea36e23cf511e749dfd66c Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 28 Nov 2019 19:23:14 -0500 Subject: [PATCH 16/55] Further improve README --- README.md | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 477a59ff..efb4a592 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,11 @@ go get nhooyr.io/websocket - JSON and ProtoBuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes - Concurrent writes -- WebSocket [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) +- [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper -- WebSocket [Pings](https://godoc.org/nhooyr.io/websocket#Conn.Ping) +- [Pings](https://godoc.org/nhooyr.io/websocket#Conn.Ping) - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression -- [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) +- Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) ## Roadmap @@ -83,7 +83,9 @@ c.Close(websocket.StatusNormalClosure, "") ## Comparison -### [gorilla/websocket](https://github.com/gorilla/websocket) +### gorilla/websocket + +[gorilla/websocket](https://github.com/gorilla/websocket) is a widely used and mature library. Advantages of nhooyr.io/websocket: - Minimal and idiomatic API @@ -103,25 +105,24 @@ Advantages of nhooyr.io/websocket: - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - Gorilla's implementation depends on unsafe and is slower - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - - Gorilla only supports no context takeover mode + - Gorilla only supports no context takeover mode - [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) -Advantages of gorilla/websocket: - - Widely used and mature - -### [x/net/websocket](https://godoc.org/golang.org/x/net/websocket) +#### golang.org/x/net/websocket -Deprecated. See ([golang/go/issues/18152](https://github.com/golang/go/issues/18152)). +[golang.org/x/net/websocket](https://godoc.org/golang.org/x/net/websocket) is deprecated. +See ([golang/go/issues/18152](https://github.com/golang/go/issues/18152)). -The [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper will ease in transitioning to nhooyr.io/websocket. +The [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper will ease in transitioning +to nhooyr.io/websocket. -### [gobwas/ws](https://github.com/gobwas/ws) +#### gobwas/ws -This library has an extremely flexible API that allows it to be used in an unidiomatic event driven style -for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). +[gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used +in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). -When writing idiomatic Go, nhooyr.io/websocket is a better choice as it will be faster and easier to use. +However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. ## Users From 43cb01eaf9fad1e2052a18b69b777db62820aae7 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 29 Nov 2019 00:00:52 -0500 Subject: [PATCH 17/55] Refactor read.go/write.go --- README.md | 43 +++--- assert_test.go | 13 +- close.go | 64 +++++---- conn.go | 92 ++++++++++--- conn_test.go | 3 +- internal/assert/assert.go | 2 +- read.go | 266 ++++++++++++++++---------------------- write.go | 215 +++++++++++++----------------- wsjson/wsjson.go | 1 - 9 files changed, 345 insertions(+), 354 deletions(-) diff --git a/README.md b/README.md index efb4a592..f0babdfc 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ go get nhooyr.io/websocket - Concurrent writes - [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper -- [Pings](https://godoc.org/nhooyr.io/websocket#Conn.Ping) +- [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression - Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) @@ -88,26 +88,27 @@ c.Close(websocket.StatusNormalClosure, "") [gorilla/websocket](https://github.com/gorilla/websocket) is a widely used and mature library. Advantages of nhooyr.io/websocket: - - Minimal and idiomatic API - - Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side. - - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper - - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - - Full [context.Context](https://blog.golang.org/context) support - - Uses [net/http.Client](https://golang.org/pkg/net/http/#Client) for dialing - - Will enable easy HTTP/2 support in the future - - Gorilla writes directly to a net.Conn and so duplicates features from net/http.Client. - - Concurrent writes - - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) - - Idiomatic [ping](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API - - gorilla/websocket requires registering a pong callback and then sending a Ping - - Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) - - Transparent buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages - - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - - Gorilla's implementation depends on unsafe and is slower - - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support + +- Minimal and idiomatic API + - Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side. +- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper +- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) +- Full [context.Context](https://blog.golang.org/context) support +- Uses [net/http.Client](https://golang.org/pkg/net/http/#Client) for dialing + - Will enable easy HTTP/2 support in the future + - Gorilla writes directly to a net.Conn and so duplicates features from net/http.Client. +- Concurrent writes +- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) +- Idiomatic [ping](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API + - gorilla/websocket requires registering a pong callback and then sending a Ping +- Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) +- Transparent message buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages +- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go + - Gorilla's implementation depends on unsafe and is slower +- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - - [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper - - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) +- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper +- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) #### golang.org/x/net/websocket @@ -120,7 +121,7 @@ to nhooyr.io/websocket. #### gobwas/ws [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used -in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). +in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. diff --git a/assert_test.go b/assert_test.go index b6e50a47..e4319938 100644 --- a/assert_test.go +++ b/assert_test.go @@ -4,12 +4,11 @@ import ( "context" "crypto/rand" "io" - "strings" - "testing" - "nhooyr.io/websocket" "nhooyr.io/websocket/internal/assert" "nhooyr.io/websocket/wsjson" + "strings" + "testing" ) func randBytes(t *testing.T, n int) []byte { @@ -21,12 +20,15 @@ func randBytes(t *testing.T, n int) []byte { func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) { t.Helper() + defer c.Close(websocket.StatusInternalError, "") exp := randString(t, n) err := wsjson.Write(ctx, c, exp) assert.Success(t, err) assertJSONRead(t, ctx, c, exp) + + c.Close(websocket.StatusNormalClosure, "") } func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { @@ -74,5 +76,10 @@ func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) { t.Helper() + defer func() { + if t.Failed() { + t.Logf("error: %+v", err) + } + }() assert.Equal(t, exp, websocket.CloseStatus(err), "StatusCode") } diff --git a/close.go b/close.go index a02dc7d9..4c474b78 100644 --- a/close.go +++ b/close.go @@ -7,9 +7,6 @@ import ( "fmt" "log" "nhooyr.io/websocket/internal/errd" - "time" - - "nhooyr.io/websocket/internal/bpool" ) // StatusCode represents a WebSocket status code. @@ -103,59 +100,58 @@ func (c *Conn) Close(code StatusCode, reason string) error { func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") - err = c.cw.sendClose(code, reason) + err = c.writeClose(code, reason) if err != nil { return err } - return c.cr.waitClose() + return c.waitClose() } -func (cw *connWriter) error(code StatusCode, err error) { - cw.c.setCloseErr(err) - cw.sendClose(code, err.Error()) - cw.c.closeWithErr(nil) +func (c *Conn) writeError(code StatusCode, err error) { + c.setCloseErr(err) + c.writeClose(code, err.Error()) + c.closeWithErr(nil) } -func (cw *connWriter) sendClose(code StatusCode, reason string) error { +func (c *Conn) writeClose(code StatusCode, reason string) error { ce := CloseError{ Code: code, Reason: reason, } - cw.c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) var p []byte if ce.Code != StatusNoStatusRcvd { p = ce.bytes() } - return cw.control(context.Background(), opClose, p) + return c.writeControl(context.Background(), opClose, p) } -func (cr *connReader) waitClose() error { - defer cr.c.closeWithErr(nil) +func (c *Conn) waitClose() error { + defer c.closeWithErr(nil) return nil - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - err := cr.mu.Lock(ctx) - if err != nil { - return err - } - defer cr.mu.Unlock() - - b := bpool.Get() - buf := b.Bytes() - buf = buf[:cap(buf)] - defer bpool.Put(b) - - for { - // TODO - return nil - } + // ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + // defer cancel() + // + // err := cr.mu.Lock(ctx) + // if err != nil { + // return err + // } + // defer cr.mu.Unlock() + // + // b := bpool.Get() + // buf := b.Bytes() + // buf = buf[:cap(buf)] + // defer bpool.Put(b) + // + // for { + // return nil + // } } func parseClosePayload(p []byte) (CloseError, error) { @@ -230,11 +226,11 @@ func (ce CloseError) bytesErr() ([]byte, error) { func (c *Conn) setCloseErr(err error) { c.closeMu.Lock() - c.setCloseErrNoLock(err) + c.setCloseErrLocked(err) c.closeMu.Unlock() } -func (c *Conn) setCloseErrNoLock(err error) { +func (c *Conn) setCloseErrLocked(err error) { if c.closeErr == nil { c.closeErr = fmt.Errorf("WebSocket closed: %w", err) } diff --git a/conn.go b/conn.go index d9001791..dc067d18 100644 --- a/conn.go +++ b/conn.go @@ -30,7 +30,7 @@ const ( // All methods may be called concurrently except for Reader and Read. // // You must always read from the connection. Otherwise control -// frames will not be handled. See the docs on Reader and CloseRead. +// frames will not be handled. See Reader and CloseRead. // // Be sure to call Close on the connection when you // are finished with it to release associated resources. @@ -42,9 +42,22 @@ type Conn struct { rwc io.ReadWriteCloser client bool copts *compressionOptions + br *bufio.Reader + bw *bufio.Writer - cr connReader - cw connWriter + readTimeout chan context.Context + writeTimeout chan context.Context + + // Read state. + readMu mu + readControlBuf [maxControlPayload]byte + msgReader *msgReader + + // Write state. + msgWriter *msgWriter + writeFrameMu mu + writeBuf []byte + writeHeader header closed chan struct{} @@ -63,8 +76,8 @@ type connConfig struct { client bool copts *compressionOptions - bw *bufio.Writer br *bufio.Reader + bw *bufio.Writer } func newConn(cfg connConfig) *Conn { @@ -73,13 +86,23 @@ func newConn(cfg connConfig) *Conn { rwc: cfg.rwc, client: cfg.client, copts: cfg.copts, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), } - c.cr.init(c, cfg.br) - c.cw.init(c, cfg.bw) + c.msgReader = newMsgReader(c) - c.closed = make(chan struct{}) - c.activePings = make(map[string]chan<- struct{}) + c.msgWriter = newMsgWriter(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } runtime.SetFinalizer(c, func(c *Conn) { c.closeWithErr(errors.New("connection garbage collected")) @@ -90,6 +113,34 @@ func newConn(cfg connConfig) *Conn { return c } +func newMsgReader(c *Conn) *msgReader { + mr := &msgReader{ + c: c, + fin: true, + } + + mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768) + if c.deflateNegotiated() && mr.contextTakeover() { + mr.ensureFlateReader() + } + + return mr +} + +func newMsgWriter(c *Conn) *msgWriter { + mw := &msgWriter{ + c: c, + } + mw.trimWriter = &trimLastFourBytesWriter{ + w: writerFunc(mw.write), + } + if c.deflateNegotiated() && mw.contextTakeover() { + mw.ensureFlateWriter() + } + + return mw +} + // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. func (c *Conn) Subprotocol() string { @@ -105,7 +156,7 @@ func (c *Conn) closeWithErr(err error) { } close(c.closed) runtime.SetFinalizer(c, nil) - c.setCloseErrNoLock(err) + c.setCloseErrLocked(err) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns @@ -113,8 +164,18 @@ func (c *Conn) closeWithErr(err error) { c.rwc.Close() go func() { - c.cr.close() - c.cw.close() + if c.client { + c.writeFrameMu.Lock(context.Background()) + putBufioWriter(c.bw) + } + c.msgWriter.close() + + if c.client { + c.readMu.Lock(context.Background()) + putBufioReader(c.br) + c.readMu.Unlock() + } + c.msgReader.close() }() } @@ -127,13 +188,12 @@ func (c *Conn) timeoutLoop() { case <-c.closed: return - case writeCtx = <-c.cw.timeout: - case readCtx = <-c.cr.timeout: + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: case <-readCtx.Done(): c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - c.cw.error(StatusPolicyViolation, errors.New("timed out")) - return + go c.writeError(StatusPolicyViolation, errors.New("timed out")) case <-writeCtx.Done(): c.closeWithErr(fmt.Errorf("write timed out: %w", writeCtx.Err())) return @@ -175,7 +235,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { c.activePingsMu.Unlock() }() - err := c.cw.control(ctx, opPing, []byte(p)) + err := c.writeControl(ctx, opPing, []byte(p)) if err != nil { return err } diff --git a/conn_test.go b/conn_test.go index 6b8a778b..cf2334f7 100644 --- a/conn_test.go +++ b/conn_test.go @@ -25,6 +25,7 @@ func TestConn(t *testing.T) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, + // CompressionMode: websocket.CompressionDisabled, }) assert.Success(t, err) defer c.Close(websocket.StatusInternalError, "") @@ -41,12 +42,12 @@ func TestConn(t *testing.T) { opts := &websocket.DialOptions{ Subprotocols: []string{"echo"}, + // CompressionMode: websocket.CompressionDisabled, } opts.HTTPClient = s.Client() c, _, err := websocket.Dial(ctx, wsURL, opts) assert.Success(t, err) - assertJSONEcho(t, ctx, c, 2) }) } diff --git a/internal/assert/assert.go b/internal/assert/assert.go index 4ebdb511..b448711a 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -23,7 +23,7 @@ func NotEqual(t testing.TB, exp, act interface{}, name string) { func Success(t testing.TB, err error) { t.Helper() if err != nil { - t.Fatalf("unexpected error : %+v", err) + t.Fatalf("unexpected error: %+v", err) } } diff --git a/read.go b/read.go index 7dba832a..d8691d65 100644 --- a/read.go +++ b/read.go @@ -1,7 +1,6 @@ package websocket import ( - "bufio" "context" "errors" "fmt" @@ -14,41 +13,22 @@ import ( "nhooyr.io/websocket/internal/errd" ) -// Reader waits until there is a WebSocket data message to read -// from the connection. -// It returns the type of the message and a reader to read it. +// Reader reads from the connection until until there is a WebSocket +// data message to be read. It will handle ping, pong and close frames as appropriate. +// +// It returns the type of the message and an io.Reader to read it. // The passed context will also bound the reader. // Ensure you read to EOF otherwise the connection will hang. // -// All returned errors will cause the connection -// to be closed so you do not need to write your own error message. -// This applies to the Read methods in the wsjson/wspb subpackages as well. -// -// You must read from the connection for control frames to be handled. -// Thus if you expect messages to take a long time to be responded to, -// you should handle such messages async to reading from the connection -// to ensure control frames are promptly handled. -// -// If you do not expect any data messages from the peer, call CloseRead. +// Call CloseRead if you do not expect any data messages from the peer. // // Only one Reader may be open at a time. -// -// If you need a separate timeout on the Reader call and then the message -// Read, use time.AfterFunc to cancel the context passed in early. -// See https://github.com/nhooyr/websocket/issues/87#issue-451703332 -// Most users should not need this. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { - typ, r, err := c.cr.reader(ctx) - if err != nil { - return 0, nil, fmt.Errorf("failed to get reader: %w", err) - } - return typ, r, nil + return c.reader(ctx) } -// Read is a convenience method to read a single message from the connection. -// -// See the Reader method to reuse buffers or for streaming. -// The docs on Reader apply to this method as well. +// Read is a convenience method around Reader to read a single message +// from the connection. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { typ, r, err := c.Reader(ctx) if err != nil { @@ -59,14 +39,17 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { return typ, b, err } -// CloseRead will start a goroutine to read from the connection until it is closed or a data message -// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. -// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. -// After calling this method, you cannot read any data messages from the connection. +// CloseRead starts a goroutine to read from the connection until it is closed +// or a data message is received. +// +// Once CloseRead is called you cannot read any messages from the connection. // The returned context will be cancelled when the connection is closed. // -// Use this when you do not want to read data messages from the connection anymore but will -// want to write messages to it. +// If a data message is received, the connection will be closed with StatusPolicyViolation. +// +// Call CloseRead when you do not expect to read any more messages. +// Since it actively reads from the connection, it will ensure that ping, pong and close +// frames are responded to. func (c *Conn) CloseRead(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) go func() { @@ -84,60 +67,32 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // // When the limit is hit, the connection will be closed with StatusMessageTooBig. func (c *Conn) SetReadLimit(n int64) { - c.cr.mr.lr.limit.Store(n) -} - -type connReader struct { - c *Conn - br *bufio.Reader - timeout chan context.Context - - mu mu - controlPayloadBuf [maxControlPayload]byte - mr *msgReader -} - -func (cr *connReader) init(c *Conn, br *bufio.Reader) { - cr.c = c - cr.br = br - cr.timeout = make(chan context.Context) - - cr.mr = &msgReader{ - cr: cr, - fin: true, - } - - cr.mr.lr = newLimitReader(c, readerFunc(cr.mr.read), 32768) - if c.deflateNegotiated() && cr.contextTakeover() { - cr.ensureFlateReader() - } + c.msgReader.limitReader.setLimit(n) } -func (cr *connReader) ensureFlateReader() { - cr.mr.fr = getFlateReader(readerFunc(cr.mr.read)) - cr.mr.lr.reset(cr.mr.fr) +func (mr *msgReader) ensureFlateReader() { + mr.flateReader = getFlateReader(readerFunc(mr.read)) + mr.limitReader.reset(mr.flateReader) } -func (cr *connReader) close() { - cr.mu.Lock(context.Background()) - if cr.c.client { - putBufioReader(cr.br) - } - if cr.c.deflateNegotiated() && cr.contextTakeover() { - putFlateReader(cr.mr.fr) +func (mr *msgReader) close() { + if mr.c.deflateNegotiated() && mr.contextTakeover() { + mr.c.readMu.Lock(context.Background()) + putFlateReader(mr.flateReader) + mr.c.readMu.Unlock() } } -func (cr *connReader) contextTakeover() bool { - if cr.c.client { - return cr.c.copts.serverNoContextTakeover +func (mr *msgReader) contextTakeover() bool { + if mr.c.client { + return mr.c.copts.serverNoContextTakeover } - return cr.c.copts.clientNoContextTakeover + return mr.c.copts.clientNoContextTakeover } -func (cr *connReader) rsv1Illegal(h header) bool { +func (c *Conn) readRSV1Illegal(h header) bool { // If compression is enabled, rsv1 is always illegal. - if !cr.c.deflateNegotiated() { + if !c.deflateNegotiated() { return true } // rsv1 is only allowed on data frames beginning messages. @@ -147,26 +102,26 @@ func (cr *connReader) rsv1Illegal(h header) bool { return false } -func (cr *connReader) loop(ctx context.Context) (header, error) { +func (c *Conn) readLoop(ctx context.Context) (header, error) { for { - h, err := cr.frameHeader(ctx) + h, err := c.readFrameHeader(ctx) if err != nil { return header{}, err } - if h.rsv1 && cr.rsv1Illegal(h) || h.rsv2 || h.rsv3 { + if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) - cr.c.cw.error(StatusProtocolError, err) + c.writeError(StatusProtocolError, err) return header{}, err } - if !cr.c.client && !h.masked { + if !c.client && !h.masked { return header{}, errors.New("received unmasked frame from client") } switch h.opcode { case opClose, opPing, opPong: - err = cr.control(ctx, h) + err = c.handleControl(ctx, h) if err != nil { // Pass through CloseErrors when receiving a close frame. if h.opcode == opClose && CloseStatus(err) != -1 { @@ -178,95 +133,89 @@ func (cr *connReader) loop(ctx context.Context) (header, error) { return h, nil default: err := fmt.Errorf("received unknown opcode %v", h.opcode) - cr.c.cw.error(StatusProtocolError, err) + c.writeError(StatusProtocolError, err) return header{}, err } } } -func (cr *connReader) frameHeader(ctx context.Context) (header, error) { +func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { - case <-cr.c.closed: - return header{}, cr.c.closeErr - case cr.timeout <- ctx: + case <-c.closed: + return header{}, c.closeErr + case c.readTimeout <- ctx: } - h, err := readFrameHeader(cr.br) + h, err := readFrameHeader(c.br) if err != nil { select { - case <-cr.c.closed: - return header{}, cr.c.closeErr + case <-c.closed: + return header{}, c.closeErr case <-ctx.Done(): return header{}, ctx.Err() default: - cr.c.closeWithErr(err) + c.closeWithErr(err) return header{}, err } } select { - case <-cr.c.closed: - return header{}, cr.c.closeErr - case cr.timeout <- context.Background(): + case <-c.closed: + return header{}, c.closeErr + case c.readTimeout <- context.Background(): } return h, nil } -func (cr *connReader) framePayload(ctx context.Context, p []byte) (int, error) { +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { - case <-cr.c.closed: - return 0, cr.c.closeErr - case cr.timeout <- ctx: + case <-c.closed: + return 0, c.closeErr + case c.readTimeout <- ctx: } - n, err := io.ReadFull(cr.br, p) + n, err := io.ReadFull(c.br, p) if err != nil { select { - case <-cr.c.closed: - return n, cr.c.closeErr + case <-c.closed: + return n, c.closeErr case <-ctx.Done(): return n, ctx.Err() default: err = fmt.Errorf("failed to read frame payload: %w", err) - cr.c.closeWithErr(err) + c.closeWithErr(err) return n, err } } select { - case <-cr.c.closed: - return n, cr.c.closeErr - case cr.timeout <- context.Background(): + case <-c.closed: + return n, c.closeErr + case c.readTimeout <- context.Background(): } return n, err } -func (cr *connReader) control(ctx context.Context, h header) error { - if h.payloadLength < 0 { - err := fmt.Errorf("received header with negative payload length: %v", h.payloadLength) - cr.c.cw.error(StatusProtocolError, err) - return err - } - - if h.payloadLength > maxControlPayload { - err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength) - cr.c.cw.error(StatusProtocolError, err) +func (c *Conn) handleControl(ctx context.Context, h header) error { + if h.payloadLength < 0 || h.payloadLength > maxControlPayload { + err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) + c.writeError(StatusProtocolError, err) return err } if !h.fin { err := errors.New("received fragmented control frame") - cr.c.cw.error(StatusProtocolError, err) + c.writeError(StatusProtocolError, err) return err } ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - b := cr.controlPayloadBuf[:h.payloadLength] - _, err := cr.framePayload(ctx, b) + b := c.readControlBuf[:h.payloadLength] + _, err := c.readFramePayload(ctx, b) if err != nil { return err } @@ -277,11 +226,11 @@ func (cr *connReader) control(ctx context.Context, h header) error { switch h.opcode { case opPing: - return cr.c.cw.control(ctx, opPong, b) + return c.writeControl(ctx, opPong, b) case opPong: - cr.c.activePingsMu.Lock() - pong, ok := cr.c.activePings[string(b)] - cr.c.activePingsMu.Unlock() + c.activePingsMu.Lock() + pong, ok := c.activePings[string(b)] + c.activePingsMu.Unlock() if ok { close(pong) } @@ -291,53 +240,56 @@ func (cr *connReader) control(ctx context.Context, h header) error { ce, err := parseClosePayload(b) if err != nil { err = fmt.Errorf("received invalid close payload: %w", err) - cr.c.cw.error(StatusProtocolError, err) + c.writeError(StatusProtocolError, err) return err } err = fmt.Errorf("received close frame: %w", ce) - cr.c.setCloseErr(err) - cr.c.cw.control(context.Background(), opClose, ce.bytes()) + c.setCloseErr(err) + c.writeControl(context.Background(), opClose, ce.bytes()) return err } -func (cr *connReader) reader(ctx context.Context) (MessageType, io.Reader, error) { - err := cr.mu.Lock(ctx) +func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { + defer errd.Wrap(&err, "failed to get reader") + + err = c.readMu.Lock(ctx) if err != nil { return 0, nil, err } - defer cr.mu.Unlock() + defer c.readMu.Unlock() - if !cr.mr.fin { + if !c.msgReader.fin { return 0, nil, errors.New("previous message not read to completion") } - h, err := cr.loop(ctx) + h, err := c.readLoop(ctx) if err != nil { return 0, nil, err } if h.opcode == opContinuation { err := errors.New("received continuation frame without text or binary frame") - cr.c.cw.error(StatusProtocolError, err) + c.writeError(StatusProtocolError, err) return 0, nil, err } - cr.mr.reset(ctx, h) + c.msgReader.reset(ctx, h) - return MessageType(h.opcode), cr.mr, nil + return MessageType(h.opcode), c.msgReader, nil } type msgReader struct { - cr *connReader - fr io.Reader - lr *limitReader + c *Conn ctx context.Context deflate bool + flateReader io.Reader deflateTail strings.Reader + limitReader *limitReader + payloadLength int64 maskKey uint32 fin bool @@ -348,8 +300,8 @@ func (mr *msgReader) reset(ctx context.Context, h header) { mr.deflate = h.rsv1 if mr.deflate { mr.deflateTail.Reset(deflateMessageTail) - if !mr.cr.contextTakeover() { - mr.cr.ensureFlateReader() + if !mr.contextTakeover() { + mr.ensureFlateReader() } } mr.setFrame(h) @@ -370,34 +322,42 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) { } }() - err = mr.cr.mu.Lock(mr.ctx) + err = mr.c.readMu.Lock(mr.ctx) if err != nil { return 0, err } - defer mr.cr.mu.Unlock() + defer mr.c.readMu.Unlock() if mr.payloadLength == 0 && mr.fin { - if mr.cr.c.deflateNegotiated() && !mr.cr.contextTakeover() { - if mr.fr != nil { - putFlateReader(mr.fr) - mr.fr = nil + if mr.c.deflateNegotiated() && !mr.contextTakeover() { + if mr.flateReader != nil { + putFlateReader(mr.flateReader) + mr.flateReader = nil } } return 0, io.EOF } - return mr.lr.Read(p) + return mr.limitReader.Read(p) } func (mr *msgReader) read(p []byte) (int, error) { if mr.payloadLength == 0 { - h, err := mr.cr.loop(mr.ctx) + if mr.fin { + if mr.deflate { + n, _ := mr.deflateTail.Read(p[:4]) + return n, nil + } + return 0, io.EOF + } + + h, err := mr.c.readLoop(mr.ctx) if err != nil { return 0, err } if h.opcode != opContinuation { err := errors.New("received new data message without finishing the previous message") - mr.cr.c.cw.error(StatusProtocolError, err) + mr.c.writeError(StatusProtocolError, err) return 0, err } mr.setFrame(h) @@ -407,14 +367,14 @@ func (mr *msgReader) read(p []byte) (int, error) { p = p[:mr.payloadLength] } - n, err := mr.cr.framePayload(mr.ctx, p) + n, err := mr.c.readFramePayload(mr.ctx, p) if err != nil { return n, err } mr.payloadLength -= int64(n) - if !mr.cr.c.client { + if !mr.c.client { mr.maskKey = mask(mr.maskKey, p) } @@ -442,10 +402,14 @@ func (lr *limitReader) reset(r io.Reader) { lr.r = r } +func (lr *limitReader) setLimit(limit int64) { + lr.limit.Store(limit) +} + func (lr *limitReader) Read(p []byte) (int, error) { if lr.n <= 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) - lr.c.cw.error(StatusMessageTooBig, err) + lr.c.writeError(StatusMessageTooBig, err) return 0, err } diff --git a/write.go b/write.go index 9cafc5c5..0ddf11e1 100644 --- a/write.go +++ b/write.go @@ -24,7 +24,7 @@ import ( // // Never close the returned writer twice. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - w, err := c.cw.writer(ctx, typ) + w, err := c.writer(ctx, typ) if err != nil { return nil, fmt.Errorf("failed to get writer: %w", err) } @@ -38,111 +38,68 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err // If compression is disabled, then it is guaranteed to write the message // in a single frame. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - _, err := c.cw.write(ctx, typ, p) + _, err := c.write(ctx, typ, p) if err != nil { return fmt.Errorf("failed to write msg: %w", err) } return nil } -type connWriter struct { - c *Conn - bw *bufio.Writer - - writeBuf []byte - - mw *messageWriter - frameMu mu - h header - - timeout chan context.Context +func (mw *msgWriter) ensureFlateWriter() { + mw.flateWriter = getFlateWriter(mw.trimWriter) } -func (cw *connWriter) init(c *Conn, bw *bufio.Writer) { - cw.c = c - cw.bw = bw - - if cw.c.client { - cw.writeBuf = extractBufioWriterBuf(cw.bw, c.rwc) - } - - cw.timeout = make(chan context.Context) - - cw.mw = &messageWriter{ - cw: cw, +func (mw *msgWriter) contextTakeover() bool { + if mw.c.client { + return mw.c.copts.clientNoContextTakeover } - cw.mw.tw = &trimLastFourBytesWriter{ - w: writerFunc(cw.mw.write), - } - if cw.c.deflateNegotiated() && cw.mw.contextTakeover() { - cw.mw.ensureFlateWriter() - } -} - -func (mw *messageWriter) ensureFlateWriter() { - mw.fw = getFlateWriter(mw.tw) + return mw.c.copts.serverNoContextTakeover } -func (cw *connWriter) close() { - if cw.c.client { - cw.frameMu.Lock(context.Background()) - putBufioWriter(cw.bw) - } - if cw.c.deflateNegotiated() && cw.mw.contextTakeover() { - cw.mw.mu.Lock(context.Background()) - putFlateWriter(cw.mw.fw) - } -} - -func (mw *messageWriter) contextTakeover() bool { - if mw.cw.c.client { - return mw.cw.c.copts.clientNoContextTakeover - } - return mw.cw.c.copts.serverNoContextTakeover -} - -func (cw *connWriter) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := cw.mw.reset(ctx, typ) +func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + err := c.msgWriter.reset(ctx, typ) if err != nil { return nil, err } - return cw.mw, nil + return c.msgWriter, nil } -func (cw *connWriter) write(ctx context.Context, typ MessageType, p []byte) (int, error) { - ww, err := cw.writer(ctx, typ) +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { + mw, err := c.writer(ctx, typ) if err != nil { return 0, err } - if !cw.c.deflateNegotiated() { + if !c.deflateNegotiated() { // Fast single frame path. - defer cw.mw.mu.Unlock() - return cw.frame(ctx, true, cw.mw.opcode, p) + defer c.msgWriter.mu.Unlock() + return c.writeFrame(ctx, true, c.msgWriter.opcode, p) } - n, err := ww.Write(p) + n, err := mw.Write(p) if err != nil { return n, err } - err = ww.Close() + err = mw.Close() return n, err } -type messageWriter struct { - cw *connWriter +type msgWriter struct { + c *Conn - mu mu - compress bool - tw *trimLastFourBytesWriter - fw *flate.Writer - ctx context.Context - opcode opcode - closed bool + mu mu + + deflate bool + ctx context.Context + opcode opcode + closed bool + + trimWriter *trimLastFourBytesWriter + flateWriter *flate.Writer } -func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error { +func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { err := mw.mu.Lock(ctx) if err != nil { return err @@ -155,30 +112,30 @@ func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error { } // Write writes the given bytes to the WebSocket connection. -func (mw *messageWriter) Write(p []byte) (_ int, err error) { +func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") if mw.closed { return 0, errors.New("cannot use closed writer") } - if mw.cw.c.deflateNegotiated() { - if !mw.compress { + if mw.c.deflateNegotiated() { + if !mw.deflate { if !mw.contextTakeover() { mw.ensureFlateWriter() } - mw.tw.reset() - mw.compress = true + mw.trimWriter.reset() + mw.deflate = true } - return mw.fw.Write(p) + return mw.flateWriter.Write(p) } return mw.write(p) } -func (mw *messageWriter) write(p []byte) (int, error) { - n, err := mw.cw.frame(mw.ctx, false, mw.opcode, p) +func (mw *msgWriter) write(p []byte) (int, error) { + n, err := mw.c.writeFrame(mw.ctx, false, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) } @@ -187,8 +144,7 @@ func (mw *messageWriter) write(p []byte) (int, error) { } // Close flushes the frame to the connection. -// This must be called for every messageWriter. -func (mw *messageWriter) Close() (err error) { +func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") if mw.closed { @@ -196,32 +152,39 @@ func (mw *messageWriter) Close() (err error) { } mw.closed = true - if mw.cw.c.deflateNegotiated() { - err = mw.fw.Flush() + if mw.c.deflateNegotiated() { + err = mw.flateWriter.Flush() if err != nil { return fmt.Errorf("failed to flush flate writer: %w", err) } } - _, err = mw.cw.frame(mw.ctx, true, mw.opcode, nil) + _, err = mw.c.writeFrame(mw.ctx, true, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } - if mw.compress && !mw.contextTakeover() { - putFlateWriter(mw.fw) - mw.compress = false + if mw.deflate && !mw.contextTakeover() { + putFlateWriter(mw.flateWriter) + mw.deflate = false } mw.mu.Unlock() return nil } -func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) error { +func (cw *msgWriter) close() { + if cw.c.deflateNegotiated() && cw.contextTakeover() { + cw.mu.Lock(context.Background()) + putFlateWriter(cw.flateWriter) + } +} + +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - _, err := cw.frame(ctx, true, opcode, p) + _, err := c.writeFrame(ctx, true, opcode, p) if err != nil { return fmt.Errorf("failed to write control frame %v: %w", opcode, err) } @@ -229,94 +192,94 @@ func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) erro } // frame handles all writes to the connection. -func (cw *connWriter) frame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { - err := cw.frameMu.Lock(ctx) +func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { + err := c.writeFrameMu.Lock(ctx) if err != nil { return 0, err } - defer cw.frameMu.Unlock() + defer c.writeFrameMu.Unlock() select { - case <-cw.c.closed: - return 0, cw.c.closeErr - case cw.timeout <- ctx: + case <-c.closed: + return 0, c.closeErr + case c.writeTimeout <- ctx: } - cw.h.fin = fin - cw.h.opcode = opcode - cw.h.masked = cw.c.client - cw.h.payloadLength = int64(len(p)) - - cw.h.rsv1 = false - if cw.mw.compress && (opcode == opText || opcode == opBinary) { - cw.h.rsv1 = true - } + c.writeHeader.fin = fin + c.writeHeader.opcode = opcode + c.writeHeader.payloadLength = int64(len(p)) - if cw.h.masked { - err = binary.Read(rand.Reader, binary.LittleEndian, &cw.h.maskKey) + if c.client { + c.writeHeader.masked = true + err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey) if err != nil { return 0, fmt.Errorf("failed to generate masking key: %w", err) } } - err = writeFrameHeader(cw.h, cw.bw) + c.writeHeader.rsv1 = false + if c.msgWriter.deflate && (opcode == opText || opcode == opBinary) { + c.writeHeader.rsv1 = true + } + + err = writeFrameHeader(c.writeHeader, c.bw) if err != nil { return 0, err } - n, err := cw.framePayload(p) + n, err := c.writeFramePayload(p) if err != nil { return n, err } - if cw.h.fin { - err = cw.bw.Flush() + if c.writeHeader.fin { + err = c.bw.Flush() if err != nil { return n, fmt.Errorf("failed to flush: %w", err) } } select { - case <-cw.c.closed: - return n, cw.c.closeErr - case cw.timeout <- context.Background(): + case <-c.closed: + return n, c.closeErr + case c.writeTimeout <- context.Background(): } return n, nil } -func (cw *connWriter) framePayload(p []byte) (_ int, err error) { +func (c *Conn) writeFramePayload(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write frame payload") - if !cw.h.masked { - return cw.bw.Write(p) + if !c.writeHeader.masked { + return c.bw.Write(p) } var n int - maskKey := cw.h.maskKey + maskKey := c.writeHeader.maskKey for len(p) > 0 { // If the buffer is full, we need to flush. - if cw.bw.Available() == 0 { - err = cw.bw.Flush() + if c.bw.Available() == 0 { + err = c.bw.Flush() if err != nil { return n, err } } // Start of next write in the buffer. - i := cw.bw.Buffered() + i := c.bw.Buffered() j := len(p) - if j > cw.bw.Available() { - j = cw.bw.Available() + if j > c.bw.Available() { + j = c.bw.Available() } - _, err := cw.bw.Write(p[:j]) + _, err := c.bw.Write(p[:j]) if err != nil { return n, err } - maskKey = mask(maskKey, cw.writeBuf[i:cw.bw.Buffered()]) + maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) p = p[j:] n += j diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 99996a69..36dd2dfd 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" From e8dfe270f06873c243fd98f2f19303093d5af85a Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 29 Nov 2019 13:17:04 -0500 Subject: [PATCH 18/55] Make CI pass --- accept.go | 2 + accept_test.go | 2 + assert_test.go | 5 ++- autobahn_test.go | 2 + ci/lint.mk | 2 +- close.go | 58 +++++++++++++++++----------- close_test.go | 8 ++-- compress.go | 24 ++++++------ conn.go | 80 ++++++++++++++------------------------- conn_test.go | 6 +-- dial.go | 2 + frame.go | 2 + go.mod | 6 --- go.sum | 24 ------------ internal/assert/assert.go | 6 +++ read.go | 40 +++++++++++++++----- write.go | 44 ++++++++++++++------- ws_js.go | 19 ++++++++-- ws_js_test.go | 2 +- wsjson/wsjson.go | 1 + 20 files changed, 183 insertions(+), 152 deletions(-) diff --git a/accept.go b/accept.go index 964e0401..ea7beebd 100644 --- a/accept.go +++ b/accept.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( diff --git a/accept_test.go b/accept_test.go index d68d4d6d..551fe4de 100644 --- a/accept_test.go +++ b/accept_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( diff --git a/assert_test.go b/assert_test.go index e4319938..dd4c30cd 100644 --- a/assert_test.go +++ b/assert_test.go @@ -4,11 +4,12 @@ import ( "context" "crypto/rand" "io" + "strings" + "testing" + "nhooyr.io/websocket" "nhooyr.io/websocket/internal/assert" "nhooyr.io/websocket/wsjson" - "strings" - "testing" ) func randBytes(t *testing.T, n int) []byte { diff --git a/autobahn_test.go b/autobahn_test.go index 30c96a7c..6b3b5b72 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket_test import ( diff --git a/ci/lint.mk b/ci/lint.mk index a656ea8d..031f0de3 100644 --- a/ci/lint.mk +++ b/ci/lint.mk @@ -1,4 +1,4 @@ -lint: govet golint govet-wasm golint-wasm +lint: govet golint govet: go vet ./... diff --git a/close.go b/close.go index 4c474b78..af437553 100644 --- a/close.go +++ b/close.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( @@ -6,6 +8,8 @@ import ( "errors" "fmt" "log" + "time" + "nhooyr.io/websocket/internal/errd" ) @@ -99,19 +103,24 @@ func (c *Conn) Close(code StatusCode, reason string) error { func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") + defer c.close(nil) err = c.writeClose(code, reason) if err != nil { return err } - return c.waitClose() + err = c.waitCloseHandshake() + if CloseStatus(err) == -1 { + return err + } + return nil } func (c *Conn) writeError(code StatusCode, err error) { c.setCloseErr(err) c.writeClose(code, err.Error()) - c.closeWithErr(nil) + c.close(nil) } func (c *Conn) writeClose(code StatusCode, reason string) error { @@ -130,28 +139,33 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { return c.writeControl(context.Background(), opClose, p) } -func (c *Conn) waitClose() error { - defer c.closeWithErr(nil) +func (c *Conn) waitCloseHandshake() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() - return nil + err := c.readMu.Lock(ctx) + if err != nil { + return err + } + defer c.readMu.Unlock() + + if c.readCloseFrameErr != nil { + return c.readCloseFrameErr + } - // ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - // defer cancel() - // - // err := cr.mu.Lock(ctx) - // if err != nil { - // return err - // } - // defer cr.mu.Unlock() - // - // b := bpool.Get() - // buf := b.Bytes() - // buf = buf[:cap(buf)] - // defer bpool.Put(b) - // - // for { - // return nil - // } + for { + h, err := c.readLoop(ctx) + if err != nil { + return err + } + + for i := int64(0); i < h.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + } } func parseClosePayload(p []byte) (CloseError, error) { diff --git a/close_test.go b/close_test.go index 9551699a..a2e0f67d 100644 --- a/close_test.go +++ b/close_test.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( @@ -49,7 +51,7 @@ func TestCloseError(t *testing.T) { t.Parallel() _, err := tc.ce.bytesErr() - if (tc.success) { + if tc.success { assert.Success(t, err) } else { assert.Error(t, err) @@ -101,7 +103,7 @@ func Test_parseClosePayload(t *testing.T) { t.Parallel() ce, err := parseClosePayload(tc.p) - if (tc.success) { + if tc.success { assert.Success(t, err) assert.Equal(t, tc.ce, ce, "CloseError") } else { @@ -151,7 +153,7 @@ func Test_validWireCloseCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, tc.code, validWireCloseCode(tc.code), "validWireCloseCode") + assert.Equal(t, tc.valid, validWireCloseCode(tc.code), "validWireCloseCode") }) } } diff --git a/compress.go b/compress.go index 9e075430..2410cb4e 100644 --- a/compress.go +++ b/compress.go @@ -19,16 +19,6 @@ import ( type CompressionMode int const ( - // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. - // This enables reusing the sliding window from previous messages. - // As most WebSocket protocols are repetitive, this is the default. - // - // The message will only be compressed if greater than or equal to 128 bytes. - // - // If the peer negotiates NoContextTakeover on the client or server side, it will be - // used instead as this is required by the RFC. - CompressionContextTakeover CompressionMode = iota - // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed // for every message. This applies to both server and client side. // @@ -36,8 +26,18 @@ const ( // will not be used but the memory overhead will be much lower if the connections // are long lived and seldom used. // - // The message will only be compressed if greater than or equal to 512 bytes. - CompressionNoContextTakeover + // The message will only be compressed if greater than 512 bytes. + CompressionNoContextTakeover CompressionMode = iota + + // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. + // This enables reusing the sliding window from previous messages. + // As most WebSocket protocols are repetitive, this can be very efficient. + // + // The message will only be compressed if greater than 128 bytes. + // + // If the peer negotiates NoContextTakeover on the client or server side, it will be + // used instead as this is required by the RFC. + CompressionContextTakeover // CompressionDisabled disables the deflate extension. // diff --git a/conn.go b/conn.go index dc067d18..10fe2e1a 100644 --- a/conn.go +++ b/conn.go @@ -49,21 +49,21 @@ type Conn struct { writeTimeout chan context.Context // Read state. - readMu mu - readControlBuf [maxControlPayload]byte - msgReader *msgReader + readMu *mu + readControlBuf [maxControlPayload]byte + msgReader *msgReader + readCloseFrameErr error // Write state. msgWriter *msgWriter - writeFrameMu mu + writeFrameMu *mu writeBuf []byte writeHeader header - closed chan struct{} - - closeMu sync.Mutex - closeErr error - closeHandshakeErr error + closed chan struct{} + closeMu sync.Mutex + closeErr error + wroteClose int64 pingCounter int32 activePingsMu sync.Mutex @@ -90,13 +90,16 @@ func newConn(cfg connConfig) *Conn { br: cfg.br, bw: cfg.bw, - readTimeout: make(chan context.Context), + readTimeout: make(chan context.Context), writeTimeout: make(chan context.Context), - closed: make(chan struct{}), + closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } + c.readMu = newMu(c) + c.writeFrameMu = newMu(c) + c.msgReader = newMsgReader(c) c.msgWriter = newMsgWriter(c) @@ -105,7 +108,7 @@ func newConn(cfg connConfig) *Conn { } runtime.SetFinalizer(c, func(c *Conn) { - c.closeWithErr(errors.New("connection garbage collected")) + c.close(errors.New("connection garbage collected")) }) go c.timeoutLoop() @@ -113,41 +116,13 @@ func newConn(cfg connConfig) *Conn { return c } -func newMsgReader(c *Conn) *msgReader { - mr := &msgReader{ - c: c, - fin: true, - } - - mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768) - if c.deflateNegotiated() && mr.contextTakeover() { - mr.ensureFlateReader() - } - - return mr -} - -func newMsgWriter(c *Conn) *msgWriter { - mw := &msgWriter{ - c: c, - } - mw.trimWriter = &trimLastFourBytesWriter{ - w: writerFunc(mw.write), - } - if c.deflateNegotiated() && mw.contextTakeover() { - mw.ensureFlateWriter() - } - - return mw -} - // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. func (c *Conn) Subprotocol() string { return c.subprotocol } -func (c *Conn) closeWithErr(err error) { +func (c *Conn) close(err error) { c.closeMu.Lock() defer c.closeMu.Unlock() @@ -195,13 +170,13 @@ func (c *Conn) timeoutLoop() { c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) go c.writeError(StatusPolicyViolation, errors.New("timed out")) case <-writeCtx.Done(): - c.closeWithErr(fmt.Errorf("write timed out: %w", writeCtx.Err())) + c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) return } } } -func (c *Conn) deflateNegotiated() bool { +func (c *Conn) deflate() bool { return c.copts != nil } @@ -245,7 +220,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { return c.closeErr case <-ctx.Done(): err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.closeWithErr(err) + c.close(err) return err case <-pong: return nil @@ -253,19 +228,21 @@ func (c *Conn) ping(ctx context.Context, p string) error { } type mu struct { - once sync.Once - ch chan struct{} + c *Conn + ch chan struct{} } -func (m *mu) init() { - m.once.Do(func() { - m.ch = make(chan struct{}, 1) - }) +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } } func (m *mu) Lock(ctx context.Context) error { - m.init() select { + case <-m.c.closed: + return m.c.closeErr case <-ctx.Done(): return ctx.Err() case m.ch <- struct{}{}: @@ -274,7 +251,6 @@ func (m *mu) Lock(ctx context.Context) error { } func (m *mu) TryLock() bool { - m.init() select { case m.ch <- struct{}{}: return true diff --git a/conn_test.go b/conn_test.go index cf2334f7..9b628cfe 100644 --- a/conn_test.go +++ b/conn_test.go @@ -25,7 +25,7 @@ func TestConn(t *testing.T) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, - // CompressionMode: websocket.CompressionDisabled, + CompressionMode: websocket.CompressionNoContextTakeover, }) assert.Success(t, err) defer c.Close(websocket.StatusInternalError, "") @@ -41,8 +41,8 @@ func TestConn(t *testing.T) { defer cancel() opts := &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - // CompressionMode: websocket.CompressionDisabled, + Subprotocols: []string{"echo"}, + CompressionMode: websocket.CompressionNoContextTakeover, } opts.HTTPClient = s.Client() diff --git a/dial.go b/dial.go index a1a10556..6cde30e7 100644 --- a/dial.go +++ b/dial.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( diff --git a/frame.go b/frame.go index 0257835e..47ff40f7 100644 --- a/frame.go +++ b/frame.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( diff --git a/go.mod b/go.mod index 1a2b08f4..6cd368b4 100644 --- a/go.mod +++ b/go.mod @@ -3,17 +3,11 @@ module nhooyr.io/websocket go 1.13 require ( - github.com/davecgh/go-spew v1.1.1 // indirect github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect github.com/gobwas/pool v0.2.0 // indirect github.com/gobwas/ws v1.0.2 github.com/golang/protobuf v1.3.2 github.com/google/go-cmp v0.3.1 github.com/gorilla/websocket v1.4.1 - github.com/kr/pretty v0.1.0 // indirect - github.com/stretchr/testify v1.4.0 // indirect - go.uber.org/atomic v1.4.0 // indirect - go.uber.org/multierr v1.1.0 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index d2f1f0e4..c639eb64 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,3 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= @@ -14,25 +10,5 @@ github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= -go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= -go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= -go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/assert/assert.go b/internal/assert/assert.go index b448711a..b20d9420 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -1,3 +1,4 @@ +// Package assert contains helpers for test assertions. package assert import ( @@ -5,6 +6,7 @@ import ( "testing" ) +// Equal asserts exp == act. func Equal(t testing.TB, exp, act interface{}, name string) { t.Helper() diff := cmpDiff(exp, act) @@ -13,6 +15,7 @@ func Equal(t testing.TB, exp, act interface{}, name string) { } } +// NotEqual asserts exp != act. func NotEqual(t testing.TB, exp, act interface{}, name string) { t.Helper() if cmpDiff(exp, act) == "" { @@ -20,6 +23,7 @@ func NotEqual(t testing.TB, exp, act interface{}, name string) { } } +// Success asserts exp == nil. func Success(t testing.TB, err error) { t.Helper() if err != nil { @@ -27,6 +31,7 @@ func Success(t testing.TB, err error) { } } +// Error asserts exp != nil. func Error(t testing.TB, err error) { t.Helper() if err == nil { @@ -34,6 +39,7 @@ func Error(t testing.TB, err error) { } } +// ErrorContains asserts the error string from err contains sub. func ErrorContains(t testing.TB, err error, sub string) { t.Helper() Error(t, err) diff --git a/read.go b/read.go index d8691d65..c72b6c17 100644 --- a/read.go +++ b/read.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( @@ -70,13 +72,27 @@ func (c *Conn) SetReadLimit(n int64) { c.msgReader.limitReader.setLimit(n) } -func (mr *msgReader) ensureFlateReader() { +func newMsgReader(c *Conn) *msgReader { + mr := &msgReader{ + c: c, + fin: true, + } + + mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768) + if c.deflate() && mr.contextTakeover() { + mr.initFlateReader() + } + + return mr +} + +func (mr *msgReader) initFlateReader() { mr.flateReader = getFlateReader(readerFunc(mr.read)) mr.limitReader.reset(mr.flateReader) } func (mr *msgReader) close() { - if mr.c.deflateNegotiated() && mr.contextTakeover() { + if mr.c.deflate() && mr.contextTakeover() { mr.c.readMu.Lock(context.Background()) putFlateReader(mr.flateReader) mr.c.readMu.Unlock() @@ -92,7 +108,7 @@ func (mr *msgReader) contextTakeover() bool { func (c *Conn) readRSV1Illegal(h header) bool { // If compression is enabled, rsv1 is always illegal. - if !c.deflateNegotiated() { + if !c.deflate() { return true } // rsv1 is only allowed on data frames beginning messages. @@ -154,7 +170,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case <-ctx.Done(): return header{}, ctx.Err() default: - c.closeWithErr(err) + c.close(err) return header{}, err } } @@ -184,7 +200,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { return n, ctx.Err() default: err = fmt.Errorf("failed to read frame payload: %w", err) - c.closeWithErr(err) + c.close(err) return n, err } } @@ -198,7 +214,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { return n, err } -func (c *Conn) handleControl(ctx context.Context, h header) error { +func (c *Conn) handleControl(ctx context.Context, h header) (err error) { if h.payloadLength < 0 || h.payloadLength > maxControlPayload { err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) c.writeError(StatusProtocolError, err) @@ -215,7 +231,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { defer cancel() b := c.readControlBuf[:h.payloadLength] - _, err := c.readFramePayload(ctx, b) + _, err = c.readFramePayload(ctx, b) if err != nil { return err } @@ -237,6 +253,10 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { return nil } + defer func() { + c.readCloseFrameErr = err + }() + ce, err := parseClosePayload(b) if err != nil { err = fmt.Errorf("received invalid close payload: %w", err) @@ -301,7 +321,7 @@ func (mr *msgReader) reset(ctx context.Context, h header) { if mr.deflate { mr.deflateTail.Reset(deflateMessageTail) if !mr.contextTakeover() { - mr.ensureFlateReader() + mr.initFlateReader() } } mr.setFrame(h) @@ -329,7 +349,7 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) { defer mr.c.readMu.Unlock() if mr.payloadLength == 0 && mr.fin { - if mr.c.deflateNegotiated() && !mr.contextTakeover() { + if mr.c.deflate() && !mr.contextTakeover() { if mr.flateReader != nil { putFlateReader(mr.flateReader) mr.flateReader = nil @@ -345,7 +365,7 @@ func (mr *msgReader) read(p []byte) (int, error) { if mr.payloadLength == 0 { if mr.fin { if mr.deflate { - n, _ := mr.deflateTail.Read(p[:4]) + n, _ := mr.deflateTail.Read(p) return n, nil } return 0, io.EOF diff --git a/write.go b/write.go index 0ddf11e1..526b3b66 100644 --- a/write.go +++ b/write.go @@ -1,3 +1,5 @@ +// +build !js + package websocket import ( @@ -45,11 +47,26 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { return nil } +func newMsgWriter(c *Conn) *msgWriter { + mw := &msgWriter{ + c: c, + mu: newMu(c), + } + mw.trimWriter = &trimLastFourBytesWriter{ + w: writerFunc(mw.write), + } + if c.deflate() && mw.deflateContextTakeover() { + mw.ensureFlateWriter() + } + + return mw +} + func (mw *msgWriter) ensureFlateWriter() { mw.flateWriter = getFlateWriter(mw.trimWriter) } -func (mw *msgWriter) contextTakeover() bool { +func (mw *msgWriter) deflateContextTakeover() bool { if mw.c.client { return mw.c.copts.clientNoContextTakeover } @@ -70,7 +87,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return 0, err } - if !c.deflateNegotiated() { + if !c.deflate() { // Fast single frame path. defer c.msgWriter.mu.Unlock() return c.writeFrame(ctx, true, c.msgWriter.opcode, p) @@ -88,15 +105,15 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error type msgWriter struct { c *Conn - mu mu + mu *mu deflate bool ctx context.Context opcode opcode closed bool - trimWriter *trimLastFourBytesWriter - flateWriter *flate.Writer + trimWriter *trimLastFourBytesWriter + flateWriter *flate.Writer } func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { @@ -108,6 +125,7 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { mw.closed = false mw.ctx = ctx mw.opcode = opcode(typ) + mw.deflate = false return nil } @@ -119,9 +137,9 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { return 0, errors.New("cannot use closed writer") } - if mw.c.deflateNegotiated() { + if mw.c.deflate() { if !mw.deflate { - if !mw.contextTakeover() { + if !mw.deflateContextTakeover() { mw.ensureFlateWriter() } mw.trimWriter.reset() @@ -152,7 +170,7 @@ func (mw *msgWriter) Close() (err error) { } mw.closed = true - if mw.c.deflateNegotiated() { + if mw.c.deflate() { err = mw.flateWriter.Flush() if err != nil { return fmt.Errorf("failed to flush flate writer: %w", err) @@ -164,7 +182,7 @@ func (mw *msgWriter) Close() (err error) { return fmt.Errorf("failed to write fin frame: %w", err) } - if mw.deflate && !mw.contextTakeover() { + if mw.deflate && !mw.deflateContextTakeover() { putFlateWriter(mw.flateWriter) mw.deflate = false } @@ -173,10 +191,10 @@ func (mw *msgWriter) Close() (err error) { return nil } -func (cw *msgWriter) close() { - if cw.c.deflateNegotiated() && cw.contextTakeover() { - cw.mu.Lock(context.Background()) - putFlateWriter(cw.flateWriter) +func (mw *msgWriter) close() { + if mw.c.deflate() && mw.deflateContextTakeover() { + mw.mu.Lock(context.Background()) + putFlateWriter(mw.flateWriter) } } diff --git a/ws_js.go b/ws_js.go index 3043106b..950aa01b 100644 --- a/ws_js.go +++ b/ws_js.go @@ -13,7 +13,18 @@ import ( "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" - "nhooyr.io/websocket/internal/wssync" +) + +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int + +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like Protobufs. + MessageBinary ) // Conn provides a wrapper around the browser WebSocket API. @@ -21,10 +32,10 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit *wssync.Int64 + msgReadLimit atomicInt64 closingMu sync.Mutex - isReadClosed *wssync.Int64 + isReadClosed atomicInt64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -337,6 +348,7 @@ func (w writer) Close() error { return nil } +// CloseRead implements *Conn.CloseRead for wasm. func (c *Conn) CloseRead(ctx context.Context) context.Context { c.isReadClosed.Store(1) @@ -349,6 +361,7 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { return ctx } +// SetReadLimit implements *Conn.SetReadLimit for wasm. func (c *Conn) SetReadLimit(n int64) { c.msgReadLimit.Store(n) } diff --git a/ws_js_test.go b/ws_js_test.go index ea888b59..6e87480b 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -1,4 +1,4 @@ -package websocket +package websocket_test import ( "context" diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 36dd2dfd..99996a69 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" From f6137f3f404630d19a84bc0fd59570ff6a967004 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 6 Dec 2019 13:35:49 -0600 Subject: [PATCH 19/55] Add minor improvements Closes #179 --- README.md | 39 +++--- accept_test.go | 20 +-- assert_test.go | 26 ++-- ci/image/Dockerfile | 2 - ci/test.mk | 8 +- close.go | 11 +- close_test.go | 16 +-- conn.go | 6 +- conn_test.go | 7 +- doc.go | 9 +- frame_test.go | 15 ++- go.mod | 4 +- go.sum | 247 +++++++++++++++++++++++++++++++++++++- internal/assert/assert.go | 50 -------- internal/assert/cmp.go | 53 -------- read.go | 64 +++++----- 16 files changed, 359 insertions(+), 218 deletions(-) delete mode 100644 internal/assert/assert.go delete mode 100644 internal/assert/cmp.go diff --git a/README.md b/README.md index f0babdfc..e958d2ab 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # websocket -[![version](https://img.shields.io/github/v/release/nhooyr/websocket?color=6b9ded&sort=semver)](https://github.com/nhooyr/websocket/releases) -[![docs](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket) +[![release](https://img.shields.io/github/v/release/nhooyr/websocket?color=6b9ded&sort=semver)](https://github.com/nhooyr/websocket/releases) +[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket) [![coverage](https://img.shields.io/coveralls/github/nhooyr/websocket?color=65d6a4)](https://coveralls.io/github/nhooyr/websocket) [![ci](https://github.com/nhooyr/websocket/workflows/ci/badge.svg)](https://github.com/nhooyr/websocket/actions) @@ -19,14 +19,14 @@ go get nhooyr.io/websocket - First class [context.Context](https://blog.golang.org/context) support - Thorough tests, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - [Zero dependencies](https://godoc.org/nhooyr.io/websocket?imports) -- JSON and ProtoBuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages +- JSON and protobuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes - Concurrent writes - [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper -- [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) +- [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression -- Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) +- Can target [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) ## Roadmap @@ -85,7 +85,11 @@ c.Close(websocket.StatusNormalClosure, "") ### gorilla/websocket -[gorilla/websocket](https://github.com/gorilla/websocket) is a widely used and mature library. +Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): + +- Mature and widely used +- [Prepared writes](https://godoc.org/github.com/gorilla/websocket#PreparedMessage) +- Configurable [buffer sizes](https://godoc.org/github.com/gorilla/websocket#hdr-Buffers) Advantages of nhooyr.io/websocket: @@ -94,26 +98,26 @@ Advantages of nhooyr.io/websocket: - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Full [context.Context](https://blog.golang.org/context) support -- Uses [net/http.Client](https://golang.org/pkg/net/http/#Client) for dialing +- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) - Will enable easy HTTP/2 support in the future - - Gorilla writes directly to a net.Conn and so duplicates features from net/http.Client. + - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Concurrent writes - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) -- Idiomatic [ping](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API - - gorilla/websocket requires registering a pong callback and then sending a Ping -- Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) +- Idiomatic [ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API + - Gorilla requires registering a pong callback before sending a Ping +- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) - Transparent message buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - - Gorilla's implementation depends on unsafe and is slower + - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode -- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper +- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) #### golang.org/x/net/websocket [golang.org/x/net/websocket](https://godoc.org/golang.org/x/net/websocket) is deprecated. -See ([golang/go/issues/18152](https://github.com/golang/go/issues/18152)). +See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). The [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper will ease in transitioning to nhooyr.io/websocket. @@ -124,10 +128,3 @@ to nhooyr.io/websocket. in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. - -## Users - -If your company or project is using this library, feel free to open an issue or PR to amend this list. - -- [Coder](https://github.com/cdr) -- [Tatsu Works](https://github.com/tatsuworks) - Ingresses 20 TB in WebSocket data every month on their Discord bot. diff --git a/accept_test.go b/accept_test.go index 551fe4de..2a784d19 100644 --- a/accept_test.go +++ b/accept_test.go @@ -7,7 +7,7 @@ import ( "strings" "testing" - "nhooyr.io/websocket/internal/assert" + "cdr.dev/slog/sloggers/slogtest/assert" ) func TestAccept(t *testing.T) { @@ -20,7 +20,7 @@ func TestAccept(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) _, err := Accept(w, r, nil) - assert.ErrorContains(t, err, "protocol violation") + assert.ErrorContains(t, "Accept", err, "protocol violation") }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -34,7 +34,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) - assert.ErrorContains(t, err, "http.ResponseWriter does not implement http.Hijacker") + assert.ErrorContains(t, "Accept", err, "http.ResponseWriter does not implement http.Hijacker") }) } @@ -127,9 +127,9 @@ func Test_verifyClientHandshake(t *testing.T) { err := verifyClientRequest(r) if tc.success { - assert.Success(t, err) + assert.Success(t, "verifyClientRequest", err) } else { - assert.Error(t, err) + assert.Error(t, "verifyClientRequest", err) } }) } @@ -179,7 +179,7 @@ func Test_selectSubprotocol(t *testing.T) { r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) negotiated := selectSubprotocol(r, tc.serverProtocols) - assert.Equal(t, tc.negotiated, negotiated, "negotiated") + assert.Equal(t, "negotiated", tc.negotiated, negotiated) }) } } @@ -234,10 +234,14 @@ func Test_authenticateOrigin(t *testing.T) { err := authenticateOrigin(r) if tc.success { - assert.Success(t, err) + assert.Success(t, "authenticateOrigin", err) } else { - assert.Error(t, err) + assert.Error(t, "authenticateOrigin", err) } }) } } + +func Test_acceptCompression(t *testing.T) { + +} diff --git a/assert_test.go b/assert_test.go index dd4c30cd..cd78fbb3 100644 --- a/assert_test.go +++ b/assert_test.go @@ -3,19 +3,19 @@ package websocket_test import ( "context" "crypto/rand" - "io" "strings" "testing" + "cdr.dev/slog/sloggers/slogtest/assert" + "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/assert" "nhooyr.io/websocket/wsjson" ) func randBytes(t *testing.T, n int) []byte { b := make([]byte, n) - _, err := io.ReadFull(rand.Reader, b) - assert.Success(t, err) + _, err := rand.Reader.Read(b) + assert.Success(t, "readRandBytes", err) return b } @@ -25,7 +25,7 @@ func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) exp := randString(t, n) err := wsjson.Write(ctx, c, exp) - assert.Success(t, err) + assert.Success(t, "wsjson.Write", err) assertJSONRead(t, ctx, c, exp) @@ -37,9 +37,9 @@ func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp in var act interface{} err := wsjson.Read(ctx, c, &act) - assert.Success(t, err) + assert.Success(t, "wsjson.Read", err) - assert.Equal(t, exp, act, "JSON") + assert.Equal(t, "json", exp, act) } func randString(t *testing.T, n int) string { @@ -60,19 +60,19 @@ func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websoc p := randBytes(t, n) err := c.Write(ctx, typ, p) - assert.Success(t, err) + assert.Success(t, "write", err) typ2, p2, err := c.Read(ctx) - assert.Success(t, err) + assert.Success(t, "read", err) - assert.Equal(t, typ, typ2, "data type") - assert.Equal(t, p, p2, "payload") + assert.Equal(t, "dataType", typ, typ2) + assert.Equal(t, "payload", p, p2) } func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { t.Helper() - assert.Equal(t, exp, c.Subprotocol(), "subprotocol") + assert.Equal(t, "subprotocol", exp, c.Subprotocol()) } func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) { @@ -82,5 +82,5 @@ func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) { t.Logf("error: %+v", err) } }() - assert.Equal(t, exp, websocket.CloseStatus(err), "StatusCode") + assert.Equal(t, "closeStatus", exp, websocket.CloseStatus(err)) } diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index ccfac109..bfc05fc8 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -2,8 +2,6 @@ FROM golang:1 RUN apt-get update RUN apt-get install -y chromium -RUN apt-get install -y npm -RUN apt-get install -y jq ENV GOFLAGS="-mod=readonly" ENV PAGER=cat diff --git a/ci/test.mk b/ci/test.mk index f9a6e09a..95e049b2 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -9,13 +9,7 @@ ci/out/coverage.html: gotest coveralls: gotest # https://github.com/coverallsapp/github-action/blob/master/src/run.ts echo "--- coveralls" - export GIT_BRANCH="$$GITHUB_REF" - export BUILD_NUMBER="$$GITHUB_SHA" - if [[ $$GITHUB_EVENT_NAME == pull_request ]]; then - export CI_PULL_REQUEST="$$(jq .number "$$GITHUB_EVENT_PATH")" - BUILD_NUMBER="$$BUILD_NUMBER-PR-$$CI_PULL_REQUEST" - fi - goveralls -coverprofile=ci/out/coverage.prof -service=github + goveralls -coverprofile=ci/out/coverage.prof gotest: go test -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... diff --git a/close.go b/close.go index af437553..7ccdb173 100644 --- a/close.go +++ b/close.go @@ -30,7 +30,7 @@ const ( StatusProtocolError StatusCode = 1002 StatusUnsupportedData StatusCode = 1003 - // 1004 is reserved and so not exported. + // 1004 is reserved and so unexported. statusReserved StatusCode = 1004 // StatusNoStatusRcvd cannot be sent in a close message. @@ -103,7 +103,6 @@ func (c *Conn) Close(code StatusCode, reason string) error { func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") - defer c.close(nil) err = c.writeClose(code, reason) if err != nil { @@ -124,6 +123,14 @@ func (c *Conn) writeError(code StatusCode, err error) { } func (c *Conn) writeClose(code StatusCode, reason string) error { + c.closeMu.Lock() + closing := c.wroteClose + c.wroteClose = true + c.closeMu.Unlock() + if closing { + return errors.New("already wrote close") + } + ce := CloseError{ Code: code, Reason: reason, diff --git a/close_test.go b/close_test.go index a2e0f67d..16b570d0 100644 --- a/close_test.go +++ b/close_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - "nhooyr.io/websocket/internal/assert" + "cdr.dev/slog/sloggers/slogtest/assert" ) func TestCloseError(t *testing.T) { @@ -52,9 +52,9 @@ func TestCloseError(t *testing.T) { _, err := tc.ce.bytesErr() if tc.success { - assert.Success(t, err) + assert.Success(t, "CloseError.bytesErr", err) } else { - assert.Error(t, err) + assert.Error(t, "CloseError.bytesErr", err) } }) } @@ -104,10 +104,10 @@ func Test_parseClosePayload(t *testing.T) { ce, err := parseClosePayload(tc.p) if tc.success { - assert.Success(t, err) - assert.Equal(t, tc.ce, ce, "CloseError") + assert.Success(t, "parse err", err) + assert.Equal(t, "ce", tc.ce, ce) } else { - assert.Error(t, err) + assert.Error(t, "parse err", err) } }) } @@ -153,7 +153,7 @@ func Test_validWireCloseCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, tc.valid, validWireCloseCode(tc.code), "validWireCloseCode") + assert.Equal(t, "valid", tc.valid, validWireCloseCode(tc.code)) }) } } @@ -190,7 +190,7 @@ func TestCloseStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, tc.exp, CloseStatus(tc.in), "CloseStatus") + assert.Equal(t, "closeStatus", tc.exp, CloseStatus(tc.in)) }) } } diff --git a/conn.go b/conn.go index 10fe2e1a..061c4517 100644 --- a/conn.go +++ b/conn.go @@ -63,7 +63,7 @@ type Conn struct { closed chan struct{} closeMu sync.Mutex closeErr error - wroteClose int64 + wroteClose bool pingCounter int32 activePingsMu sync.Mutex @@ -244,7 +244,9 @@ func (m *mu) Lock(ctx context.Context) error { case <-m.c.closed: return m.c.closeErr case <-ctx.Done(): - return ctx.Err() + err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) + m.c.close(err) + return err case m.ch <- struct{}{}: return nil } diff --git a/conn_test.go b/conn_test.go index 9b628cfe..9b311a87 100644 --- a/conn_test.go +++ b/conn_test.go @@ -13,8 +13,9 @@ import ( "testing" "time" + "cdr.dev/slog/sloggers/slogtest/assert" + "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/assert" ) func TestConn(t *testing.T) { @@ -27,7 +28,7 @@ func TestConn(t *testing.T) { InsecureSkipVerify: true, CompressionMode: websocket.CompressionNoContextTakeover, }) - assert.Success(t, err) + assert.Success(t, "accept", err) defer c.Close(websocket.StatusInternalError, "") err = echoLoop(r.Context(), c) @@ -47,7 +48,7 @@ func TestConn(t *testing.T) { opts.HTTPClient = s.Client() c, _, err := websocket.Dial(ctx, wsURL, opts) - assert.Success(t, err) + assert.Success(t, "dial", err) assertJSONEcho(t, ctx, c, 2) }) } diff --git a/doc.go b/doc.go index 54b7e1ea..6847d537 100644 --- a/doc.go +++ b/doc.go @@ -4,14 +4,17 @@ // // https://tools.ietf.org/html/rfc6455 // -// Use Dial to dial a WebSocket server and Accept to accept a WebSocket client. +// Use Dial to dial a WebSocket server. +// +// Accept to accept a WebSocket client. +// // Conn represents the resulting WebSocket connection. // // The examples are the best way to understand how to correctly use the library. // // The wsjson and wspb subpackages contain helpers for JSON and Protobuf messages. // -// See https://nhooyr.io/websocket for further information. +// More documentation at https://nhooyr.io/websocket. // // Wasm // @@ -23,7 +26,7 @@ // Some important caveats to be aware of: // // - Conn.Ping is no-op -// - HTTPClient, HTTPHeader and CompressionOptions in DialOptions are no-op +// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op // - *http.Response from Dial is &http.Response{} on success // package websocket // import "nhooyr.io/websocket" diff --git a/frame_test.go b/frame_test.go index 68455cfa..323ea991 100644 --- a/frame_test.go +++ b/frame_test.go @@ -13,10 +13,9 @@ import ( "time" _ "unsafe" + "cdr.dev/slog/sloggers/slogtest/assert" "github.com/gobwas/ws" _ "github.com/gorilla/websocket" - - "nhooyr.io/websocket/internal/assert" ) func TestHeader(t *testing.T) { @@ -81,14 +80,14 @@ func testHeader(t *testing.T, h header) { r := bufio.NewReader(b) err := writeFrameHeader(h, w) - assert.Success(t, err) + assert.Success(t, "writeFrameHeader", err) err = w.Flush() - assert.Success(t, err) + assert.Success(t, "flush", err) h2, err := readFrameHeader(r) - assert.Success(t, err) + assert.Success(t, "readFrameHeader", err) - assert.Equal(t, h, h2, "header") + assert.Equal(t, "header", h, h2) } func Test_mask(t *testing.T) { @@ -99,8 +98,8 @@ func Test_mask(t *testing.T) { p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} gotKey32 := mask(key32, p) - assert.Equal(t, []byte{0, 0, 0, 0x0d, 0x6}, p, "mask") - assert.Equal(t, bits.RotateLeft32(key32, -8), gotKey32, "mask key") + assert.Equal(t, "mask", []byte{0, 0, 0, 0x0d, 0x6}, p) + assert.Equal(t, "maskKey", bits.RotateLeft32(key32, -8), gotKey32) } func basicMask(maskKey [4]byte, pos int, b []byte) int { diff --git a/go.mod b/go.mod index 6cd368b4..06098485 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,13 @@ module nhooyr.io/websocket go 1.13 require ( + cdr.dev/slog v1.3.0 github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect github.com/gobwas/pool v0.2.0 // indirect github.com/gobwas/ws v1.0.2 github.com/golang/protobuf v1.3.2 - github.com/google/go-cmp v0.3.1 github.com/gorilla/websocket v1.4.1 + github.com/mattn/goveralls v0.0.4 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 + golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea // indirect ) diff --git a/go.sum b/go.sum index c639eb64..df11eba9 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,257 @@ +cdr.dev/slog v1.3.0 h1:MYN1BChIaVEGxdS7I5cpdyMC0+WfJfK8BETAfzfLUGQ= +cdr.dev/slog v1.3.0/go.mod h1:C5OL99WyuOK8YHZdYY57dAPN1jK2WJlCdq2VP6xeQns= +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= +cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= +cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= +cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= +cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= +cloud.google.com/go v0.49.0 h1:CH+lkubJzcPYB1Ggupcq0+k8Ni2ILdG2lYjDIgavDBQ= +cloud.google.com/go v0.49.0/go.mod h1:hGvAdzcWNbyuxS3nWhD7H2cIJxjRRTRLQVB0bdputVY= +cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= +cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= +cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= +cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= +dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0= +github.com/GeertJohan/go.rice v1.0.0/go.mod h1:eH6gbSOAUv07dQuZVnBmoDP8mgsM1rtixis4Tib9if0= +github.com/akavel/rsrc v0.8.0/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c= +github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38 h1:smF2tmSOzy2Mm+0dGI2AIUHY+w0BUc+4tn40djz7+6U= +github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38/go.mod h1:r7bzyVFMNntcxPZXK3/+KdruV1H5KSlyVY0gc+NgInI= +github.com/alecthomas/chroma v0.7.0 h1:z+0HgTUmkpRDRz0SRSdMaqOLfJV4F+N1FPDZUZIDUzw= +github.com/alecthomas/chroma v0.7.0/go.mod h1:1U/PfCsTALWWYHDnsIQkxEBM0+6LLe0v8+RSVMOwxeY= +github.com/alecthomas/colour v0.0.0-20160524082231-60882d9e2721 h1:JHZL0hZKJ1VENNfmXvHbgYlbUOvpzYzvy2aZU5gXVeo= +github.com/alecthomas/colour v0.0.0-20160524082231-60882d9e2721/go.mod h1:QO9JBoKquHd+jz9nshCh40fOfO+JzsoXy8qTHF68zU0= +github.com/alecthomas/kong v0.1.17-0.20190424132513-439c674f7ae0/go.mod h1:+inYUSluD+p4L8KdviBSgzcqEjUQOfC5fQDRFuc36lI= +github.com/alecthomas/kong v0.2.1-0.20190708041108-0548c6b1afae/go.mod h1:+inYUSluD+p4L8KdviBSgzcqEjUQOfC5fQDRFuc36lI= +github.com/alecthomas/kong-hcl v0.1.8-0.20190615233001-b21fea9723c8/go.mod h1:MRgZdU3vrFd05IQ89AxUZ0aYdF39BYoNFa324SodPCA= +github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897 h1:p9Sln00KOTlrYkxI1zYWl1QLnEqAqEARBEYa8FQnQcY= +github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/daaku/go.zipexe v1.0.0/go.mod h1:z8IiR6TsVLEYKwXAoE/I+8ys/sDkgTzSL0CLnGVd57E= +github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 h1:y5HC9v93H5EPKqaS1UYVg1uYah5Xf51mBfIoWehClUQ= +github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964/go.mod h1:Xd9hchkHSWYkEqJwUGisez3G1QY8Ryz0sdWrLPMGjLk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.1.6 h1:CqB4MjHw0MFCDj+PHHjiESmHX+N7t0tJzKvC6M97BRg= +github.com/dlclark/regexp2 v1.1.6/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= +github.com/dlclark/regexp2 v1.2.0 h1:8sAhBGEM0dRWogWqWyQeIJnxjWO6oIjl8FKqREDsGfk= +github.com/dlclark/regexp2 v1.2.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= +github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM39fZuwSd1LwSqqSW0hOdXCYYDX0R3I= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 h1:uHTyIjqVhYRhLbJ8nIiOJHkEZZ+5YoOsAbD3sk82NiE= +github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.2-0.20191216170541-340f1ebe299e h1:4WfjkTUTsO6siF8ghDQQk6t7x/FPsv3w6MXkc47do7Q= +github.com/google/go-cmp v0.3.2-0.20191216170541-340f1ebe299e/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= +github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/gorilla/csrf v1.6.0/go.mod h1:7tSf8kmjNYr7IWDCYhd3U8Ck34iQ/Yw5CJu7bAkHEGI= +github.com/gorilla/handlers v1.4.1/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= +github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= +github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= +github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= +github.com/mattn/goveralls v0.0.4 h1:/mdWfiU2y8kZ48EtgByYev/XT3W4dkTuKLOJJsh/r+o= +github.com/mattn/goveralls v0.0.4/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/nkovacs/streamquote v0.0.0-20170412213628-49af9bddb229/go.mod h1:0aYXnNPJ8l7uZxf45rWW1a/uME32OF0rhiYGNQ2oF2E= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= +github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= +go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= +go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= +go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= +go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU= +golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g= +golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= +golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= +golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= +golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= +golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8= +golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181128092732-4ed8d59d0b35/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191210023423-ac6580df4449 h1:gSbV7h1NRL2G1xTg/owz62CST1oJBmxy4QpMMregXVQ= +golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2 h1:EtTFh6h4SAKemS+CURDMTDIANuduG5zKEXShyy18bGA= +golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea h1:mtRJM/ln5qwEigajtnZtuARALEPOooGf5lwkM5a9tt4= +golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= +google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= +google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= +google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1 h1:aQktFqmDE2yjveXJlVIfslDFmFnUXSqG0i6KRcJAeMc= +google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +google.golang.org/grpc v1.21.1 h1:j6XxA85m/6txkUCHvzlV5f+HBNl/1r5cZ2A/3IEFOO8= +google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.25.1 h1:wdKvqQk7IttEw92GoRyKG2IDrUIpgpj6H6m81yfeMW0= +google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= diff --git a/internal/assert/assert.go b/internal/assert/assert.go deleted file mode 100644 index b20d9420..00000000 --- a/internal/assert/assert.go +++ /dev/null @@ -1,50 +0,0 @@ -// Package assert contains helpers for test assertions. -package assert - -import ( - "strings" - "testing" -) - -// Equal asserts exp == act. -func Equal(t testing.TB, exp, act interface{}, name string) { - t.Helper() - diff := cmpDiff(exp, act) - if diff != "" { - t.Fatalf("unexpected %v: %v", name, diff) - } -} - -// NotEqual asserts exp != act. -func NotEqual(t testing.TB, exp, act interface{}, name string) { - t.Helper() - if cmpDiff(exp, act) == "" { - t.Fatalf("expected different %v: %+v", name, act) - } -} - -// Success asserts exp == nil. -func Success(t testing.TB, err error) { - t.Helper() - if err != nil { - t.Fatalf("unexpected error: %+v", err) - } -} - -// Error asserts exp != nil. -func Error(t testing.TB, err error) { - t.Helper() - if err == nil { - t.Fatal("expected error") - } -} - -// ErrorContains asserts the error string from err contains sub. -func ErrorContains(t testing.TB, err error, sub string) { - t.Helper() - Error(t, err) - errs := err.Error() - if !strings.Contains(errs, sub) { - t.Fatalf("error string %q does not contain %q", errs, sub) - } -} diff --git a/internal/assert/cmp.go b/internal/assert/cmp.go deleted file mode 100644 index 39be1f4a..00000000 --- a/internal/assert/cmp.go +++ /dev/null @@ -1,53 +0,0 @@ -package assert - -import ( - "reflect" - - "github.com/google/go-cmp/cmp" -) - -// https://github.com/google/go-cmp/issues/40#issuecomment-328615283 -func cmpDiff(exp, act interface{}) string { - return cmp.Diff(exp, act, deepAllowUnexported(exp, act)) -} - -func deepAllowUnexported(vs ...interface{}) cmp.Option { - m := make(map[reflect.Type]struct{}) - for _, v := range vs { - structTypes(reflect.ValueOf(v), m) - } - var typs []interface{} - for t := range m { - typs = append(typs, reflect.New(t).Elem().Interface()) - } - return cmp.AllowUnexported(typs...) -} - -func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { - if !v.IsValid() { - return - } - switch v.Kind() { - case reflect.Ptr: - if !v.IsNil() { - structTypes(v.Elem(), m) - } - case reflect.Interface: - if !v.IsNil() { - structTypes(v.Elem(), m) - } - case reflect.Slice, reflect.Array: - for i := 0; i < v.Len(); i++ { - structTypes(v.Index(i), m) - } - case reflect.Map: - for _, k := range v.MapKeys() { - structTypes(v.MapIndex(k), m) - } - case reflect.Struct: - m[v.Type()] = struct{}{} - for i := 0; i < v.NumField(); i++ { - structTypes(v.Field(i), m) - } - } -} diff --git a/read.go b/read.go index c72b6c17..dc59f9f4 100644 --- a/read.go +++ b/read.go @@ -69,7 +69,7 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // // When the limit is hit, the connection will be closed with StatusMessageTooBig. func (c *Conn) SetReadLimit(n int64) { - c.msgReader.limitReader.setLimit(n) + c.msgReader.limitReader.limit.Store(n) } func newMsgReader(c *Conn) *msgReader { @@ -87,15 +87,17 @@ func newMsgReader(c *Conn) *msgReader { } func (mr *msgReader) initFlateReader() { - mr.flateReader = getFlateReader(readerFunc(mr.read)) - mr.limitReader.reset(mr.flateReader) + mr.deflateReader = getFlateReader(readerFunc(mr.read)) + mr.limitReader.r = mr.deflateReader } func (mr *msgReader) close() { - if mr.c.deflate() && mr.contextTakeover() { - mr.c.readMu.Lock(context.Background()) - putFlateReader(mr.flateReader) - mr.c.readMu.Unlock() + mr.c.readMu.Lock(context.Background()) + defer mr.c.readMu.Unlock() + + if mr.deflateReader != nil { + putFlateReader(mr.deflateReader) + mr.deflateReader = nil } } @@ -266,7 +268,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { err = fmt.Errorf("received close frame: %w", ce) c.setCloseErr(err) - c.writeControl(context.Background(), opClose, ce.bytes()) + c.writeClose(ce.Code, ce.Reason) return err } @@ -302,36 +304,35 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro type msgReader struct { c *Conn - ctx context.Context - - deflate bool - flateReader io.Reader - deflateTail strings.Reader - - limitReader *limitReader + ctx context.Context + deflate bool + deflateReader io.Reader + deflateTail strings.Reader + limitReader *limitReader + fin bool payloadLength int64 maskKey uint32 - fin bool } func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.deflate = h.rsv1 if mr.deflate { - mr.deflateTail.Reset(deflateMessageTail) if !mr.contextTakeover() { mr.initFlateReader() } + mr.deflateTail.Reset(deflateMessageTail) } + + mr.limitReader.reset() mr.setFrame(h) - mr.fin = false } func (mr *msgReader) setFrame(h header) { + mr.fin = h.fin mr.payloadLength = h.payloadLength mr.maskKey = h.maskKey - mr.fin = h.fin } func (mr *msgReader) Read(p []byte) (_ int, err error) { @@ -350,9 +351,9 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) { if mr.payloadLength == 0 && mr.fin { if mr.c.deflate() && !mr.contextTakeover() { - if mr.flateReader != nil { - putFlateReader(mr.flateReader) - mr.flateReader = nil + if mr.deflateReader != nil { + putFlateReader(mr.deflateReader) + mr.deflateReader = nil } } return 0, io.EOF @@ -363,12 +364,9 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) { func (mr *msgReader) read(p []byte) (int, error) { if mr.payloadLength == 0 { - if mr.fin { - if mr.deflate { - n, _ := mr.deflateTail.Read(p) - return n, nil - } - return 0, io.EOF + if mr.fin && mr.deflate { + n, _ := mr.deflateTail.Read(p) + return n, nil } h, err := mr.c.readLoop(mr.ctx) @@ -413,17 +411,13 @@ func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { c: c, } lr.limit.Store(limit) - lr.reset(r) + lr.r = r + lr.reset() return lr } -func (lr *limitReader) reset(r io.Reader) { +func (lr *limitReader) reset() { lr.n = lr.limit.Load() - lr.r = r -} - -func (lr *limitReader) setLimit(limit int64) { - lr.limit.Store(limit) } func (lr *limitReader) Read(p []byte) (int, error) { From 6f6fa430a6e88699b3b8aef5d1b8499100f3e8b9 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 30 Dec 2019 22:03:21 -0500 Subject: [PATCH 20/55] Refactor autobahn --- accept.go | 2 - autobahn_test.go | 319 ++++++++++++++++++++++++++--------------------- close.go | 2 + compress.go | 18 +-- conn.go | 7 +- go.mod | 2 - go.sum | 7 -- read.go | 74 ++++++----- write.go | 45 ++++--- 9 files changed, 260 insertions(+), 216 deletions(-) diff --git a/accept.go b/accept.go index ea7beebd..f16180f0 100644 --- a/accept.go +++ b/accept.go @@ -37,8 +37,6 @@ type AcceptOptions struct { // If used incorrectly your WebSocket server will be open to CSRF attacks. InsecureSkipVerify bool - // CompressionMode sets the compression mode. - // See the docs on CompressionMode. CompressionMode CompressionMode } diff --git a/autobahn_test.go b/autobahn_test.go index 6b3b5b72..16384b27 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -9,7 +9,6 @@ import ( "io/ioutil" "net" "net/http" - "net/http/httptest" "os" "os/exec" "strconv" @@ -17,9 +16,27 @@ import ( "testing" "time" + "cdr.dev/slog/sloggers/slogtest/assert" + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/errd" ) +var excludedAutobahnCases = []string{ + // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just + // more performance overhead. + "6.*", "7.5.1", + + // We skip the tests related to requestMaxWindowBits as that is unimplemented due + // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 + "13.3.*", "13.4.*", "13.5.*", "13.6.*", + + "12.*", + "13.*", +} + +var autobahnCases = []string{"*"} + // https://github.com/crossbario/autobahn-python/tree/master/wstest func TestAutobahn(t *testing.T) { t.Parallel() @@ -35,19 +52,17 @@ func TestAutobahn(t *testing.T) { func testServerAutobahn(t *testing.T) { t.Parallel() - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) - if err != nil { - t.Logf("server handshake failed: %+v", err) - return - } - echoLoop(r.Context(), c) - })) - defer s.Close() + assert.Success(t, "accept", err) + err = echoLoop(r.Context(), c) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }, false) + defer closeFn() - spec := map[string]interface{}{ + specFile, err := tempJSONFile(map[string]interface{}{ "outdir": "ci/out/wstestServerReports", "servers": []interface{}{ map[string]interface{}{ @@ -55,92 +70,105 @@ func testServerAutobahn(t *testing.T) { "url": strings.Replace(s.URL, "http", "ws", 1), }, }, - "cases": []string{"*"}, - // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just - // more performance overhead. 7.5.1 is the same. - "exclude-cases": []string{"6.*", "7.5.1"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) - } - defer specFile.Close() - - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } - - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } + "cases": autobahnCases, + "exclude-cases": excludedAutobahnCases, + }) + assert.Success(t, "tempJSONFile", err) - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10) defer cancel() - args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} + args := []string{"--mode", "fuzzingclient", "--spec", specFile} wstest := exec.CommandContext(ctx, "wstest", args...) - out, err := wstest.CombinedOutput() - if err != nil { - t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) - } + _, err = wstest.CombinedOutput() + assert.Success(t, "wstest", err) checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") } -func unusedListenAddr() (string, error) { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - return "", err - } - l.Close() - return l.Addr().String(), nil -} - func testClientAutobahn(t *testing.T) { t.Parallel() - serverAddr, err := unusedListenAddr() - if err != nil { - t.Fatalf("failed to get unused listen addr for wstest: %v", err) - } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() - wsServerURL := "ws://" + serverAddr + wstestURL, closeFn, err := wstestClientServer(ctx) + assert.Success(t, "wstestClient", err) + defer closeFn() - spec := map[string]interface{}{ - "url": wsServerURL, - "outdir": "ci/out/wstestClientReports", - "cases": []string{"*"}, - // See TestAutobahnServer for the reasons why we exclude these. - "exclude-cases": []string{"6.*", "7.5.1"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) + err = waitWS(ctx, wstestURL) + assert.Success(t, "waitWS", err) + + cases, err := wstestCaseCount(ctx, wstestURL) + assert.Success(t, "wstestCaseCount", err) + + t.Run("cases", func(t *testing.T) { + for i := 1; i <= cases; i++ { + i := i + t.Run("", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(ctx, time.Second*45) + defer cancel() + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) + assert.Success(t, "autobahn dial", err) + + err = echoLoop(ctx, c) + t.Logf("echoLoop: %+v", err) + }) + } + }) + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) + assert.Success(t, "dial", err) + c.Close(websocket.StatusNormalClosure, "") + + checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") +} + +func waitWS(ctx context.Context, url string) error { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + for ctx.Err() == nil { + c, _, err := websocket.Dial(ctx, url, nil) + if err != nil { + continue + } + c.Close(websocket.StatusNormalClosure, "") + return nil } - defer specFile.Close() - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) + return ctx.Err() +} + +func wstestClientServer(ctx context.Context) (url string, closeFn func(), err error) { + serverAddr, err := unusedListenAddr() if err != nil { - t.Fatalf("failed to write spec: %v", err) + return "", nil, err } - err = specFile.Close() + url = "ws://" + serverAddr + + specFile, err := tempJSONFile(map[string]interface{}{ + "url": url, + "outdir": "ci/out/wstestClientReports", + "cases": autobahnCases, + "exclude-cases": excludedAutobahnCases, + }) if err != nil { - t.Fatalf("failed to close file: %v", err) + return "", nil, fmt.Errorf("failed to write spec: %w", err) } - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer func() { + if err != nil { + cancel() + } + }() - args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), + args := []string{"--mode", "fuzzingserver", "--spec", specFile, // Disables some server that runs as part of fuzzingserver mode. // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 "--webport=0", @@ -148,101 +176,104 @@ func testClientAutobahn(t *testing.T) { wstest := exec.CommandContext(ctx, "wstest", args...) err = wstest.Start() if err != nil { - t.Fatal(err) + return "", nil, fmt.Errorf("failed to start wstest: %w", err) } - defer func() { - err := wstest.Process.Kill() - if err != nil { - t.Error(err) - } - }() - - // Let it come up. - time.Sleep(time.Second * 5) - - var cases int - func() { - c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - _, r, err := c.Reader(ctx) - if err != nil { - t.Fatal(err) - } - b, err := ioutil.ReadAll(r) - if err != nil { - t.Fatal(err) - } - cases, err = strconv.Atoi(string(b)) - if err != nil { - t.Fatal(err) - } - c.Close(websocket.StatusNormalClosure, "") - }() + return url, func() { + wstest.Process.Kill() + }, nil +} - for i := 1; i <= cases; i++ { - func() { - ctx, cancel := context.WithTimeout(ctx, time.Second*45) - defer cancel() +func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { + defer errd.Wrap(&err, "failed to get case count") - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) - if err != nil { - t.Fatal(err) - } - echoLoop(ctx, c) - }() + c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil) + if err != nil { + return 0, err } + defer c.Close(websocket.StatusInternalError, "") - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) + _, r, err := c.Reader(ctx) + if err != nil { + return 0, err + } + b, err := ioutil.ReadAll(r) + if err != nil { + return 0, err + } + cases, err = strconv.Atoi(string(b)) if err != nil { - t.Fatal(err) + return 0, err } + c.Close(websocket.StatusNormalClosure, "") - checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") + return cases, nil } func checkWSTestIndex(t *testing.T, path string) { wstestOut, err := ioutil.ReadFile(path) - if err != nil { - t.Fatalf("failed to read index.json: %v", err) - } + assert.Success(t, "ioutil.ReadFile", err) var indexJSON map[string]map[string]struct { Behavior string `json:"behavior"` BehaviorClose string `json:"behaviorClose"` } err = json.Unmarshal(wstestOut, &indexJSON) - if err != nil { - t.Fatalf("failed to unmarshal index.json: %v", err) - } + assert.Success(t, "json.Unmarshal", err) - var failed bool for _, tests := range indexJSON { for test, result := range tests { - switch result.Behavior { - case "OK", "NON-STRICT", "INFORMATIONAL": - default: - failed = true - t.Errorf("test %v failed", test) - } - switch result.BehaviorClose { - case "OK", "INFORMATIONAL": - default: - failed = true - t.Errorf("bad close behaviour for test %v", test) - } + t.Run(test, func(t *testing.T) { + switch result.BehaviorClose { + case "OK", "INFORMATIONAL": + default: + t.Errorf("bad close behaviour") + } + + switch result.Behavior { + case "OK", "NON-STRICT", "INFORMATIONAL": + default: + t.Errorf("failed") + } + }) } } - if failed { - path = strings.Replace(path, ".json", ".html", 1) - if os.Getenv("CI") == "" { - t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path) - } + if t.Failed() { + htmlPath := strings.Replace(path, ".json", ".html", 1) + t.Errorf("detected autobahn violation, see %q", htmlPath) } } + +func unusedListenAddr() (_ string, err error) { + defer errd.Wrap(&err, "failed to get unused listen address") + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", err + } + l.Close() + return l.Addr().String(), nil +} + +func tempJSONFile(v interface{}) (string, error) { + f, err := ioutil.TempFile("", "temp.json") + if err != nil { + return "", fmt.Errorf("temp file: %w", err) + } + defer f.Close() + + e := json.NewEncoder(f) + e.SetIndent("", "\t") + err = e.Encode(v) + if err != nil { + return "", fmt.Errorf("json encode: %w", err) + } + + err = f.Close() + if err != nil { + return "", fmt.Errorf("close temp file: %w", err) + } + + return f.Name(), nil +} diff --git a/close.go b/close.go index 7ccdb173..c5c51c6e 100644 --- a/close.go +++ b/close.go @@ -147,6 +147,8 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { } func (c *Conn) waitCloseHandshake() error { + defer c.close(nil) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() diff --git a/compress.go b/compress.go index 2410cb4e..8c4dbe28 100644 --- a/compress.go +++ b/compress.go @@ -9,6 +9,14 @@ import ( "sync" ) +type CompressionOptions struct { + // Mode controls the compression mode. + Mode CompressionMode + + // Threshold controls the minimum size of a message before compression is applied. + Threshold int +} + // CompressionMode controls the modes available RFC 7692's deflate extension. // See https://tools.ietf.org/html/rfc7692 // @@ -29,14 +37,8 @@ const ( // The message will only be compressed if greater than 512 bytes. CompressionNoContextTakeover CompressionMode = iota - // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. - // This enables reusing the sliding window from previous messages. - // As most WebSocket protocols are repetitive, this can be very efficient. - // - // The message will only be compressed if greater than 128 bytes. - // - // If the peer negotiates NoContextTakeover on the client or server side, it will be - // used instead as this is required by the RFC. + // Unimplemented for now due to limitations in compress/flate. + // See https://github.com/golang/go/issues/31514#issuecomment-569668619 CompressionContextTakeover // CompressionDisabled disables the deflate extension. diff --git a/conn.go b/conn.go index 061c4517..5ccf9f91 100644 --- a/conn.go +++ b/conn.go @@ -176,7 +176,7 @@ func (c *Conn) timeoutLoop() { } } -func (c *Conn) deflate() bool { +func (c *Conn) flate() bool { return c.copts != nil } @@ -262,5 +262,8 @@ func (m *mu) TryLock() bool { } func (m *mu) Unlock() { - <-m.ch + select { + case <-m.ch: + default: + } } diff --git a/go.mod b/go.mod index 06098485..01ec18f7 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,5 @@ require ( github.com/gobwas/ws v1.0.2 github.com/golang/protobuf v1.3.2 github.com/gorilla/websocket v1.4.1 - github.com/mattn/goveralls v0.0.4 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 - golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea // indirect ) diff --git a/go.sum b/go.sum index df11eba9..864efaa7 100644 --- a/go.sum +++ b/go.sum @@ -102,8 +102,6 @@ github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNx github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= -github.com/mattn/goveralls v0.0.4 h1:/mdWfiU2y8kZ48EtgByYev/XT3W4dkTuKLOJJsh/r+o= -github.com/mattn/goveralls v0.0.4/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/nkovacs/streamquote v0.0.0-20170412213628-49af9bddb229/go.mod h1:0aYXnNPJ8l7uZxf45rWW1a/uME32OF0rhiYGNQ2oF2E= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -129,7 +127,6 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g= golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -150,7 +147,6 @@ golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -209,11 +205,8 @@ golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2 h1:EtTFh6h4SAKemS+CURDMTDIANuduG5zKEXShyy18bGA= golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea h1:mtRJM/ln5qwEigajtnZtuARALEPOooGf5lwkM5a9tt4= -golang.org/x/tools v0.0.0-20191218225520-84f0c7cf60ea/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/read.go b/read.go index dc59f9f4..517022b5 100644 --- a/read.go +++ b/read.go @@ -79,7 +79,7 @@ func newMsgReader(c *Conn) *msgReader { } mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768) - if c.deflate() && mr.contextTakeover() { + if c.flate() && mr.flateContextTakeover() { mr.initFlateReader() } @@ -87,30 +87,27 @@ func newMsgReader(c *Conn) *msgReader { } func (mr *msgReader) initFlateReader() { - mr.deflateReader = getFlateReader(readerFunc(mr.read)) - mr.limitReader.r = mr.deflateReader + mr.flateReader = getFlateReader(readerFunc(mr.read)) + mr.limitReader.r = mr.flateReader } func (mr *msgReader) close() { mr.c.readMu.Lock(context.Background()) defer mr.c.readMu.Unlock() - if mr.deflateReader != nil { - putFlateReader(mr.deflateReader) - mr.deflateReader = nil - } + mr.returnFlateReader() } -func (mr *msgReader) contextTakeover() bool { +func (mr *msgReader) flateContextTakeover() bool { if mr.c.client { - return mr.c.copts.serverNoContextTakeover + return !mr.c.copts.serverNoContextTakeover } - return mr.c.copts.clientNoContextTakeover + return !mr.c.copts.clientNoContextTakeover } func (c *Conn) readRSV1Illegal(h header) bool { // If compression is enabled, rsv1 is always illegal. - if !c.deflate() { + if !c.flate() { return true } // rsv1 is only allowed on data frames beginning messages. @@ -269,6 +266,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { err = fmt.Errorf("received close frame: %w", ce) c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) + c.close(err) return err } @@ -304,11 +302,11 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro type msgReader struct { c *Conn - ctx context.Context - deflate bool - deflateReader io.Reader - deflateTail strings.Reader - limitReader *limitReader + ctx context.Context + deflate bool + flateReader io.Reader + deflateTail strings.Reader + limitReader *limitReader fin bool payloadLength int64 @@ -319,7 +317,7 @@ func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.deflate = h.rsv1 if mr.deflate { - if !mr.contextTakeover() { + if !mr.flateContextTakeover() { mr.initFlateReader() } mr.deflateTail.Reset(deflateMessageTail) @@ -335,8 +333,19 @@ func (mr *msgReader) setFrame(h header) { mr.maskKey = h.maskKey } -func (mr *msgReader) Read(p []byte) (_ int, err error) { +func (mr *msgReader) Read(p []byte) (n int, err error) { defer func() { + r := recover() + if r != nil { + if r != "ANMOL" { + panic(r) + } + err = io.EOF + if !mr.flateContextTakeover() { + mr.returnFlateReader() + } + } + errd.Wrap(&err, "failed to read") if errors.Is(err, io.EOF) { err = io.EOF @@ -349,24 +358,27 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) { } defer mr.c.readMu.Unlock() - if mr.payloadLength == 0 && mr.fin { - if mr.c.deflate() && !mr.contextTakeover() { - if mr.deflateReader != nil { - putFlateReader(mr.deflateReader) - mr.deflateReader = nil - } - } - return 0, io.EOF - } - return mr.limitReader.Read(p) } +func (mr *msgReader) returnFlateReader() { + if mr.flateReader != nil { + putFlateReader(mr.flateReader) + mr.flateReader = nil + } +} + func (mr *msgReader) read(p []byte) (int, error) { if mr.payloadLength == 0 { - if mr.fin && mr.deflate { - n, _ := mr.deflateTail.Read(p) - return n, nil + if mr.fin { + if mr.deflate { + if mr.deflateTail.Len() == 0 { + panic("ANMOL") + } + n, _ := mr.deflateTail.Read(p) + return n, nil + } + return 0, io.EOF } h, err := mr.c.readLoop(mr.ctx) diff --git a/write.go b/write.go index 526b3b66..de20e041 100644 --- a/write.go +++ b/write.go @@ -55,7 +55,7 @@ func newMsgWriter(c *Conn) *msgWriter { mw.trimWriter = &trimLastFourBytesWriter{ w: writerFunc(mw.write), } - if c.deflate() && mw.deflateContextTakeover() { + if c.flate() && mw.flateContextTakeover() { mw.ensureFlateWriter() } @@ -63,14 +63,16 @@ func newMsgWriter(c *Conn) *msgWriter { } func (mw *msgWriter) ensureFlateWriter() { - mw.flateWriter = getFlateWriter(mw.trimWriter) + if mw.flateWriter == nil { + mw.flateWriter = getFlateWriter(mw.trimWriter) + } } -func (mw *msgWriter) deflateContextTakeover() bool { +func (mw *msgWriter) flateContextTakeover() bool { if mw.c.client { - return mw.c.copts.clientNoContextTakeover + return !mw.c.copts.clientNoContextTakeover } - return mw.c.copts.serverNoContextTakeover + return !mw.c.copts.serverNoContextTakeover } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { @@ -87,7 +89,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return 0, err } - if !c.deflate() { + if !c.flate() { // Fast single frame path. defer c.msgWriter.mu.Unlock() return c.writeFrame(ctx, true, c.msgWriter.opcode, p) @@ -107,11 +109,12 @@ type msgWriter struct { mu *mu - deflate bool - ctx context.Context - opcode opcode - closed bool + ctx context.Context + opcode opcode + closed bool + // TODO pass down into writeFrame + flate bool trimWriter *trimLastFourBytesWriter flateWriter *flate.Writer } @@ -125,7 +128,7 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { mw.closed = false mw.ctx = ctx mw.opcode = opcode(typ) - mw.deflate = false + mw.flate = false return nil } @@ -137,13 +140,14 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { return 0, errors.New("cannot use closed writer") } - if mw.c.deflate() { - if !mw.deflate { - if !mw.deflateContextTakeover() { + if mw.c.flate() { + if !mw.flate { + mw.flate = true + + if !mw.flateContextTakeover() { mw.ensureFlateWriter() } mw.trimWriter.reset() - mw.deflate = true } return mw.flateWriter.Write(p) @@ -170,7 +174,7 @@ func (mw *msgWriter) Close() (err error) { } mw.closed = true - if mw.c.deflate() { + if mw.flate { err = mw.flateWriter.Flush() if err != nil { return fmt.Errorf("failed to flush flate writer: %w", err) @@ -182,9 +186,9 @@ func (mw *msgWriter) Close() (err error) { return fmt.Errorf("failed to write fin frame: %w", err) } - if mw.deflate && !mw.deflateContextTakeover() { + if mw.c.flate() && !mw.flateContextTakeover() && mw.flateWriter != nil { putFlateWriter(mw.flateWriter) - mw.deflate = false + mw.flateWriter = nil } mw.mu.Unlock() @@ -192,9 +196,10 @@ func (mw *msgWriter) Close() (err error) { } func (mw *msgWriter) close() { - if mw.c.deflate() && mw.deflateContextTakeover() { + if mw.flateWriter != nil && mw.flateContextTakeover() { mw.mu.Lock(context.Background()) putFlateWriter(mw.flateWriter) + mw.flateWriter = nil } } @@ -236,7 +241,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte } c.writeHeader.rsv1 = false - if c.msgWriter.deflate && (opcode == opText || opcode == opBinary) { + if c.flate() && (opcode == opText || opcode == opBinary) { c.writeHeader.rsv1 = true } From 8c87970e1fbf809deab88d9a9637822616a6f676 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 4 Jan 2020 02:19:14 -0500 Subject: [PATCH 21/55] Add slidingWindowReader --- accept.go | 4 +++- autobahn_test.go | 8 ++++++++ compress.go | 52 ++++++++++++++++++++++++++++++++++++++++++++---- dial.go | 6 +++--- 4 files changed, 62 insertions(+), 8 deletions(-) diff --git a/accept.go b/accept.go index f16180f0..f030e4aa 100644 --- a/accept.go +++ b/accept.go @@ -37,7 +37,9 @@ type AcceptOptions struct { // If used incorrectly your WebSocket server will be open to CSRF attacks. InsecureSkipVerify bool - CompressionMode CompressionMode + // CompressionOptions controls the compression options. + // See docs on the CompressionOptions type. + CompressionOptions CompressionOptions } // Accept accepts a WebSocket handshake from a client and upgrades the diff --git a/autobahn_test.go b/autobahn_test.go index 16384b27..1c39887c 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -103,11 +103,19 @@ func testClientAutobahn(t *testing.T) { assert.Success(t, "wstestCaseCount", err) t.Run("cases", func(t *testing.T) { + // Max 8 cases running at a time. + mu := make(chan struct{}, 8) + for i := 1; i <= cases; i++ { i := i t.Run("", func(t *testing.T) { t.Parallel() + mu <- struct{}{} + defer func() { + <-mu + }() + ctx, cancel := context.WithTimeout(ctx, time.Second*45) defer cancel() diff --git a/compress.go b/compress.go index 8c4dbe28..62cc9cd3 100644 --- a/compress.go +++ b/compress.go @@ -37,8 +37,14 @@ const ( // The message will only be compressed if greater than 512 bytes. CompressionNoContextTakeover CompressionMode = iota - // Unimplemented for now due to limitations in compress/flate. - // See https://github.com/golang/go/issues/31514#issuecomment-569668619 + // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. + // This enables reusing the sliding window from previous messages. + // As most WebSocket protocols are repetitive, this can be very efficient. + // + // The message will only be compressed if greater than 128 bytes. + // + // If the peer negotiates NoContextTakeover on the client or server side, it will be + // used instead as this is required by the RFC. CompressionContextTakeover // CompressionDisabled disables the deflate extension. @@ -151,10 +157,10 @@ func putFlateReader(fr io.Reader) { var flateWriterPool sync.Pool -func getFlateWriter(w io.Writer) *flate.Writer { +func getFlateWriter(w io.Writer, dict []byte) *flate.Writer { fw, ok := flateWriterPool.Get().(*flate.Writer) if !ok { - fw, _ = flate.NewWriter(w, flate.BestSpeed) + fw, _ = flate.NewWriterDict(w, flate.BestSpeed, dict) return fw } fw.Reset(w) @@ -164,3 +170,41 @@ func getFlateWriter(w io.Writer) *flate.Writer { func putFlateWriter(w *flate.Writer) { flateWriterPool.Put(w) } + +type slidingWindowReader struct { + window []byte + + r io.Reader +} + +func (r slidingWindowReader) Read(p []byte) (int, error) { + n, err := r.r.Read(p) + p = p[:n] + + r.append(p) + + return n, err +} + +func (r slidingWindowReader) append(p []byte) { + if len(r.window) <= cap(r.window) { + r.window = append(r.window, p...) + } + + if len(p) > cap(r.window) { + p = p[len(p)-cap(r.window):] + } + + // p now contains at max the last window bytes + // so we need to be able to append all of it to r.window. + // Shift as many bytes from r.window as needed. + + // Maximum window size minus current window minus extra gives + // us the number of bytes that need to be shifted. + off := len(r.window) + len(p) - cap(r.window) + + r.window = append(r.window[:0], r.window[off:]...) + copy(r.window, r.window[off:]) + copy(r.window[len(r.window)-len(p):], p) + return +} diff --git a/dial.go b/dial.go index 6cde30e7..43408f20 100644 --- a/dial.go +++ b/dial.go @@ -33,9 +33,9 @@ type DialOptions struct { // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string - // CompressionMode sets the compression mode. - // See the docs on CompressionMode. - CompressionMode CompressionMode + // CompressionOptions controls the compression options. + // See docs on the CompressionOptions type. + CompressionOptions CompressionOptions } // Dial performs a WebSocket handshake on url. From aaf4b458c6a66df98da8375425cb54ec47e9540b Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 25 Jan 2020 20:58:09 -0600 Subject: [PATCH 22/55] Up test coverage of accept.go to 100% --- accept.go | 6 ++- accept_test.go | 140 +++++++++++++++++++++++++++++++++++++++++++++++++ compress.go | 14 +++-- conn_test.go | 10 ++-- dial.go | 4 +- write.go | 2 +- 6 files changed, 164 insertions(+), 12 deletions(-) diff --git a/accept.go b/accept.go index f030e4aa..d9b4bf90 100644 --- a/accept.go +++ b/accept.go @@ -92,7 +92,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con w.Header().Set("Sec-WebSocket-Protocol", subproto) } - copts, err := acceptCompression(r, w, opts.CompressionMode) + copts, err := acceptCompression(r, w, opts.CompressionOptions.Mode) if err != nil { return nil, err } @@ -201,7 +201,9 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi case "server_no_context_takeover": copts.serverNoContextTakeover = true continue - case "client_max_window_bits", "server-max-window-bits": + } + + if strings.HasPrefix(p, "client_max_window_bits") || strings.HasPrefix(p, "server_max_window_bits") { continue } diff --git a/accept_test.go b/accept_test.go index 2a784d19..8a9e9198 100644 --- a/accept_test.go +++ b/accept_test.go @@ -3,6 +3,10 @@ package websocket import ( + "bufio" + "errors" + "net" + "net/http" "net/http/httptest" "strings" "testing" @@ -23,6 +27,38 @@ func TestAccept(t *testing.T) { assert.ErrorContains(t, "Accept", err, "protocol violation") }) + t.Run("badOrigin", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Origin", "harhar.com") + + _, err := Accept(w, r, nil) + assert.ErrorContains(t, "Accept", err, "request Origin \"harhar.com\" is not authorized for Host") + }) + + t.Run("badCompression", func(t *testing.T) { + t.Parallel() + + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + } + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") + + _, err := Accept(w, r, nil) + assert.ErrorContains(t, "Accept", err, "unsupported permessage-deflate parameter") + }) + t.Run("requireHttpHijacker", func(t *testing.T) { t.Parallel() @@ -36,6 +72,26 @@ func TestAccept(t *testing.T) { _, err := Accept(w, r, nil) assert.ErrorContains(t, "Accept", err, "http.ResponseWriter does not implement http.Hijacker") }) + + t.Run("badHijack", func(t *testing.T) { + t.Parallel() + + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { + return nil, nil, errors.New("haha") + }, + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + + _, err := Accept(w, r, nil) + assert.ErrorContains(t, "Accept", err, "failed to hijack connection") + }) } func Test_verifyClientHandshake(t *testing.T) { @@ -243,5 +299,89 @@ func Test_authenticateOrigin(t *testing.T) { } func Test_acceptCompression(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + mode CompressionMode + reqSecWebSocketExtensions string + respSecWebSocketExtensions string + expCopts *compressionOptions + error bool + }{ + { + name: "disabled", + mode: CompressionDisabled, + expCopts: nil, + }, + { + name: "noClientSupport", + mode: CompressionNoContextTakeover, + expCopts: nil, + }, + { + name: "permessage-deflate", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", + respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + }, + { + name: "permessage-deflate/error", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "permessage-deflate; meow", + error: true, + }, + { + name: "x-webkit-deflate-frame", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", + respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover", + expCopts: &compressionOptions{ + clientNoContextTakeover: true, + serverNoContextTakeover: true, + }, + }, + { + name: "x-webkit-deflate/error", + mode: CompressionNoContextTakeover, + reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits", + error: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) + + w := httptest.NewRecorder() + copts, err := acceptCompression(r, w, tc.mode) + if tc.error { + assert.Error(t, "acceptCompression", err) + return + } + + assert.Success(t, "acceptCompression", err) + assert.Equal(t, "compresssionOpts", tc.expCopts, copts) + assert.Equal(t, "respHeader", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) + }) + } +} + +type mockHijacker struct { + http.ResponseWriter + hijack func() (net.Conn, *bufio.ReadWriter, error) +} + +var _ http.Hijacker = mockHijacker{} +func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return mj.hijack() } diff --git a/compress.go b/compress.go index 62cc9cd3..fd2535cc 100644 --- a/compress.go +++ b/compress.go @@ -9,15 +9,22 @@ import ( "sync" ) +// CompressionOptions represents the available deflate extension options. +// See https://tools.ietf.org/html/rfc7692 type CompressionOptions struct { // Mode controls the compression mode. + // + // See docs on CompressionMode. Mode CompressionMode // Threshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 256 bytes + // for CompressionContextTakeover. Threshold int } -// CompressionMode controls the modes available RFC 7692's deflate extension. +// CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 // // A compatibility layer is implemented for the older deflate-frame extension used @@ -31,7 +38,7 @@ const ( // for every message. This applies to both server and client side. // // This means less efficient compression as the sliding window from previous messages - // will not be used but the memory overhead will be much lower if the connections + // will not be used but the memory overhead will be lower if the connections // are long lived and seldom used. // // The message will only be compressed if greater than 512 bytes. @@ -40,8 +47,7 @@ const ( // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. // This enables reusing the sliding window from previous messages. // As most WebSocket protocols are repetitive, this can be very efficient. - // - // The message will only be compressed if greater than 128 bytes. + // It carries an overhead of 64 kB for every connection compared to CompressionNoContextTakeover. // // If the peer negotiates NoContextTakeover on the client or server side, it will be // used instead as this is required by the RFC. diff --git a/conn_test.go b/conn_test.go index 9b311a87..c8663b47 100644 --- a/conn_test.go +++ b/conn_test.go @@ -26,7 +26,9 @@ func TestConn(t *testing.T) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, - CompressionMode: websocket.CompressionNoContextTakeover, + CompressionOptions: websocket.CompressionOptions{ + Mode: websocket.CompressionNoContextTakeover, + }, }) assert.Success(t, "accept", err) defer c.Close(websocket.StatusInternalError, "") @@ -42,8 +44,10 @@ func TestConn(t *testing.T) { defer cancel() opts := &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - CompressionMode: websocket.CompressionNoContextTakeover, + Subprotocols: []string{"echo"}, + CompressionOptions: websocket.CompressionOptions{ + Mode: websocket.CompressionNoContextTakeover, + }, } opts.HTTPClient = s.Client() diff --git a/dial.go b/dial.go index 43408f20..af945011 100644 --- a/dial.go +++ b/dial.go @@ -136,8 +136,8 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } - if opts.CompressionMode != CompressionDisabled { - copts := opts.CompressionMode.opts() + if opts.CompressionOptions.Mode != CompressionDisabled { + copts := opts.CompressionOptions.Mode.opts() copts.setHeader(req.Header) } diff --git a/write.go b/write.go index de20e041..33d20c1d 100644 --- a/write.go +++ b/write.go @@ -64,7 +64,7 @@ func newMsgWriter(c *Conn) *msgWriter { func (mw *msgWriter) ensureFlateWriter() { if mw.flateWriter == nil { - mw.flateWriter = getFlateWriter(mw.trimWriter) + mw.flateWriter = getFlateWriter(mw.trimWriter, nil) } } From 6b765363d1e5ce21e6ca3bdb7bde03ecba1a2a98 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 29 Jan 2020 22:08:29 -0600 Subject: [PATCH 23/55] Up dial coverage to 100% --- .github/ISSUE_TEMPLATE.md | 3 - ci/image/Dockerfile | 2 +- conn.go | 2 +- dial.go | 13 +-- dial_test.go | 165 +++++++++++++++++++++++++++++--------- doc.go | 3 +- internal/bpool/bpool.go | 6 +- ws_js.go | 2 +- wspb/wspb.go | 16 ++-- 9 files changed, 151 insertions(+), 61 deletions(-) delete mode 100644 .github/ISSUE_TEMPLATE.md diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md deleted file mode 100644 index 7b580937..00000000 --- a/.github/ISSUE_TEMPLATE.md +++ /dev/null @@ -1,3 +0,0 @@ - diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index bfc05fc8..070c50e6 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -6,7 +6,7 @@ RUN apt-get install -y chromium ENV GOFLAGS="-mod=readonly" ENV PAGER=cat ENV CI=true -ENV MAKEFLAGS="--jobs=8 --output-sync=target" +ENV MAKEFLAGS="--jobs=16 --output-sync=target" RUN npm install -g prettier RUN go get golang.org/x/tools/cmd/stringer diff --git a/conn.go b/conn.go index 5ccf9f91..a0176495 100644 --- a/conn.go +++ b/conn.go @@ -22,7 +22,7 @@ type MessageType int const ( // MessageText is for UTF-8 encoded text messages like JSON. MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like Protobufs. + // MessageBinary is for binary messages like protobufs. MessageBinary ) diff --git a/dial.go b/dial.go index af945011..58c0a9c5 100644 --- a/dial.go +++ b/dial.go @@ -50,10 +50,10 @@ type DialOptions struct { // in net/http to perform WebSocket handshakes. // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { - return dial(ctx, u, opts) + return dial(ctx, u, opts, nil) } -func dial(ctx context.Context, urls string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { +func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { defer errd.Wrap(&err, "failed to WebSocket dial") if opts == nil { @@ -67,7 +67,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions) (_ *Conn, _ *http opts.HTTPHeader = http.Header{} } - secWebSocketKey, err := secWebSocketKey() + secWebSocketKey, err := secWebSocketKey(rand) if err != nil { return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } @@ -148,9 +148,12 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe return resp, nil } -func secWebSocketKey() (string, error) { +func secWebSocketKey(rr io.Reader) (string, error) { + if rr == nil { + rr = rand.Reader + } b := make([]byte, 16) - _, err := io.ReadFull(rand.Reader, b) + _, err := io.ReadFull(rr, b) if err != nil { return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) } diff --git a/dial_test.go b/dial_test.go index 6286f0ff..4314f98e 100644 --- a/dial_test.go +++ b/dial_test.go @@ -4,58 +4,117 @@ package websocket import ( "context" + "crypto/rand" + "io" + "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" "time" + + "cdr.dev/slog/sloggers/slogtest/assert" ) func TestBadDials(t *testing.T) { t.Parallel() - testCases := []struct { - name string - url string - opts *DialOptions - }{ - { - name: "badURL", - url: "://noscheme", - }, - { - name: "badURLScheme", - url: "ftp://nhooyr.io", - }, - { - name: "badHTTPClient", - url: "ws://nhooyr.io", - opts: &DialOptions{ - HTTPClient: &http.Client{ - Timeout: time.Minute, + t.Run("badReq", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + url string + opts *DialOptions + rand readerFunc + }{ + { + name: "badURL", + url: "://noscheme", + }, + { + name: "badURLScheme", + url: "ftp://nhooyr.io", + }, + { + name: "badHTTPClient", + url: "ws://nhooyr.io", + opts: &DialOptions{ + HTTPClient: &http.Client{ + Timeout: time.Minute, + }, }, }, - }, - { - name: "badTLS", - url: "wss://totallyfake.nhooyr.io", - }, - } + { + name: "badTLS", + url: "wss://totallyfake.nhooyr.io", + }, + { + name: "badReader", + rand: func(p []byte) (int, error) { + return 0, io.EOF + }, + }, + } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() - _, _, err := Dial(ctx, tc.url, tc.opts) - if err == nil { - t.Fatalf("expected non nil error: %+v", err) - } + if tc.rand == nil { + tc.rand = rand.Reader.Read + } + + _, _, err := dial(ctx, tc.url, tc.opts, tc.rand) + assert.Error(t, "dial", err) + }) + } + }) + + t.Run("badResponse", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("hi")), + }, nil + }), }) - } + assert.ErrorContains(t, "dial", err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") + }) + + t.Run("badBody", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + rt := func(r *http.Request) (*http.Response, error) { + h := http.Header{} + h.Set("Connection", "Upgrade") + h.Set("Upgrade", "websocket") + h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key"))) + + return &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + Header: h, + Body: ioutil.NopCloser(strings.NewReader("hi")), + }, nil + } + + _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ + HTTPClient: mockHTTPClient(rt), + }) + assert.ErrorContains(t, "dial", err, "response body is not a io.ReadWriteCloser") + }) } func Test_verifyServerHandshake(t *testing.T) { @@ -110,6 +169,26 @@ func Test_verifyServerHandshake(t *testing.T) { }, success: false, }, + { + name: "unsupportedExtension", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Extensions", "meow") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "unsupportedDeflateParam", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, { name: "success", response: func(w http.ResponseWriter) { @@ -131,7 +210,7 @@ func Test_verifyServerHandshake(t *testing.T) { resp := w.Result() r := httptest.NewRequest("GET", "/", nil) - key, err := secWebSocketKey() + key, err := secWebSocketKey(rand.Reader) if err != nil { t.Fatal(err) } @@ -151,3 +230,15 @@ func Test_verifyServerHandshake(t *testing.T) { }) } } + +func mockHTTPClient(fn roundTripperFunc) *http.Client { + return &http.Client{ + Transport: fn, + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} diff --git a/doc.go b/doc.go index 6847d537..c8f5550b 100644 --- a/doc.go +++ b/doc.go @@ -12,7 +12,7 @@ // // The examples are the best way to understand how to correctly use the library. // -// The wsjson and wspb subpackages contain helpers for JSON and Protobuf messages. +// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages. // // More documentation at https://nhooyr.io/websocket. // @@ -28,5 +28,4 @@ // - Conn.Ping is no-op // - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op // - *http.Response from Dial is &http.Response{} on success -// package websocket // import "nhooyr.io/websocket" diff --git a/internal/bpool/bpool.go b/internal/bpool/bpool.go index e2c5f76a..aa826fba 100644 --- a/internal/bpool/bpool.go +++ b/internal/bpool/bpool.go @@ -5,12 +5,12 @@ import ( "sync" ) -var pool sync.Pool +var bpool sync.Pool // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { - b := pool.Get() + b := bpool.Get() if b == nil { return &bytes.Buffer{} } @@ -20,5 +20,5 @@ func Get() *bytes.Buffer { // Put returns a buffer into the pool. func Put(b *bytes.Buffer) { b.Reset() - pool.Put(b) + bpool.Put(b) } diff --git a/ws_js.go b/ws_js.go index 950aa01b..2aaef738 100644 --- a/ws_js.go +++ b/ws_js.go @@ -23,7 +23,7 @@ type MessageType int const ( // MessageText is for UTF-8 encoded text messages like JSON. MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like Protobufs. + // MessageBinary is for binary messages like protobufs. MessageBinary ) diff --git a/wspb/wspb.go b/wspb/wspb.go index 666c6fa5..e43042d5 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -13,14 +13,14 @@ import ( "nhooyr.io/websocket/internal/errd" ) -// Read reads a Protobuf message from c into v. +// Read reads a protobuf message from c into v. // It will reuse buffers in between calls to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { return read(ctx, c, v) } func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { - defer errd.Wrap(&err, "failed to read Protobuf message") + defer errd.Wrap(&err, "failed to read protobuf message") typ, r, err := c.Reader(ctx) if err != nil { @@ -29,7 +29,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { if typ != websocket.MessageBinary { c.Close(websocket.StatusUnsupportedData, "expected binary message") - return fmt.Errorf("expected binary message for Protobuf but got: %v", typ) + return fmt.Errorf("expected binary message for protobuf but got: %v", typ) } b := bpool.Get() @@ -42,21 +42,21 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { err = proto.Unmarshal(b.Bytes(), v) if err != nil { - c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal Protobuf") - return fmt.Errorf("failed to unmarshal Protobuf: %w", err) + c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") + return fmt.Errorf("failed to unmarshal protobuf: %w", err) } return nil } -// Write writes the Protobuf message v to c. +// Write writes the protobuf message v to c. // It will reuse buffers in between calls to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { return write(ctx, c, v) } func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { - defer errd.Wrap(&err, "failed to write Protobuf message") + defer errd.Wrap(&err, "failed to write protobuf message") b := bpool.Get() pb := proto.NewBuffer(b.Bytes()) @@ -66,7 +66,7 @@ func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) err = pb.Marshal(v) if err != nil { - return fmt.Errorf("failed to marshal Protobuf: %w", err) + return fmt.Errorf("failed to marshal protobuf: %w", err) } return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) From 0f115ed9aa51110aa5dad556673d9dab606bd479 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 30 Jan 2020 22:19:15 -0600 Subject: [PATCH 24/55] Add Go 1.12 support Closes #182 --- accept.go | 28 ++++++++++++++-------------- accept_test.go | 4 ++-- autobahn_test.go | 11 ++++++----- close.go | 23 ++++++++++++----------- conn.go | 18 +++++++++--------- conn_test.go | 10 ++++------ dial.go | 32 ++++++++++++++++---------------- dial_test.go | 4 +--- example_echo_test.go | 8 ++++---- go.mod | 3 ++- internal/errd/wrap.go | 8 +++----- netconn.go | 5 +++-- read.go | 38 ++++++++++++++++++-------------------- write.go | 41 ++++++++++++++++++----------------------- ws_js.go | 36 ++++++++++++++++++------------------ wsjson/wsjson.go | 9 +++++---- wspb/wspb.go | 8 ++++---- 17 files changed, 139 insertions(+), 147 deletions(-) diff --git a/accept.go b/accept.go index d9b4bf90..ac7f2de1 100644 --- a/accept.go +++ b/accept.go @@ -6,14 +6,14 @@ import ( "bytes" "crypto/sha1" "encoding/base64" - "errors" - "fmt" "io" "net/http" "net/textproto" "net/url" "strings" + "golang.org/x/xerrors" + "nhooyr.io/websocket/internal/errd" ) @@ -76,7 +76,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con hj, ok := w.(http.Hijacker) if !ok { - err = errors.New("http.ResponseWriter does not implement http.Hijacker") + err = xerrors.New("http.ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) return nil, err } @@ -101,7 +101,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con netConn, brw, err := hj.Hijack() if err != nil { - err = fmt.Errorf("failed to hijack connection: %w", err) + err = xerrors.Errorf("failed to hijack connection: %w", err) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } @@ -122,27 +122,27 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con func verifyClientRequest(r *http.Request) error { if !r.ProtoAtLeast(1, 1) { - return fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) + return xerrors.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } if !headerContainsToken(r.Header, "Connection", "Upgrade") { - return fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) + return xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } if !headerContainsToken(r.Header, "Upgrade", "websocket") { - return fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) + return xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) } if r.Method != "GET" { - return fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) + return xerrors.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) } if r.Header.Get("Sec-WebSocket-Version") != "13" { - return fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) + return xerrors.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } if r.Header.Get("Sec-WebSocket-Key") == "" { - return errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") + return xerrors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } return nil @@ -153,10 +153,10 @@ func authenticateOrigin(r *http.Request) error { if origin != "" { u, err := url.Parse(origin) if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err) } if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + return xerrors.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) } } return nil @@ -207,7 +207,7 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi continue } - err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + err := xerrors.Errorf("unsupported permessage-deflate parameter: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } @@ -237,7 +237,7 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com // // Either way, we're only implementing this for webkit which never sends the max_window_bits // parameter so we don't need to worry about it. - err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) + err := xerrors.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) http.Error(w, err.Error(), http.StatusBadRequest) return nil, err } diff --git a/accept_test.go b/accept_test.go index 8a9e9198..3e8b1f46 100644 --- a/accept_test.go +++ b/accept_test.go @@ -4,7 +4,6 @@ package websocket import ( "bufio" - "errors" "net" "net/http" "net/http/httptest" @@ -12,6 +11,7 @@ import ( "testing" "cdr.dev/slog/sloggers/slogtest/assert" + "golang.org/x/xerrors" ) func TestAccept(t *testing.T) { @@ -79,7 +79,7 @@ func TestAccept(t *testing.T) { w := mockHijacker{ ResponseWriter: httptest.NewRecorder(), hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) { - return nil, nil, errors.New("haha") + return nil, nil, xerrors.New("haha") }, } diff --git a/autobahn_test.go b/autobahn_test.go index 1c39887c..bcbf8671 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -17,6 +17,7 @@ import ( "time" "cdr.dev/slog/sloggers/slogtest/assert" + "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" @@ -166,7 +167,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er "exclude-cases": excludedAutobahnCases, }) if err != nil { - return "", nil, fmt.Errorf("failed to write spec: %w", err) + return "", nil, xerrors.Errorf("failed to write spec: %w", err) } ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) @@ -184,7 +185,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er wstest := exec.CommandContext(ctx, "wstest", args...) err = wstest.Start() if err != nil { - return "", nil, fmt.Errorf("failed to start wstest: %w", err) + return "", nil, xerrors.Errorf("failed to start wstest: %w", err) } return url, func() { @@ -267,7 +268,7 @@ func unusedListenAddr() (_ string, err error) { func tempJSONFile(v interface{}) (string, error) { f, err := ioutil.TempFile("", "temp.json") if err != nil { - return "", fmt.Errorf("temp file: %w", err) + return "", xerrors.Errorf("temp file: %w", err) } defer f.Close() @@ -275,12 +276,12 @@ func tempJSONFile(v interface{}) (string, error) { e.SetIndent("", "\t") err = e.Encode(v) if err != nil { - return "", fmt.Errorf("json encode: %w", err) + return "", xerrors.Errorf("json encode: %w", err) } err = f.Close() if err != nil { - return "", fmt.Errorf("close temp file: %w", err) + return "", xerrors.Errorf("close temp file: %w", err) } return f.Name(), nil diff --git a/close.go b/close.go index c5c51c6e..931160e6 100644 --- a/close.go +++ b/close.go @@ -5,11 +5,12 @@ package websocket import ( "context" "encoding/binary" - "errors" "fmt" "log" "time" + "golang.org/x/xerrors" + "nhooyr.io/websocket/internal/errd" ) @@ -60,7 +61,7 @@ const ( // CloseError is returned when the connection is closed with a status and reason. // -// Use Go 1.13's errors.As to check for this error. +// Use Go 1.13's xerrors.As to check for this error. // Also see the CloseStatus helper. type CloseError struct { Code StatusCode @@ -71,13 +72,13 @@ func (ce CloseError) Error() string { return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } -// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab +// CloseStatus is a convenience wrapper around Go 1.13's xerrors.As to grab // the status code from a CloseError. // // -1 will be returned if the passed error is nil or not a CloseError. func CloseStatus(err error) StatusCode { var ce CloseError - if errors.As(err, &ce) { + if xerrors.As(err, &ce) { return ce.Code } return -1 @@ -128,7 +129,7 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { c.wroteClose = true c.closeMu.Unlock() if closing { - return errors.New("already wrote close") + return xerrors.New("already wrote close") } ce := CloseError{ @@ -136,7 +137,7 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { Reason: reason, } - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + c.setCloseErr(xerrors.Errorf("sent close frame: %w", ce)) var p []byte if ce.Code != StatusNoStatusRcvd { @@ -185,7 +186,7 @@ func parseClosePayload(p []byte) (CloseError, error) { } if len(p) < 2 { - return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ @@ -194,7 +195,7 @@ func parseClosePayload(p []byte) (CloseError, error) { } if !validWireCloseCode(ce.Code) { - return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) + return CloseError{}, xerrors.Errorf("invalid status code %v", ce.Code) } return ce, nil @@ -234,11 +235,11 @@ const maxCloseReason = maxControlPayload - 2 func (ce CloseError) bytesErr() ([]byte, error) { if len(ce.Reason) > maxCloseReason { - return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) } if !validWireCloseCode(ce.Code) { - return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + return nil, xerrors.Errorf("status code %v cannot be set", ce.Code) } buf := make([]byte, 2+len(ce.Reason)) @@ -255,7 +256,7 @@ func (c *Conn) setCloseErr(err error) { func (c *Conn) setCloseErrLocked(err error) { if c.closeErr == nil { - c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + c.closeErr = xerrors.Errorf("WebSocket closed: %w", err) } } diff --git a/conn.go b/conn.go index a0176495..ab93e4e6 100644 --- a/conn.go +++ b/conn.go @@ -5,13 +5,13 @@ package websocket import ( "bufio" "context" - "errors" - "fmt" "io" "runtime" "strconv" "sync" "sync/atomic" + + "golang.org/x/xerrors" ) // MessageType represents the type of a WebSocket message. @@ -108,7 +108,7 @@ func newConn(cfg connConfig) *Conn { } runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) + c.close(xerrors.New("connection garbage collected")) }) go c.timeoutLoop() @@ -167,10 +167,10 @@ func (c *Conn) timeoutLoop() { case readCtx = <-c.readTimeout: case <-readCtx.Done(): - c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - go c.writeError(StatusPolicyViolation, errors.New("timed out")) + c.setCloseErr(xerrors.Errorf("read timed out: %w", readCtx.Err())) + go c.writeError(StatusPolicyViolation, xerrors.New("timed out")) case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + c.close(xerrors.Errorf("write timed out: %w", writeCtx.Err())) return } } @@ -192,7 +192,7 @@ func (c *Conn) Ping(ctx context.Context) error { err := c.ping(ctx, strconv.Itoa(int(p))) if err != nil { - return fmt.Errorf("failed to ping: %w", err) + return xerrors.Errorf("failed to ping: %w", err) } return nil } @@ -219,7 +219,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { case <-c.closed: return c.closeErr case <-ctx.Done(): - err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) + err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err()) c.close(err) return err case <-pong: @@ -244,7 +244,7 @@ func (m *mu) Lock(ctx context.Context) error { case <-m.c.closed: return m.c.closeErr case <-ctx.Done(): - err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) + err := xerrors.Errorf("failed to acquire lock: %w", ctx.Err()) m.c.close(err) return err case m.ch <- struct{}{}: diff --git a/conn_test.go b/conn_test.go index c8663b47..a65c332c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,7 +4,6 @@ package websocket_test import ( "context" - "fmt" "io" "net/http" "net/http/httptest" @@ -14,6 +13,7 @@ import ( "time" "cdr.dev/slog/sloggers/slogtest/assert" + "golang.org/x/xerrors" "nhooyr.io/websocket" ) @@ -67,9 +67,7 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request), closeFn2 := wsgrace(s.Config) return s, func() { err := closeFn2() - if err != nil { - tb.Fatal(err) - } + assert.Success(tb, "closeFn", err) } } @@ -96,7 +94,7 @@ func wsgrace(s *http.Server) (closeFn func() error) { err := s.Shutdown(ctx) if err != nil { - return fmt.Errorf("server shutdown failed: %v", err) + return xerrors.Errorf("server shutdown failed: %v", err) } t := time.NewTicker(time.Millisecond * 10) @@ -108,7 +106,7 @@ func wsgrace(s *http.Server) (closeFn func() error) { return nil } case <-ctx.Done(): - return fmt.Errorf("failed to wait for WebSocket connections: %v", ctx.Err()) + return xerrors.Errorf("failed to wait for WebSocket connections: %v", ctx.Err()) } } } diff --git a/dial.go b/dial.go index 58c0a9c5..f53d30ee 100644 --- a/dial.go +++ b/dial.go @@ -8,8 +8,6 @@ import ( "context" "crypto/rand" "encoding/base64" - "errors" - "fmt" "io" "io/ioutil" "net/http" @@ -17,6 +15,8 @@ import ( "strings" "sync" + "golang.org/x/xerrors" + "nhooyr.io/websocket/internal/errd" ) @@ -69,7 +69,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( secWebSocketKey, err := secWebSocketKey(rand) if err != nil { - return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) + return nil, nil, xerrors.Errorf("failed to generate Sec-WebSocket-Key: %w", err) } resp, err := handshakeRequest(ctx, urls, opts, secWebSocketKey) @@ -95,7 +95,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( rwc, ok := respBody.(io.ReadWriteCloser) if !ok { - return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) + return nil, resp, xerrors.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) } return newConn(connConfig{ @@ -110,12 +110,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWebSocketKey string) (*http.Response, error) { if opts.HTTPClient.Timeout > 0 { - return nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") + return nil, xerrors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") } u, err := url.Parse(urls) if err != nil { - return nil, fmt.Errorf("failed to parse url: %w", err) + return nil, xerrors.Errorf("failed to parse url: %w", err) } switch u.Scheme { @@ -124,7 +124,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe case "wss": u.Scheme = "https" default: - return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) + return nil, xerrors.Errorf("unexpected url scheme: %q", u.Scheme) } req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) @@ -143,7 +143,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe resp, err := opts.HTTPClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to send handshake request: %w", err) + return nil, xerrors.Errorf("failed to send handshake request: %w", err) } return resp, nil } @@ -155,26 +155,26 @@ func secWebSocketKey(rr io.Reader) (string, error) { b := make([]byte, 16) _, err := io.ReadFull(rr, b) if err != nil { - return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) + return "", xerrors.Errorf("failed to read random data from rand.Reader: %w", err) } return base64.StdEncoding.EncodeToString(b), nil } func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { - return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + return nil, xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } if !headerContainsToken(resp.Header, "Connection", "Upgrade") { - return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + return nil, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { - return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + return nil, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { - return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + return nil, xerrors.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", resp.Header.Get("Sec-WebSocket-Accept"), secWebSocketKey, ) @@ -200,7 +200,7 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error { } } - return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) + return xerrors.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } func verifyServerExtensions(h http.Header) (*compressionOptions, error) { @@ -211,7 +211,7 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) { ext := exts[0] if ext.name != "permessage-deflate" || len(exts) > 1 { - return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) + return nil, xerrors.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) } copts := &compressionOptions{} @@ -222,7 +222,7 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) { case "server_no_context_takeover": copts.serverNoContextTakeover = true default: - return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + return nil, xerrors.Errorf("unsupported permessage-deflate parameter: %q", p) } } diff --git a/dial_test.go b/dial_test.go index 4314f98e..3be52208 100644 --- a/dial_test.go +++ b/dial_test.go @@ -211,9 +211,7 @@ func Test_verifyServerHandshake(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) key, err := secWebSocketKey(rand.Reader) - if err != nil { - t.Fatal(err) - } + assert.Success(t, "secWebSocketKey", err) r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { diff --git a/example_echo_test.go b/example_echo_test.go index cd195d2e..1daec8a5 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -4,7 +4,6 @@ package websocket_test import ( "context" - "errors" "fmt" "io" "log" @@ -13,6 +12,7 @@ import ( "time" "golang.org/x/time/rate" + "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" @@ -78,7 +78,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { if c.Subprotocol() != "echo" { c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") - return errors.New("client does not speak echo sub protocol") + return xerrors.New("client does not speak echo sub protocol") } l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) @@ -88,7 +88,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { return nil } if err != nil { - return fmt.Errorf("failed to echo with %v: %w", r.RemoteAddr, err) + return xerrors.Errorf("failed to echo with %v: %w", r.RemoteAddr, err) } } } @@ -117,7 +117,7 @@ func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { _, err = io.Copy(w, r) if err != nil { - return fmt.Errorf("failed to io.Copy: %w", err) + return xerrors.Errorf("failed to io.Copy: %w", err) } err = w.Close() diff --git a/go.mod b/go.mod index 01ec18f7..5dc9b261 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module nhooyr.io/websocket -go 1.13 +go 1.12 require ( cdr.dev/slog v1.3.0 @@ -10,4 +10,5 @@ require ( github.com/golang/protobuf v1.3.2 github.com/gorilla/websocket v1.4.1 golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 + golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 ) diff --git a/internal/errd/wrap.go b/internal/errd/wrap.go index 6e779131..20de7743 100644 --- a/internal/errd/wrap.go +++ b/internal/errd/wrap.go @@ -1,14 +1,12 @@ package errd -import ( - "fmt" -) +import "golang.org/x/xerrors" -// Wrap wraps err with fmt.Errorf if err is non nil. +// Wrap wraps err with xerrors.Errorf if err is non nil. // Intended for use with defer and a named error return. // Inspired by https://github.com/golang/go/issues/32676. func Wrap(err *error, f string, v ...interface{}) { if *err != nil { - *err = fmt.Errorf(f+": %w", append(v, *err)...) + *err = xerrors.Errorf(f+": %w", append(v, *err)...) } } diff --git a/netconn.go b/netconn.go index 64aadf0b..a2d8f4f3 100644 --- a/netconn.go +++ b/netconn.go @@ -2,12 +2,13 @@ package websocket import ( "context" - "fmt" "io" "math" "net" "sync" "time" + + "golang.org/x/xerrors" ) // NetConn converts a *websocket.Conn into a net.Conn. @@ -107,7 +108,7 @@ func (c *netConn) Read(p []byte) (int, error) { return 0, err } if typ != c.msgType { - err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) + err := xerrors.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) c.c.Close(StatusUnsupportedData, err.Error()) return 0, err } diff --git a/read.go b/read.go index 517022b5..4b94f067 100644 --- a/read.go +++ b/read.go @@ -4,14 +4,14 @@ package websocket import ( "context" - "errors" - "fmt" "io" "io/ioutil" "strings" "sync/atomic" "time" + "golang.org/x/xerrors" + "nhooyr.io/websocket/internal/errd" ) @@ -79,10 +79,6 @@ func newMsgReader(c *Conn) *msgReader { } mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768) - if c.flate() && mr.flateContextTakeover() { - mr.initFlateReader() - } - return mr } @@ -125,13 +121,13 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { } if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { - err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) + err := xerrors.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) c.writeError(StatusProtocolError, err) return header{}, err } if !c.client && !h.masked { - return header{}, errors.New("received unmasked frame from client") + return header{}, xerrors.New("received unmasked frame from client") } switch h.opcode { @@ -142,12 +138,12 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) { if h.opcode == opClose && CloseStatus(err) != -1 { return header{}, err } - return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) + return header{}, xerrors.Errorf("failed to handle control frame %v: %w", h.opcode, err) } case opContinuation, opText, opBinary: return h, nil default: - err := fmt.Errorf("received unknown opcode %v", h.opcode) + err := xerrors.Errorf("received unknown opcode %v", h.opcode) c.writeError(StatusProtocolError, err) return header{}, err } @@ -198,7 +194,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { case <-ctx.Done(): return n, ctx.Err() default: - err = fmt.Errorf("failed to read frame payload: %w", err) + err = xerrors.Errorf("failed to read frame payload: %w", err) c.close(err) return n, err } @@ -215,13 +211,13 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { func (c *Conn) handleControl(ctx context.Context, h header) (err error) { if h.payloadLength < 0 || h.payloadLength > maxControlPayload { - err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) + err := xerrors.Errorf("received control frame payload with invalid length: %d", h.payloadLength) c.writeError(StatusProtocolError, err) return err } if !h.fin { - err := errors.New("received fragmented control frame") + err := xerrors.New("received fragmented control frame") c.writeError(StatusProtocolError, err) return err } @@ -258,12 +254,12 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { ce, err := parseClosePayload(b) if err != nil { - err = fmt.Errorf("received invalid close payload: %w", err) + err = xerrors.Errorf("received invalid close payload: %w", err) c.writeError(StatusProtocolError, err) return err } - err = fmt.Errorf("received close frame: %w", ce) + err = xerrors.Errorf("received close frame: %w", ce) c.setCloseErr(err) c.writeClose(ce.Code, ce.Reason) c.close(err) @@ -280,7 +276,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro defer c.readMu.Unlock() if !c.msgReader.fin { - return 0, nil, errors.New("previous message not read to completion") + return 0, nil, xerrors.New("previous message not read to completion") } h, err := c.readLoop(ctx) @@ -289,7 +285,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro } if h.opcode == opContinuation { - err := errors.New("received continuation frame without text or binary frame") + err := xerrors.New("received continuation frame without text or binary frame") c.writeError(StatusProtocolError, err) return 0, nil, err } @@ -347,7 +343,7 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { } errd.Wrap(&err, "failed to read") - if errors.Is(err, io.EOF) { + if xerrors.Is(err, io.EOF) { err = io.EOF } }() @@ -386,7 +382,7 @@ func (mr *msgReader) read(p []byte) (int, error) { return 0, err } if h.opcode != opContinuation { - err := errors.New("received new data message without finishing the previous message") + err := xerrors.New("received new data message without finishing the previous message") mr.c.writeError(StatusProtocolError, err) return 0, err } @@ -434,7 +430,7 @@ func (lr *limitReader) reset() { func (lr *limitReader) Read(p []byte) (int, error) { if lr.n <= 0 { - err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) + err := xerrors.Errorf("read limited at %v bytes", lr.limit.Load()) lr.c.writeError(StatusMessageTooBig, err) return 0, err } @@ -448,6 +444,8 @@ func (lr *limitReader) Read(p []byte) (int, error) { } type atomicInt64 struct { + // We do not use atomic.Load/StoreInt64 since it does not + // work on 32 bit computers but we need 64 bit integers. i atomic.Value } diff --git a/write.go b/write.go index 33d20c1d..db47ddbc 100644 --- a/write.go +++ b/write.go @@ -8,11 +8,11 @@ import ( "context" "crypto/rand" "encoding/binary" - "errors" - "fmt" "io" "time" + "golang.org/x/xerrors" + "nhooyr.io/websocket/internal/errd" ) @@ -28,7 +28,7 @@ import ( func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { w, err := c.writer(ctx, typ) if err != nil { - return nil, fmt.Errorf("failed to get writer: %w", err) + return nil, xerrors.Errorf("failed to get writer: %w", err) } return w, nil } @@ -42,7 +42,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) if err != nil { - return fmt.Errorf("failed to write msg: %w", err) + return xerrors.Errorf("failed to write msg: %w", err) } return nil } @@ -55,10 +55,6 @@ func newMsgWriter(c *Conn) *msgWriter { mw.trimWriter = &trimLastFourBytesWriter{ w: writerFunc(mw.write), } - if c.flate() && mw.flateContextTakeover() { - mw.ensureFlateWriter() - } - return mw } @@ -92,7 +88,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error if !c.flate() { // Fast single frame path. defer c.msgWriter.mu.Unlock() - return c.writeFrame(ctx, true, c.msgWriter.opcode, p) + return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } n, err := mw.Write(p) @@ -113,7 +109,6 @@ type msgWriter struct { opcode opcode closed bool - // TODO pass down into writeFrame flate bool trimWriter *trimLastFourBytesWriter flateWriter *flate.Writer @@ -137,7 +132,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") if mw.closed { - return 0, errors.New("cannot use closed writer") + return 0, xerrors.New("cannot use closed writer") } if mw.c.flate() { @@ -157,9 +152,9 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { } func (mw *msgWriter) write(p []byte) (int, error) { - n, err := mw.c.writeFrame(mw.ctx, false, mw.opcode, p) + n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { - return n, fmt.Errorf("failed to write data frame: %w", err) + return n, xerrors.Errorf("failed to write data frame: %w", err) } mw.opcode = opContinuation return n, nil @@ -170,20 +165,20 @@ func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") if mw.closed { - return errors.New("cannot use closed writer") + return xerrors.New("cannot use closed writer") } mw.closed = true if mw.flate { err = mw.flateWriter.Flush() if err != nil { - return fmt.Errorf("failed to flush flate writer: %w", err) + return xerrors.Errorf("failed to flush flate writer: %w", err) } } - _, err = mw.c.writeFrame(mw.ctx, true, mw.opcode, nil) + _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { - return fmt.Errorf("failed to write fin frame: %w", err) + return xerrors.Errorf("failed to write fin frame: %w", err) } if mw.c.flate() && !mw.flateContextTakeover() && mw.flateWriter != nil { @@ -207,15 +202,15 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - _, err := c.writeFrame(ctx, true, opcode, p) + _, err := c.writeFrame(ctx, true, false, opcode, p) if err != nil { - return fmt.Errorf("failed to write control frame %v: %w", opcode, err) + return xerrors.Errorf("failed to write control frame %v: %w", opcode, err) } return nil } // frame handles all writes to the connection. -func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { +func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) { err := c.writeFrameMu.Lock(ctx) if err != nil { return 0, err @@ -236,12 +231,12 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte c.writeHeader.masked = true err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey) if err != nil { - return 0, fmt.Errorf("failed to generate masking key: %w", err) + return 0, xerrors.Errorf("failed to generate masking key: %w", err) } } c.writeHeader.rsv1 = false - if c.flate() && (opcode == opText || opcode == opBinary) { + if flate && (opcode == opText || opcode == opBinary) { c.writeHeader.rsv1 = true } @@ -258,7 +253,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte if c.writeHeader.fin { err = c.bw.Flush() if err != nil { - return n, fmt.Errorf("failed to flush: %w", err) + return n, xerrors.Errorf("failed to flush: %w", err) } } diff --git a/ws_js.go b/ws_js.go index 2aaef738..3ce6f34d 100644 --- a/ws_js.go +++ b/ws_js.go @@ -3,14 +3,14 @@ package websocket // import "nhooyr.io/websocket" import ( "bytes" "context" - "errors" - "fmt" "io" "reflect" "runtime" "sync" "syscall/js" + "golang.org/x/xerrors" + "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" ) @@ -55,7 +55,7 @@ func (c *Conn) close(err error, wasClean bool) { runtime.SetFinalizer(c, nil) if !wasClean { - err = fmt.Errorf("unclean connection close: %w", err) + err = xerrors.Errorf("unclean connection close: %w", err) } c.setCloseErr(err) c.closeWasClean = wasClean @@ -100,7 +100,7 @@ func (c *Conn) init() { }) runtime.SetFinalizer(c, func(c *Conn) { - c.setCloseErr(errors.New("connection garbage collected")) + c.setCloseErr(xerrors.New("connection garbage collected")) c.closeWithInternal() }) } @@ -113,15 +113,15 @@ func (c *Conn) closeWithInternal() { // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { if c.isReadClosed.Load() == 1 { - return 0, nil, errors.New("WebSocket connection read closed") + return 0, nil, xerrors.New("WebSocket connection read closed") } typ, p, err := c.read(ctx) if err != nil { - return 0, nil, fmt.Errorf("failed to read: %w", err) + return 0, nil, xerrors.Errorf("failed to read: %w", err) } if int64(len(p)) > c.msgReadLimit.Load() { - err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit) + err := xerrors.Errorf("read limited at %v bytes", c.msgReadLimit) c.Close(StatusMessageTooBig, err.Error()) return 0, nil, err } @@ -174,7 +174,7 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { // to match the Go API. It can only error if the message type // is unexpected or the passed bytes contain invalid UTF-8 for // MessageText. - err := fmt.Errorf("failed to write: %w", err) + err := xerrors.Errorf("failed to write: %w", err) c.setCloseErr(err) c.closeWithInternal() return err @@ -192,7 +192,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { case MessageText: return c.ws.SendText(string(p)) default: - return fmt.Errorf("unexpected message type: %v", typ) + return xerrors.Errorf("unexpected message type: %v", typ) } } @@ -203,7 +203,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { - return fmt.Errorf("failed to close WebSocket: %w", err) + return xerrors.Errorf("failed to close WebSocket: %w", err) } return nil } @@ -212,13 +212,13 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { c.closingMu.Lock() defer c.closingMu.Unlock() - ce := fmt.Errorf("sent close: %w", CloseError{ + ce := xerrors.Errorf("sent close: %w", CloseError{ Code: code, Reason: reason, }) if c.isClosed() { - return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) + return xerrors.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) } c.setCloseErr(ce) @@ -253,7 +253,7 @@ type DialOptions struct { func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, error) { c, err := dial(ctx, url, opts) if err != nil { - return nil, resp, fmt.Errorf("failed to WebSocket dial %q: %w", url, err) + return nil, resp, xerrors.Errorf("failed to WebSocket dial %q: %w", url, err) } return c, nil } @@ -325,25 +325,25 @@ type writer struct { func (w writer) Write(p []byte) (int, error) { if w.closed { - return 0, errors.New("cannot write to closed writer") + return 0, xerrors.New("cannot write to closed writer") } n, err := w.b.Write(p) if err != nil { - return n, fmt.Errorf("failed to write message: %w", err) + return n, xerrors.Errorf("failed to write message: %w", err) } return n, nil } func (w writer) Close() error { if w.closed { - return errors.New("cannot close closed writer") + return xerrors.New("cannot close closed writer") } w.closed = true defer bpool.Put(w.b) err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) if err != nil { - return fmt.Errorf("failed to close writer: %w", err) + return xerrors.Errorf("failed to close writer: %w", err) } return nil } @@ -368,7 +368,7 @@ func (c *Conn) SetReadLimit(n int64) { func (c *Conn) setCloseErr(err error) { c.closeErrOnce.Do(func() { - c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + c.closeErr = xerrors.Errorf("WebSocket closed: %w", err) }) } diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 99996a69..e6f06a2f 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -4,7 +4,8 @@ package wsjson // import "nhooyr.io/websocket/wsjson" import ( "context" "encoding/json" - "fmt" + + "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" @@ -27,7 +28,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { if typ != websocket.MessageText { c.Close(websocket.StatusUnsupportedData, "expected text message") - return fmt.Errorf("expected text message for JSON but got: %v", typ) + return xerrors.Errorf("expected text message for JSON but got: %v", typ) } b := bpool.Get() @@ -41,7 +42,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { err = json.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") - return fmt.Errorf("failed to unmarshal JSON: %w", err) + return xerrors.Errorf("failed to unmarshal JSON: %w", err) } return nil @@ -65,7 +66,7 @@ func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { // a copy of the byte slice but Encoder does as it directly writes to w. err = json.NewEncoder(w).Encode(v) if err != nil { - return fmt.Errorf("failed to marshal JSON: %w", err) + return xerrors.Errorf("failed to marshal JSON: %w", err) } return w.Close() diff --git a/wspb/wspb.go b/wspb/wspb.go index e43042d5..06ac3368 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -4,9 +4,9 @@ package wspb // import "nhooyr.io/websocket/wspb" import ( "bytes" "context" - "fmt" "github.com/golang/protobuf/proto" + "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" @@ -29,7 +29,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { if typ != websocket.MessageBinary { c.Close(websocket.StatusUnsupportedData, "expected binary message") - return fmt.Errorf("expected binary message for protobuf but got: %v", typ) + return xerrors.Errorf("expected binary message for protobuf but got: %v", typ) } b := bpool.Get() @@ -43,7 +43,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) { err = proto.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf") - return fmt.Errorf("failed to unmarshal protobuf: %w", err) + return xerrors.Errorf("failed to unmarshal protobuf: %w", err) } return nil @@ -66,7 +66,7 @@ func write(ctx context.Context, c *websocket.Conn, v proto.Message) (err error) err = pb.Marshal(v) if err != nil { - return fmt.Errorf("failed to marshal protobuf: %w", err) + return xerrors.Errorf("failed to marshal protobuf: %w", err) } return c.Write(ctx, websocket.MessageBinary, pb.Bytes()) From b6b56b7499ee09561b87ad3de17709a59f839952 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 5 Feb 2020 00:21:26 -0600 Subject: [PATCH 25/55] Both modes seem to work :) --- accept.go | 14 ++++---- assert_test.go | 3 +- compress.go | 58 +++++++++++++++------------------ compress_test.go | 45 ++++++++++++++++++++++++++ conn.go | 41 ++++++++++++++---------- conn_test.go | 7 ++-- dial.go | 13 ++++---- read.go | 74 ++++++++++++++++++++++-------------------- write.go | 83 ++++++++++++++++++++++++------------------------ 9 files changed, 196 insertions(+), 142 deletions(-) create mode 100644 compress_test.go diff --git a/accept.go b/accept.go index ac7f2de1..0394fa6d 100644 --- a/accept.go +++ b/accept.go @@ -111,12 +111,14 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) return newConn(connConfig{ - subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), - rwc: netConn, - client: false, - copts: copts, - br: brw.Reader, - bw: brw.Writer, + subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), + rwc: netConn, + client: false, + copts: copts, + flateThreshold: opts.CompressionOptions.Threshold, + + br: brw.Reader, + bw: brw.Writer, }), nil } diff --git a/assert_test.go b/assert_test.go index cd78fbb3..5307ee8e 100644 --- a/assert_test.go +++ b/assert_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "cdr.dev/slog" "cdr.dev/slog/sloggers/slogtest/assert" "nhooyr.io/websocket" @@ -33,7 +34,7 @@ func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) } func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { - t.Helper() + slog.Helper() var act interface{} err := wsjson.Read(ctx, c, &act) diff --git a/compress.go b/compress.go index fd2535cc..efd89b33 100644 --- a/compress.go +++ b/compress.go @@ -148,12 +148,12 @@ func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { var flateReaderPool sync.Pool -func getFlateReader(r io.Reader) io.Reader { +func getFlateReader(r io.Reader, dict []byte) io.Reader { fr, ok := flateReaderPool.Get().(io.Reader) if !ok { - return flate.NewReader(r) + return flate.NewReaderDict(r, dict) } - fr.(flate.Resetter).Reset(r, nil) + fr.(flate.Resetter).Reset(r, dict) return fr } @@ -163,10 +163,10 @@ func putFlateReader(fr io.Reader) { var flateWriterPool sync.Pool -func getFlateWriter(w io.Writer, dict []byte) *flate.Writer { +func getFlateWriter(w io.Writer) *flate.Writer { fw, ok := flateWriterPool.Get().(*flate.Writer) if !ok { - fw, _ = flate.NewWriterDict(w, flate.BestSpeed, dict) + fw, _ = flate.NewWriter(w, flate.BestSpeed) return fw } fw.Reset(w) @@ -177,40 +177,32 @@ func putFlateWriter(w *flate.Writer) { flateWriterPool.Put(w) } -type slidingWindowReader struct { - window []byte - - r io.Reader +type slidingWindow struct { + r io.Reader + buf []byte } -func (r slidingWindowReader) Read(p []byte) (int, error) { - n, err := r.r.Read(p) - p = p[:n] - - r.append(p) - - return n, err +func newSlidingWindow(n int) *slidingWindow { + return &slidingWindow{ + buf: make([]byte, 0, n), + } } -func (r slidingWindowReader) append(p []byte) { - if len(r.window) <= cap(r.window) { - r.window = append(r.window, p...) +func (w *slidingWindow) write(p []byte) { + if len(p) >= cap(w.buf) { + w.buf = w.buf[:cap(w.buf)] + p = p[len(p)-cap(w.buf):] + copy(w.buf, p) + return } - if len(p) > cap(r.window) { - p = p[len(p)-cap(r.window):] + left := cap(w.buf) - len(w.buf) + if left < len(p) { + // We need to shift spaceNeeded bytes from the end to make room for p at the end. + spaceNeeded := len(p) - left + copy(w.buf, w.buf[spaceNeeded:]) + w.buf = w.buf[:len(w.buf)-spaceNeeded] } - // p now contains at max the last window bytes - // so we need to be able to append all of it to r.window. - // Shift as many bytes from r.window as needed. - - // Maximum window size minus current window minus extra gives - // us the number of bytes that need to be shifted. - off := len(r.window) + len(p) - cap(r.window) - - r.window = append(r.window[:0], r.window[off:]...) - copy(r.window, r.window[off:]) - copy(r.window[len(r.window)-len(p):], p) - return + w.buf = append(w.buf, p...) } diff --git a/compress_test.go b/compress_test.go new file mode 100644 index 00000000..6edfcb1a --- /dev/null +++ b/compress_test.go @@ -0,0 +1,45 @@ +package websocket + +import ( + "crypto/rand" + "encoding/base64" + "math/big" + "strings" + "testing" + + "cdr.dev/slog/sloggers/slogtest/assert" +) + +func Test_slidingWindow(t *testing.T) { + t.Parallel() + + const testCount = 99 + const maxWindow = 99999 + for i := 0; i < testCount; i++ { + input := randStr(t, maxWindow) + windowLength := randInt(t, maxWindow) + r := newSlidingWindow(windowLength) + r.write([]byte(input)) + + if cap(r.buf) != windowLength { + t.Fatalf("sliding window length changed somehow: %q and windowLength %d", input, windowLength) + } + assert.True(t, "hasSuffix", strings.HasSuffix(input, string(r.buf))) + } +} + +func randStr(t *testing.T, max int) string { + n := randInt(t, max) + + b := make([]byte, n) + _, err := rand.Read(b) + assert.Success(t, "rand.Read", err) + + return base64.StdEncoding.EncodeToString(b) +} + +func randInt(t *testing.T, max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + assert.Success(t, "rand.Int", err) + return int(x.Int64()) +} diff --git a/conn.go b/conn.go index ab93e4e6..2d36123f 100644 --- a/conn.go +++ b/conn.go @@ -38,12 +38,13 @@ const ( // On any error from any method, the connection is closed // with an appropriate reason. type Conn struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - br *bufio.Reader - bw *bufio.Writer + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer readTimeout chan context.Context writeTimeout chan context.Context @@ -71,10 +72,11 @@ type Conn struct { } type connConfig struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int br *bufio.Reader bw *bufio.Writer @@ -82,10 +84,11 @@ type connConfig struct { func newConn(cfg connConfig) *Conn { c := &Conn{ - subprotocol: cfg.subprotocol, - rwc: cfg.rwc, - client: cfg.client, - copts: cfg.copts, + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, br: cfg.br, bw: cfg.bw, @@ -96,6 +99,12 @@ func newConn(cfg connConfig) *Conn { closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } + if c.flateThreshold == 0 { + c.flateThreshold = 256 + if c.writeNoContextTakeOver() { + c.flateThreshold = 512 + } + } c.readMu = newMu(c) c.writeFrameMu = newMu(c) @@ -145,12 +154,10 @@ func (c *Conn) close(err error) { } c.msgWriter.close() + c.msgReader.close() if c.client { - c.readMu.Lock(context.Background()) putBufioReader(c.br) - c.readMu.Unlock() } - c.msgReader.close() }() } diff --git a/conn_test.go b/conn_test.go index a65c332c..7186da8a 100644 --- a/conn_test.go +++ b/conn_test.go @@ -27,13 +27,15 @@ func TestConn(t *testing.T) { Subprotocols: []string{"echo"}, InsecureSkipVerify: true, CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionNoContextTakeover, + Mode: websocket.CompressionContextTakeover, + Threshold: 1, }, }) assert.Success(t, "accept", err) defer c.Close(websocket.StatusInternalError, "") err = echoLoop(r.Context(), c) + t.Logf("server: %v", err) assertCloseStatus(t, websocket.StatusNormalClosure, err) }, false) defer closeFn() @@ -46,7 +48,8 @@ func TestConn(t *testing.T) { opts := &websocket.DialOptions{ Subprotocols: []string{"echo"}, CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionNoContextTakeover, + Mode: websocket.CompressionContextTakeover, + Threshold: 1, }, } opts.HTTPClient = s.Client() diff --git a/dial.go b/dial.go index f53d30ee..4557602e 100644 --- a/dial.go +++ b/dial.go @@ -99,12 +99,13 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( } return newConn(connConfig{ - subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), - rwc: rwc, - client: true, - copts: copts, - br: getBufioReader(rwc), - bw: getBufioWriter(rwc), + subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), + rwc: rwc, + client: true, + copts: copts, + flateThreshold: opts.CompressionOptions.Threshold, + br: getBufioReader(rwc), + bw: getBufioWriter(rwc), }), resp, nil } diff --git a/read.go b/read.go index 4b94f067..73ec0b32 100644 --- a/read.go +++ b/read.go @@ -72,25 +72,40 @@ func (c *Conn) SetReadLimit(n int64) { c.msgReader.limitReader.limit.Store(n) } +const defaultReadLimit = 32768 + func newMsgReader(c *Conn) *msgReader { mr := &msgReader{ c: c, fin: true, } - mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768) + mr.limitReader = newLimitReader(c, readerFunc(mr.read), defaultReadLimit) return mr } -func (mr *msgReader) initFlateReader() { - mr.flateReader = getFlateReader(readerFunc(mr.read)) +func (mr *msgReader) ensureFlate() { + if mr.flateContextTakeover() && mr.dict == nil { + mr.dict = newSlidingWindow(32768) + } + + if mr.flateContextTakeover() { + mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf) + } else { + mr.flateReader = getFlateReader(readerFunc(mr.read), nil) + } mr.limitReader.r = mr.flateReader } +func (mr *msgReader) returnFlateReader() { + if mr.flateReader != nil { + putFlateReader(mr.flateReader) + mr.flateReader = nil + } +} + func (mr *msgReader) close() { mr.c.readMu.Lock(context.Background()) - defer mr.c.readMu.Unlock() - mr.returnFlateReader() } @@ -299,10 +314,11 @@ type msgReader struct { c *Conn ctx context.Context - deflate bool + flate bool flateReader io.Reader - deflateTail strings.Reader + flateTail strings.Reader limitReader *limitReader + dict *slidingWindow fin bool payloadLength int64 @@ -311,12 +327,10 @@ type msgReader struct { func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx - mr.deflate = h.rsv1 - if mr.deflate { - if !mr.flateContextTakeover() { - mr.initFlateReader() - } - mr.deflateTail.Reset(deflateMessageTail) + mr.flate = h.rsv1 + if mr.flate { + mr.ensureFlate() + mr.flateTail.Reset(deflateMessageTail) } mr.limitReader.reset() @@ -331,18 +345,10 @@ func (mr *msgReader) setFrame(h header) { func (mr *msgReader) Read(p []byte) (n int, err error) { defer func() { - r := recover() - if r != nil { - if r != "ANMOL" { - panic(r) - } + errd.Wrap(&err, "failed to read") + if xerrors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { err = io.EOF - if !mr.flateContextTakeover() { - mr.returnFlateReader() - } } - - errd.Wrap(&err, "failed to read") if xerrors.Is(err, io.EOF) { err = io.EOF } @@ -354,25 +360,23 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { } defer mr.c.readMu.Unlock() - return mr.limitReader.Read(p) -} - -func (mr *msgReader) returnFlateReader() { - if mr.flateReader != nil { - putFlateReader(mr.flateReader) - mr.flateReader = nil + n, err = mr.limitReader.Read(p) + if mr.flateContextTakeover() { + p = p[:n] + mr.dict.write(p) } + return n, err } func (mr *msgReader) read(p []byte) (int, error) { if mr.payloadLength == 0 { if mr.fin { - if mr.deflate { - if mr.deflateTail.Len() == 0 { - panic("ANMOL") + if mr.flate { + n, err := mr.flateTail.Read(p) + if xerrors.Is(err, io.EOF) { + mr.returnFlateReader() } - n, _ := mr.deflateTail.Read(p) - return n, nil + return n, err } return 0, io.EOF } diff --git a/write.go b/write.go index db47ddbc..a7fa5f5a 100644 --- a/write.go +++ b/write.go @@ -37,8 +37,8 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err // // See the Writer method if you want to stream a message. // -// If compression is disabled, then it is guaranteed to write the message -// in a single frame. +// If compression is disabled or the threshold is not met, then it +// will write the message in a single frame. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { _, err := c.write(ctx, typ, p) if err != nil { @@ -47,20 +47,38 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { return nil } +type msgWriter struct { + c *Conn + + mu *mu + + ctx context.Context + opcode opcode + closed bool + flate bool + + trimWriter *trimLastFourBytesWriter + flateWriter *flate.Writer +} + func newMsgWriter(c *Conn) *msgWriter { mw := &msgWriter{ c: c, mu: newMu(c), } - mw.trimWriter = &trimLastFourBytesWriter{ - w: writerFunc(mw.write), - } return mw } -func (mw *msgWriter) ensureFlateWriter() { +func (mw *msgWriter) ensureFlate() { if mw.flateWriter == nil { - mw.flateWriter = getFlateWriter(mw.trimWriter, nil) + if mw.trimWriter == nil { + mw.trimWriter = &trimLastFourBytesWriter{ + w: writerFunc(mw.write), + } + } + + mw.flateWriter = getFlateWriter(mw.trimWriter) + mw.flate = true } } @@ -85,8 +103,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return 0, err } - if !c.flate() { - // Fast single frame path. + if !c.flate() || len(p) < c.flateThreshold { defer c.msgWriter.mu.Unlock() return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } @@ -100,20 +117,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return n, err } -type msgWriter struct { - c *Conn - - mu *mu - - ctx context.Context - opcode opcode - closed bool - - flate bool - trimWriter *trimLastFourBytesWriter - flateWriter *flate.Writer -} - func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { err := mw.mu.Lock(ctx) if err != nil { @@ -127,6 +130,13 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { return nil } +func (mw *msgWriter) returnFlateWriter() { + if mw.flateWriter != nil { + putFlateWriter(mw.flateWriter) + mw.flateWriter = nil + } +} + // Write writes the given bytes to the WebSocket connection. func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") @@ -135,16 +145,10 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { return 0, xerrors.New("cannot use closed writer") } - if mw.c.flate() { - if !mw.flate { - mw.flate = true - - if !mw.flateContextTakeover() { - mw.ensureFlateWriter() - } - mw.trimWriter.reset() - } - + // TODO can make threshold detection robust across writes by writing to buffer + if mw.flate || + mw.c.flate() && len(p) >= mw.c.flateThreshold { + mw.ensureFlate() return mw.flateWriter.Write(p) } @@ -181,21 +185,16 @@ func (mw *msgWriter) Close() (err error) { return xerrors.Errorf("failed to write fin frame: %w", err) } - if mw.c.flate() && !mw.flateContextTakeover() && mw.flateWriter != nil { - putFlateWriter(mw.flateWriter) - mw.flateWriter = nil + if mw.c.flate() && !mw.flateContextTakeover() { + mw.returnFlateWriter() } - mw.mu.Unlock() return nil } func (mw *msgWriter) close() { - if mw.flateWriter != nil && mw.flateContextTakeover() { - mw.mu.Lock(context.Background()) - putFlateWriter(mw.flateWriter) - mw.flateWriter = nil - } + mw.mu.Lock(context.Background()) + mw.returnFlateWriter() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { From 9e32354f05c6a12cfbdf1256f43c7e05551116e7 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 5 Feb 2020 21:53:03 -0600 Subject: [PATCH 26/55] Fix randString method in tests --- assert_test.go | 1 + ci/image/Dockerfile | 2 +- conn_test.go | 52 ++++++++++++++++++++++++++++++++++--------- internal/errd/wrap.go | 34 ++++++++++++++++++++++++++-- write.go | 2 +- 5 files changed, 76 insertions(+), 15 deletions(-) diff --git a/assert_test.go b/assert_test.go index 5307ee8e..6cfd9264 100644 --- a/assert_test.go +++ b/assert_test.go @@ -45,6 +45,7 @@ func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp in func randString(t *testing.T, n int) string { s := strings.ToValidUTF8(string(randBytes(t, n)), "_") + s = strings.ReplaceAll(s, "\x00", "_") if len(s) > n { return s[:n] } diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index 070c50e6..88c96502 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -1,7 +1,7 @@ FROM golang:1 RUN apt-get update -RUN apt-get install -y chromium +RUN apt-get install -y chromium npm ENV GOFLAGS="-mod=readonly" ENV PAGER=cat diff --git a/conn_test.go b/conn_test.go index 7186da8a..4720cba9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -18,45 +18,71 @@ import ( "nhooyr.io/websocket" ) +func TestFuzz(t *testing.T) { + t.Parallel() + + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + CompressionOptions: websocket.CompressionOptions{ + Mode: websocket.CompressionContextTakeover, + }, + }) + assert.Success(t, "accept", err) + defer c.Close(websocket.StatusInternalError, "") + + err = echoLoop(r.Context(), c) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }, false) + defer closeFn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + opts := &websocket.DialOptions{ + CompressionOptions: websocket.CompressionOptions{ + Mode: websocket.CompressionContextTakeover, + }, + } + opts.HTTPClient = s.Client() + + c, _, err := websocket.Dial(ctx, wsURL(s), opts) + assert.Success(t, "dial", err) + assertJSONEcho(t, ctx, c, 8393) +} + func TestConn(t *testing.T) { t.Parallel() t.Run("json", func(t *testing.T) { s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - InsecureSkipVerify: true, + Subprotocols: []string{"echo"}, CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - Threshold: 1, + Mode: websocket.CompressionContextTakeover, }, }) assert.Success(t, "accept", err) defer c.Close(websocket.StatusInternalError, "") err = echoLoop(r.Context(), c) - t.Logf("server: %v", err) assertCloseStatus(t, websocket.StatusNormalClosure, err) }, false) defer closeFn() - wsURL := strings.Replace(s.URL, "http", "ws", 1) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() opts := &websocket.DialOptions{ Subprotocols: []string{"echo"}, CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - Threshold: 1, + Mode: websocket.CompressionContextTakeover, }, } opts.HTTPClient = s.Client() - c, _, err := websocket.Dial(ctx, wsURL, opts) + c, _, err := websocket.Dial(ctx, wsURL(s), opts) assert.Success(t, "dial", err) - assertJSONEcho(t, ctx, c, 2) + assertJSONEcho(t, ctx, c, 8393) }) } @@ -149,3 +175,7 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error { } } } + +func wsURL(s *httptest.Server) string { + return strings.Replace(s.URL, "http", "ws", 1) +} diff --git a/internal/errd/wrap.go b/internal/errd/wrap.go index 20de7743..ed0b7754 100644 --- a/internal/errd/wrap.go +++ b/internal/errd/wrap.go @@ -1,12 +1,42 @@ package errd -import "golang.org/x/xerrors" +import ( + "fmt" + + "golang.org/x/xerrors" +) + +type wrapError struct { + msg string + err error + frame xerrors.Frame +} + +func (e *wrapError) Error() string { + return fmt.Sprint(e) +} + +func (e *wrapError) Format(s fmt.State, v rune) { xerrors.FormatError(e, s, v) } + +func (e *wrapError) FormatError(p xerrors.Printer) (next error) { + p.Print(e.msg) + e.frame.Format(p) + return e.err +} + +func (e *wrapError) Unwrap() error { + return e.err +} // Wrap wraps err with xerrors.Errorf if err is non nil. // Intended for use with defer and a named error return. // Inspired by https://github.com/golang/go/issues/32676. func Wrap(err *error, f string, v ...interface{}) { if *err != nil { - *err = xerrors.Errorf(f+": %w", append(v, *err)...) + *err = &wrapError{ + msg: fmt.Sprintf(f, v...), + err: *err, + frame: xerrors.Caller(1), + } } } diff --git a/write.go b/write.go index a7fa5f5a..4a756fa9 100644 --- a/write.go +++ b/write.go @@ -145,7 +145,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { return 0, xerrors.New("cannot use closed writer") } - // TODO can make threshold detection robust across writes by writing to buffer + // TODO can make threshold detection robust across writes by writing to bufio writer if mw.flate || mw.c.flate() && len(p) >= mw.c.flateThreshold { mw.ensureFlate() From 78da35ec5b221d5ec664ee9cbf0a8fb034d46f4c Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Fri, 7 Feb 2020 00:58:57 -0600 Subject: [PATCH 27/55] Get test with multiple messages working --- README.md | 2 +- assert_test.go | 79 +++++++++++++++++++++++++++++++----------- autobahn_test.go | 16 +++++---- conn_test.go | 89 +++++++++++++++++------------------------------- example_test.go | 2 -- read.go | 11 +++--- write.go | 1 + ws_js_test.go | 2 +- 8 files changed, 109 insertions(+), 93 deletions(-) diff --git a/README.md b/README.md index e958d2ab..2569383a 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ go get nhooyr.io/websocket - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper - [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression -- Can target [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) +- Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) ## Roadmap diff --git a/assert_test.go b/assert_test.go index 6cfd9264..3727d995 100644 --- a/assert_test.go +++ b/assert_test.go @@ -3,10 +3,15 @@ package websocket_test import ( "context" "crypto/rand" + "fmt" + "net/http" + "net/http/httptest" "strings" "testing" + "time" "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest/assert" "nhooyr.io/websocket" @@ -20,26 +25,31 @@ func randBytes(t *testing.T, n int) []byte { return b } -func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) { - t.Helper() - defer c.Close(websocket.StatusInternalError, "") +func echoJSON(t *testing.T, c *websocket.Conn, n int) { + slog.Helper() - exp := randString(t, n) - err := wsjson.Write(ctx, c, exp) - assert.Success(t, "wsjson.Write", err) + s := randString(t, n) + writeJSON(t, c, s) + readJSON(t, c, s) +} - assertJSONRead(t, ctx, c, exp) +func writeJSON(t *testing.T, c *websocket.Conn, v interface{}) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() - c.Close(websocket.StatusNormalClosure, "") + err := wsjson.Write(ctx, c, v) + assert.Success(t, "wsjson.Write", err) } -func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { +func readJSON(t *testing.T, c *websocket.Conn, exp interface{}) { slog.Helper() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + var act interface{} err := wsjson.Read(ctx, c, &act) assert.Success(t, "wsjson.Read", err) - assert.Equal(t, "json", exp, act) } @@ -58,7 +68,7 @@ func randString(t *testing.T, n int) string { } func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) { - t.Helper() + slog.Helper() p := randBytes(t, n) err := c.Write(ctx, typ, p) @@ -72,17 +82,46 @@ func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websoc } func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { - t.Helper() + slog.Helper() assert.Equal(t, "subprotocol", exp, c.Subprotocol()) } -func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) { - t.Helper() - defer func() { - if t.Failed() { - t.Logf("error: %+v", err) - } - }() - assert.Equal(t, "closeStatus", exp, websocket.CloseStatus(err)) +func assertCloseStatus(t testing.TB, exp websocket.StatusCode, err error) { + slog.Helper() + + if websocket.CloseStatus(err) == -1 { + slogtest.Fatal(t, "expected websocket.CloseError", slogType(err), slog.Error(err)) + } + if websocket.CloseStatus(err) != exp { + slogtest.Error(t, "unexpected close status", + slog.F("exp", exp), + slog.F("act", err), + ) + } + +} + +func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts *websocket.AcceptOptions) *websocket.Conn { + c, err := websocket.Accept(w, r, opts) + assert.Success(t, "websocket.Accept", err) + return c +} + +func dialWebSocket(t testing.TB, s *httptest.Server, opts *websocket.DialOptions) (*websocket.Conn, *http.Response) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + if opts == nil { + opts = &websocket.DialOptions{} + } + opts.HTTPClient = s.Client() + + c, resp, err := websocket.Dial(ctx, wsURL(s), opts) + assert.Success(t, "websocket.Dial", err) + return c, resp +} + +func slogType(v interface{}) slog.Field { + return slog.F("type", fmt.Sprintf("%T", v)) } diff --git a/autobahn_test.go b/autobahn_test.go index bcbf8671..dd9887f6 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "net" "net/http" + "net/http/httptest" "os" "os/exec" "strconv" @@ -53,15 +54,18 @@ func TestAutobahn(t *testing.T) { func testServerAutobahn(t *testing.T) { t.Parallel() - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := acceptWebSocket(t, r, w, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) - assert.Success(t, "accept", err) - err = echoLoop(r.Context(), c) + err := echoLoop(r.Context(), c) assertCloseStatus(t, websocket.StatusNormalClosure, err) - }, false) - defer closeFn() + })) + closeFn := wsgrace(s.Config) + defer func() { + err := closeFn() + assert.Success(t, "closeFn", err) + }() specFile, err := tempJSONFile(map[string]interface{}{ "outdir": "ci/out/wstestServerReports", diff --git a/conn_test.go b/conn_test.go index 4720cba9..6f6b8d5d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,7 +4,9 @@ package websocket_test import ( "context" + "crypto/rand" "io" + "math/big" "net/http" "net/http/httptest" "strings" @@ -18,77 +20,32 @@ import ( "nhooyr.io/websocket" ) -func TestFuzz(t *testing.T) { - t.Parallel() - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - }, - }) - assert.Success(t, "accept", err) - defer c.Close(websocket.StatusInternalError, "") - - err = echoLoop(r.Context(), c) - assertCloseStatus(t, websocket.StatusNormalClosure, err) - }, false) - defer closeFn() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - opts := &websocket.DialOptions{ - CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - }, - } - opts.HTTPClient = s.Client() - - c, _, err := websocket.Dial(ctx, wsURL(s), opts) - assert.Success(t, "dial", err) - assertJSONEcho(t, ctx, c, 8393) -} - func TestConn(t *testing.T) { t.Parallel() t.Run("json", func(t *testing.T) { - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - }, - }) - assert.Success(t, "accept", err) - defer c.Close(websocket.StatusInternalError, "") - - err = echoLoop(r.Context(), c) - assertCloseStatus(t, websocket.StatusNormalClosure, err) - }, false) + t.Parallel() + + s, closeFn := testEchoLoop(t) defer closeFn() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() + c, _ := dialWebSocket(t, s, nil) + defer c.Close(websocket.StatusInternalError, "") - opts := &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - }, + c.SetReadLimit(1 << 30) + + for i := 0; i < 10; i++ { + n := randInt(t, 1_048_576) + echoJSON(t, c, n) } - opts.HTTPClient = s.Client() - c, _, err := websocket.Dial(ctx, wsURL(s), opts) - assert.Success(t, "dial", err) - assertJSONEcho(t, ctx, c, 8393) + c.Close(websocket.StatusNormalClosure, "") }) } -func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request), tls bool) (s *httptest.Server, closeFn func()) { +func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request)) (s *httptest.Server, closeFn func()) { h := http.HandlerFunc(fn) - if tls { + if randInt(tb, 2) == 1 { s = httptest.NewTLSServer(h) } else { s = httptest.NewServer(h) @@ -179,3 +136,19 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error { func wsURL(s *httptest.Server) string { return strings.Replace(s.URL, "http", "ws", 1) } + +func testEchoLoop(t testing.TB) (*httptest.Server, func()) { + return testServer(t, func(w http.ResponseWriter, r *http.Request) { + c := acceptWebSocket(t, r, w, nil) + defer c.Close(websocket.StatusInternalError, "") + + err := echoLoop(r.Context(), c) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }) +} + +func randInt(t testing.TB, max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + assert.Success(t, "rand.Int", err) + return int(x.Int64()) +} diff --git a/example_test.go b/example_test.go index bc603aff..1842b765 100644 --- a/example_test.go +++ b/example_test.go @@ -33,8 +33,6 @@ func ExampleAccept() { return } - log.Printf("received: %v", v) - c.Close(websocket.StatusNormalClosure, "") }) diff --git a/read.go b/read.go index 73ec0b32..7e74894a 100644 --- a/read.go +++ b/read.go @@ -95,6 +95,7 @@ func (mr *msgReader) ensureFlate() { mr.flateReader = getFlateReader(readerFunc(mr.read), nil) } mr.limitReader.r = mr.flateReader + mr.flateTail.Reset(deflateMessageTail) } func (mr *msgReader) returnFlateReader() { @@ -328,12 +329,12 @@ type msgReader struct { func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.flate = h.rsv1 + mr.limitReader.reset(readerFunc(mr.read)) + if mr.flate { mr.ensureFlate() - mr.flateTail.Reset(deflateMessageTail) } - mr.limitReader.reset() mr.setFrame(h) } @@ -423,13 +424,13 @@ func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { c: c, } lr.limit.Store(limit) - lr.r = r - lr.reset() + lr.reset(r) return lr } -func (lr *limitReader) reset() { +func (lr *limitReader) reset(r io.Reader) { lr.n = lr.limit.Load() + lr.r = r } func (lr *limitReader) Read(p []byte) (int, error) { diff --git a/write.go b/write.go index 4a756fa9..34543486 100644 --- a/write.go +++ b/write.go @@ -76,6 +76,7 @@ func (mw *msgWriter) ensureFlate() { w: writerFunc(mw.write), } } + mw.trimWriter.reset() mw.flateWriter = getFlateWriter(mw.trimWriter) mw.flate = true diff --git a/ws_js_test.go b/ws_js_test.go index 6e87480b..9f725a57 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -24,7 +24,7 @@ func TestEcho(t *testing.T) { assertSubprotocol(t, c, "echo") assert.Equalf(t, &http.Response{}, resp, "http.Response") - assertJSONEcho(t, ctx, c, 1024) + echoJSON(t, ctx, c, 1024) assertEcho(t, ctx, c, websocket.MessageBinary, 1024) err = c.Close(websocket.StatusNormalClosure, "") From d09268649e33ce5b3afde49006d39508a28cbe12 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 8 Feb 2020 15:29:08 -0500 Subject: [PATCH 28/55] Autobahn tests fully pass :) --- assert_test.go | 15 ---- autobahn_test.go | 76 ++------------------ conn.go | 2 +- conn_test.go | 177 ++++++++++++++++++++++++++--------------------- read.go | 6 +- write.go | 31 +++++---- 6 files changed, 127 insertions(+), 180 deletions(-) diff --git a/assert_test.go b/assert_test.go index 3727d995..22814e3b 100644 --- a/assert_test.go +++ b/assert_test.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "fmt" "net/http" - "net/http/httptest" "strings" "testing" "time" @@ -108,20 +107,6 @@ func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts return c } -func dialWebSocket(t testing.TB, s *httptest.Server, opts *websocket.DialOptions) (*websocket.Conn, *http.Response) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - if opts == nil { - opts = &websocket.DialOptions{} - } - opts.HTTPClient = s.Client() - - c, resp, err := websocket.Dial(ctx, wsURL(s), opts) - assert.Success(t, "websocket.Dial", err) - return c, resp -} - func slogType(v interface{}) slog.Field { return slog.F("type", fmt.Sprintf("%T", v)) } diff --git a/autobahn_test.go b/autobahn_test.go index dd9887f6..71d22be7 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -8,9 +8,6 @@ import ( "fmt" "io/ioutil" "net" - "net/http" - "net/http/httptest" - "os" "os/exec" "strconv" "strings" @@ -32,69 +29,14 @@ var excludedAutobahnCases = []string{ // We skip the tests related to requestMaxWindowBits as that is unimplemented due // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 "13.3.*", "13.4.*", "13.5.*", "13.6.*", - - "12.*", - "13.*", } var autobahnCases = []string{"*"} -// https://github.com/crossbario/autobahn-python/tree/master/wstest func TestAutobahn(t *testing.T) { t.Parallel() - if os.Getenv("AUTOBAHN") == "" { - t.Skip("Set $AUTOBAHN to run tests against the autobahn test suite") - } - - t.Run("server", testServerAutobahn) - t.Run("client", testClientAutobahn) -} - -func testServerAutobahn(t *testing.T) { - t.Parallel() - - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c := acceptWebSocket(t, r, w, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - }) - err := echoLoop(r.Context(), c) - assertCloseStatus(t, websocket.StatusNormalClosure, err) - })) - closeFn := wsgrace(s.Config) - defer func() { - err := closeFn() - assert.Success(t, "closeFn", err) - }() - - specFile, err := tempJSONFile(map[string]interface{}{ - "outdir": "ci/out/wstestServerReports", - "servers": []interface{}{ - map[string]interface{}{ - "agent": "main", - "url": strings.Replace(s.URL, "http", "ws", 1), - }, - }, - "cases": autobahnCases, - "exclude-cases": excludedAutobahnCases, - }) - assert.Success(t, "tempJSONFile", err) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10) - defer cancel() - - args := []string{"--mode", "fuzzingclient", "--spec", specFile} - wstest := exec.CommandContext(ctx, "wstest", args...) - _, err = wstest.CombinedOutput() - assert.Success(t, "wstest", err) - - checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") -} - -func testClientAutobahn(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) defer cancel() wstestURL, closeFn, err := wstestClientServer(ctx) @@ -108,27 +50,17 @@ func testClientAutobahn(t *testing.T) { assert.Success(t, "wstestCaseCount", err) t.Run("cases", func(t *testing.T) { - // Max 8 cases running at a time. - mu := make(chan struct{}, 8) - for i := 1; i <= cases; i++ { i := i t.Run("", func(t *testing.T) { - t.Parallel() - - mu <- struct{}{} - defer func() { - <-mu - }() - - ctx, cancel := context.WithTimeout(ctx, time.Second*45) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) assert.Success(t, "autobahn dial", err) err = echoLoop(ctx, c) - t.Logf("echoLoop: %+v", err) + t.Logf("echoLoop: %v", err) }) } }) @@ -174,7 +106,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er return "", nil, xerrors.Errorf("failed to write spec: %w", err) } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) defer func() { if err != nil { cancel() diff --git a/conn.go b/conn.go index 2d36123f..163802bb 100644 --- a/conn.go +++ b/conn.go @@ -99,7 +99,7 @@ func newConn(cfg connConfig) *Conn { closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } - if c.flateThreshold == 0 { + if c.flate() && c.flateThreshold == 0 { c.flateThreshold = 256 if c.writeNoContextTakeOver() { c.flateThreshold = 512 diff --git a/conn_test.go b/conn_test.go index 6f6b8d5d..aceac3fd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,99 +3,70 @@ package websocket_test import ( + "bufio" "context" "crypto/rand" "io" "math/big" + "net" "net/http" "net/http/httptest" - "strings" - "sync/atomic" "testing" "time" "cdr.dev/slog/sloggers/slogtest/assert" - "golang.org/x/xerrors" "nhooyr.io/websocket" ) +func goFn(fn func()) func() { + done := make(chan struct{}) + go func() { + defer close(done) + fn() + }() + + return func() { + <-done + } +} + func TestConn(t *testing.T) { t.Parallel() t.Run("json", func(t *testing.T) { t.Parallel() - s, closeFn := testEchoLoop(t) - defer closeFn() + for i := 0; i < 1; i++ { + t.Run("", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() - c, _ := dialWebSocket(t, s, nil) - defer c.Close(websocket.StatusInternalError, "") - - c.SetReadLimit(1 << 30) - - for i := 0; i < 10; i++ { - n := randInt(t, 1_048_576) - echoJSON(t, c, n) - } + c1, c2 := websocketPipe(t) - c.Close(websocket.StatusNormalClosure, "") - }) -} - -func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request)) (s *httptest.Server, closeFn func()) { - h := http.HandlerFunc(fn) - if randInt(tb, 2) == 1 { - s = httptest.NewTLSServer(h) - } else { - s = httptest.NewServer(h) - } - closeFn2 := wsgrace(s.Config) - return s, func() { - err := closeFn2() - assert.Success(tb, "closeFn", err) - } -} + wait := goFn(func() { + err := echoLoop(ctx, c1) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }) + defer wait() -// grace wraps s.Handler to gracefully shutdown WebSocket connections. -// The returned function must be used to close the server instead of s.Close. -func wsgrace(s *http.Server) (closeFn func() error) { - h := s.Handler - var conns int64 - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&conns, 1) - defer atomic.AddInt64(&conns, -1) + c2.SetReadLimit(1 << 30) - ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) - defer cancel() - - r = r.WithContext(ctx) + for i := 0; i < 10; i++ { + n := randInt(t, 131_072) + echoJSON(t, c2, n) + } - h.ServeHTTP(w, r) + c2.Close(websocket.StatusNormalClosure, "") + }) + } }) +} - return func() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - err := s.Shutdown(ctx) - if err != nil { - return xerrors.Errorf("server shutdown failed: %v", err) - } +type writerFunc func(p []byte) (int, error) - t := time.NewTicker(time.Millisecond * 10) - defer t.Stop() - for { - select { - case <-t.C: - if atomic.LoadInt64(&conns) == 0 { - return nil - } - case <-ctx.Done(): - return xerrors.Errorf("failed to wait for WebSocket connections: %v", ctx.Err()) - } - } - } +func (f writerFunc) Write(p []byte) (int, error) { + return f(p) } // echoLoop echos every msg received from c until an error @@ -133,18 +104,8 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error { } } -func wsURL(s *httptest.Server) string { - return strings.Replace(s.URL, "http", "ws", 1) -} - -func testEchoLoop(t testing.TB) (*httptest.Server, func()) { - return testServer(t, func(w http.ResponseWriter, r *http.Request) { - c := acceptWebSocket(t, r, w, nil) - defer c.Close(websocket.StatusInternalError, "") - - err := echoLoop(r.Context(), c) - assertCloseStatus(t, websocket.StatusNormalClosure, err) - }) +func randBool(t testing.TB) bool { + return randInt(t, 2) == 1 } func randInt(t testing.TB, max int) int { @@ -152,3 +113,65 @@ func randInt(t testing.TB, max int) int { assert.Success(t, "rand.Int", err) return int(x.Int64()) } + +type testHijacker struct { + *httptest.ResponseRecorder + serverConn net.Conn + hijacked chan struct{} +} + +var _ http.Hijacker = testHijacker{} + +func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + close(hj.hijacked) + return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil +} + +func websocketPipe(t *testing.T) (*websocket.Conn, *websocket.Conn) { + var serverConn *websocket.Conn + tt := testTransport{ + h: func(w http.ResponseWriter, r *http.Request) { + serverConn = acceptWebSocket(t, r, w, nil) + }, + } + + dialOpts := &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: tt, + }, + } + + clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts) + assert.Success(t, "websocket.Dial", err) + + if randBool(t) { + return serverConn, clientConn + } + return clientConn, serverConn +} + +type testTransport struct { + h http.HandlerFunc +} + +func (t testTransport) RoundTrip(r *http.Request) (*http.Response, error) { + clientConn, serverConn := net.Pipe() + + hj := testHijacker{ + ResponseRecorder: httptest.NewRecorder(), + serverConn: serverConn, + hijacked: make(chan struct{}), + } + + done := make(chan struct{}) + t.h.ServeHTTP(hj, r) + + select { + case <-hj.hijacked: + resp := hj.ResponseRecorder.Result() + resp.Body = clientConn + return resp, nil + case <-done: + return hj.ResponseRecorder.Result(), nil + } +} diff --git a/read.go b/read.go index 7e74894a..b681a944 100644 --- a/read.go +++ b/read.go @@ -84,7 +84,7 @@ func newMsgReader(c *Conn) *msgReader { return mr } -func (mr *msgReader) ensureFlate() { +func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() && mr.dict == nil { mr.dict = newSlidingWindow(32768) } @@ -332,7 +332,7 @@ func (mr *msgReader) reset(ctx context.Context, h header) { mr.limitReader.reset(readerFunc(mr.read)) if mr.flate { - mr.ensureFlate() + mr.resetFlate() } mr.setFrame(h) @@ -362,7 +362,7 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { defer mr.c.readMu.Unlock() n, err = mr.limitReader.Read(p) - if mr.flateContextTakeover() { + if mr.flate && mr.flateContextTakeover() { p = p[:n] mr.dict.write(p) } diff --git a/write.go b/write.go index 34543486..70656b9f 100644 --- a/write.go +++ b/write.go @@ -70,17 +70,17 @@ func newMsgWriter(c *Conn) *msgWriter { } func (mw *msgWriter) ensureFlate() { - if mw.flateWriter == nil { - if mw.trimWriter == nil { - mw.trimWriter = &trimLastFourBytesWriter{ - w: writerFunc(mw.write), - } + if mw.trimWriter == nil { + mw.trimWriter = &trimLastFourBytesWriter{ + w: writerFunc(mw.write), } - mw.trimWriter.reset() + } + if mw.flateWriter == nil { mw.flateWriter = getFlateWriter(mw.trimWriter) - mw.flate = true } + + mw.flate = true } func (mw *msgWriter) flateContextTakeover() bool { @@ -128,6 +128,11 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false + + if mw.trimWriter != nil { + mw.trimWriter.reset() + } + return nil } @@ -146,9 +151,8 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { return 0, xerrors.New("cannot use closed writer") } - // TODO can make threshold detection robust across writes by writing to bufio writer - if mw.flate || - mw.c.flate() && len(p) >= mw.c.flateThreshold { + // TODO Write to buffer to detect whether to enable flate or not for this message. + if mw.c.flate() { mw.ensureFlate() return mw.flateWriter.Write(p) } @@ -172,7 +176,6 @@ func (mw *msgWriter) Close() (err error) { if mw.closed { return xerrors.New("cannot use closed writer") } - mw.closed = true if mw.flate { err = mw.flateWriter.Flush() @@ -181,12 +184,16 @@ func (mw *msgWriter) Close() (err error) { } } + // We set closed after flushing the flate writer to ensure Write + // can succeed. + mw.closed = true + _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return xerrors.Errorf("failed to write fin frame: %w", err) } - if mw.c.flate() && !mw.flateContextTakeover() { + if mw.flate && !mw.flateContextTakeover() { mw.returnFlateWriter() } mw.mu.Unlock() From 6975801d4df4be481b0a76ae48928c402496df45 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 8 Feb 2020 22:14:00 -0500 Subject: [PATCH 29/55] Fix race in tests --- assert_test.go | 2 +- autobahn_test.go | 5 +++++ conn_test.go | 15 +++++---------- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/assert_test.go b/assert_test.go index 22814e3b..a51b2c3d 100644 --- a/assert_test.go +++ b/assert_test.go @@ -28,7 +28,7 @@ func echoJSON(t *testing.T, c *websocket.Conn, n int) { slog.Helper() s := randString(t, n) - writeJSON(t, c, s) + go writeJSON(t, c, s) readJSON(t, c, s) } diff --git a/autobahn_test.go b/autobahn_test.go index 71d22be7..d730cf4a 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "net" + "os" "os/exec" "strconv" "strings" @@ -36,6 +37,10 @@ var autobahnCases = []string{"*"} func TestAutobahn(t *testing.T) { t.Parallel() + if os.Getenv("AUTOBAHN_TEST") == "" { + t.SkipNow() + } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) defer cancel() diff --git a/conn_test.go b/conn_test.go index aceac3fd..f1361adc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -39,7 +39,7 @@ func TestConn(t *testing.T) { for i := 0; i < 1; i++ { t.Run("", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() c1, c2 := websocketPipe(t) @@ -49,6 +49,7 @@ func TestConn(t *testing.T) { assertCloseStatus(t, websocket.StatusNormalClosure, err) }) defer wait() + defer cancel() c2.SetReadLimit(1 << 30) @@ -63,12 +64,6 @@ func TestConn(t *testing.T) { }) } -type writerFunc func(p []byte) (int, error) - -func (f writerFunc) Write(p []byte) (int, error) { - return f(p) -} - // echoLoop echos every msg received from c until an error // occurs or the context expires. // The read limit is set to 1 << 30. @@ -104,7 +99,7 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error { } } -func randBool(t testing.TB) bool { +func randBool(t testing.TB) bool { return randInt(t, 2) == 1 } @@ -117,7 +112,7 @@ func randInt(t testing.TB, max int) int { type testHijacker struct { *httptest.ResponseRecorder serverConn net.Conn - hijacked chan struct{} + hijacked chan struct{} } var _ http.Hijacker = testHijacker{} @@ -154,7 +149,7 @@ type testTransport struct { h http.HandlerFunc } -func (t testTransport) RoundTrip(r *http.Request) (*http.Response, error) { +func (t testTransport) RoundTrip(r *http.Request) (*http.Response, error) { clientConn, serverConn := net.Pipe() hj := testHijacker{ From bbaf469750cf0996a4a7bd1b6ddcf01f88943c3d Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 8 Feb 2020 22:28:13 -0500 Subject: [PATCH 30/55] Fix test step --- .github/workflows/ci.yml | 2 +- go.mod | 12 ++++++++++-- go.sum | 21 +++++++++++++++++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 865c67f0..074e5246 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,7 +47,7 @@ jobs: with: args: make test env: - COVERALLS_TOKEN: ${{ secrets.github_token }} + COVERALLS_TOKEN: ${{ secrets.COVERALLS_TOKEN }} - name: Upload coverage.html uses: actions/upload-artifact@master with: diff --git a/go.mod b/go.mod index 5dc9b261..ee1708a2 100644 --- a/go.mod +++ b/go.mod @@ -4,11 +4,19 @@ go 1.12 require ( cdr.dev/slog v1.3.0 + github.com/alecthomas/chroma v0.7.1 // indirect + github.com/fatih/color v1.9.0 // indirect github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect github.com/gobwas/pool v0.2.0 // indirect github.com/gobwas/ws v1.0.2 - github.com/golang/protobuf v1.3.2 + github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect + github.com/golang/protobuf v1.3.3 + github.com/google/go-cmp v0.4.0 // indirect github.com/gorilla/websocket v1.4.1 - golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 + github.com/mattn/go-isatty v0.0.12 // indirect + go.opencensus.io v0.22.3 // indirect + golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340 // indirect + golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 // indirect + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 ) diff --git a/go.sum b/go.sum index 864efaa7..1d1dc3a6 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38 h1:smF2tmSOzy2Mm github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38/go.mod h1:r7bzyVFMNntcxPZXK3/+KdruV1H5KSlyVY0gc+NgInI= github.com/alecthomas/chroma v0.7.0 h1:z+0HgTUmkpRDRz0SRSdMaqOLfJV4F+N1FPDZUZIDUzw= github.com/alecthomas/chroma v0.7.0/go.mod h1:1U/PfCsTALWWYHDnsIQkxEBM0+6LLe0v8+RSVMOwxeY= +github.com/alecthomas/chroma v0.7.1 h1:G1i02OhUbRi2nJxcNkwJaY/J1gHXj9tt72qN6ZouLFQ= +github.com/alecthomas/chroma v0.7.1/go.mod h1:gHw09mkX1Qp80JlYbmN9L3+4R5o6DJJ3GRShh+AICNc= github.com/alecthomas/colour v0.0.0-20160524082231-60882d9e2721 h1:JHZL0hZKJ1VENNfmXvHbgYlbUOvpzYzvy2aZU5gXVeo= github.com/alecthomas/colour v0.0.0-20160524082231-60882d9e2721/go.mod h1:QO9JBoKquHd+jz9nshCh40fOfO+JzsoXy8qTHF68zU0= github.com/alecthomas/kong v0.1.17-0.20190424132513-439c674f7ae0/go.mod h1:+inYUSluD+p4L8KdviBSgzcqEjUQOfC5fQDRFuc36lI= @@ -46,6 +48,8 @@ github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymF github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= +github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= @@ -59,6 +63,8 @@ github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 h1:uHTyIjqVhYRhLbJ8nIiOJHkEZZ+5YoOsAbD3sk82NiE= github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= @@ -66,12 +72,16 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.2-0.20191216170541-340f1ebe299e h1:4WfjkTUTsO6siF8ghDQQk6t7x/FPsv3w6MXkc47do7Q= github.com/google/go-cmp v0.3.2-0.20191216170541-340f1ebe299e/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -102,6 +112,8 @@ github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNx github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/nkovacs/streamquote v0.0.0-20170412213628-49af9bddb229/go.mod h1:0aYXnNPJ8l7uZxf45rWW1a/uME32OF0rhiYGNQ2oF2E= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -123,12 +135,16 @@ go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.3 h1:8sGtKOrtQqkN1bp2AtX+misvLIlOmsEsNd+9NIcPEm8= +go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g= golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340 h1:KOcEaR10tFr7gdJV2GCKw8Os5yED1u1aOqHjOAb6d2Y= +golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -181,6 +197,9 @@ golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191210023423-ac6580df4449 h1:gSbV7h1NRL2G1xTg/owz62CST1oJBmxy4QpMMregXVQ= golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 h1:LfCXLvNmTYH9kEmVgqbnsWfruoXZIrh4YBgqVHtDvw0= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= @@ -188,6 +207,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= From faadcc9613d9e663ef39dd9d71196e033f3f2901 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 8 Feb 2020 23:14:03 -0500 Subject: [PATCH 31/55] Simplify tests --- assert_test.go | 112 ------------------------ compress_test.go | 25 +----- conn_test.go | 165 +++++++++++++++-------------------- dial.go | 1 + go.mod | 2 +- internal/test/cmp/cmp.go | 22 +++++ internal/test/doc.go | 2 + internal/test/wstest/pipe.go | 82 +++++++++++++++++ internal/test/xrand/xrand.go | 47 ++++++++++ ws_js_test.go | 12 ++- 10 files changed, 234 insertions(+), 236 deletions(-) delete mode 100644 assert_test.go create mode 100644 internal/test/cmp/cmp.go create mode 100644 internal/test/doc.go create mode 100644 internal/test/wstest/pipe.go create mode 100644 internal/test/xrand/xrand.go diff --git a/assert_test.go b/assert_test.go deleted file mode 100644 index a51b2c3d..00000000 --- a/assert_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package websocket_test - -import ( - "context" - "crypto/rand" - "fmt" - "net/http" - "strings" - "testing" - "time" - - "cdr.dev/slog" - "cdr.dev/slog/sloggers/slogtest" - "cdr.dev/slog/sloggers/slogtest/assert" - - "nhooyr.io/websocket" - "nhooyr.io/websocket/wsjson" -) - -func randBytes(t *testing.T, n int) []byte { - b := make([]byte, n) - _, err := rand.Reader.Read(b) - assert.Success(t, "readRandBytes", err) - return b -} - -func echoJSON(t *testing.T, c *websocket.Conn, n int) { - slog.Helper() - - s := randString(t, n) - go writeJSON(t, c, s) - readJSON(t, c, s) -} - -func writeJSON(t *testing.T, c *websocket.Conn, v interface{}) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - err := wsjson.Write(ctx, c, v) - assert.Success(t, "wsjson.Write", err) -} - -func readJSON(t *testing.T, c *websocket.Conn, exp interface{}) { - slog.Helper() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - var act interface{} - err := wsjson.Read(ctx, c, &act) - assert.Success(t, "wsjson.Read", err) - assert.Equal(t, "json", exp, act) -} - -func randString(t *testing.T, n int) string { - s := strings.ToValidUTF8(string(randBytes(t, n)), "_") - s = strings.ReplaceAll(s, "\x00", "_") - if len(s) > n { - return s[:n] - } - if len(s) < n { - // Pad with = - extra := n - len(s) - return s + strings.Repeat("=", extra) - } - return s -} - -func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) { - slog.Helper() - - p := randBytes(t, n) - err := c.Write(ctx, typ, p) - assert.Success(t, "write", err) - - typ2, p2, err := c.Read(ctx) - assert.Success(t, "read", err) - - assert.Equal(t, "dataType", typ, typ2) - assert.Equal(t, "payload", p, p2) -} - -func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { - slog.Helper() - - assert.Equal(t, "subprotocol", exp, c.Subprotocol()) -} - -func assertCloseStatus(t testing.TB, exp websocket.StatusCode, err error) { - slog.Helper() - - if websocket.CloseStatus(err) == -1 { - slogtest.Fatal(t, "expected websocket.CloseError", slogType(err), slog.Error(err)) - } - if websocket.CloseStatus(err) != exp { - slogtest.Error(t, "unexpected close status", - slog.F("exp", exp), - slog.F("act", err), - ) - } - -} - -func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts *websocket.AcceptOptions) *websocket.Conn { - c, err := websocket.Accept(w, r, opts) - assert.Success(t, "websocket.Accept", err) - return c -} - -func slogType(v interface{}) slog.Field { - return slog.F("type", fmt.Sprintf("%T", v)) -} diff --git a/compress_test.go b/compress_test.go index 6edfcb1a..15d334d6 100644 --- a/compress_test.go +++ b/compress_test.go @@ -1,13 +1,12 @@ package websocket import ( - "crypto/rand" - "encoding/base64" - "math/big" "strings" "testing" "cdr.dev/slog/sloggers/slogtest/assert" + + "nhooyr.io/websocket/internal/test/xrand" ) func Test_slidingWindow(t *testing.T) { @@ -16,8 +15,8 @@ func Test_slidingWindow(t *testing.T) { const testCount = 99 const maxWindow = 99999 for i := 0; i < testCount; i++ { - input := randStr(t, maxWindow) - windowLength := randInt(t, maxWindow) + input := xrand.String(maxWindow) + windowLength := xrand.Int(maxWindow) r := newSlidingWindow(windowLength) r.write([]byte(input)) @@ -27,19 +26,3 @@ func Test_slidingWindow(t *testing.T) { assert.True(t, "hasSuffix", strings.HasSuffix(input, string(r.buf))) } } - -func randStr(t *testing.T, max int) string { - n := randInt(t, max) - - b := make([]byte, n) - _, err := rand.Read(b) - assert.Success(t, "rand.Read", err) - - return base64.StdEncoding.EncodeToString(b) -} - -func randInt(t *testing.T, max int) int { - x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) - assert.Success(t, "rand.Int", err) - return int(x.Int64()) -} diff --git a/conn_test.go b/conn_test.go index f1361adc..d246f719 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,59 +3,96 @@ package websocket_test import ( - "bufio" "context" - "crypto/rand" "io" - "math/big" - "net" - "net/http" - "net/http/httptest" "testing" "time" - "cdr.dev/slog/sloggers/slogtest/assert" + "golang.org/x/xerrors" "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/wstest" + "nhooyr.io/websocket/internal/test/xrand" + "nhooyr.io/websocket/wsjson" ) -func goFn(fn func()) func() { - done := make(chan struct{}) +func goFn(fn func() error) chan error { + errs := make(chan error) go func() { - defer close(done) - fn() + defer close(errs) + errs <- fn() }() - return func() { - <-done - } + return errs } func TestConn(t *testing.T) { t.Parallel() - t.Run("json", func(t *testing.T) { + t.Run("data", func(t *testing.T) { t.Parallel() - for i := 0; i < 1; i++ { + for i := 0; i < 10; i++ { t.Run("", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - c1, c2 := websocketPipe(t) + copts := websocket.CompressionOptions{ + Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled))), + Threshold: xrand.Int(9999), + } + + c1, c2, err := wstest.Pipe(&websocket.DialOptions{ + CompressionOptions: copts, + }, &websocket.AcceptOptions{ + CompressionOptions: copts, + }) + if err != nil { + t.Fatal(err) + } + defer c1.Close(websocket.StatusInternalError, "") + defer c2.Close(websocket.StatusInternalError, "") - wait := goFn(func() { + echoLoopErr := goFn(func() error { err := echoLoop(ctx, c1) - assertCloseStatus(t, websocket.StatusNormalClosure, err) + return assertCloseStatus(websocket.StatusNormalClosure, err) }) - defer wait() + defer func() { + err := <-echoLoopErr + if err != nil { + t.Errorf("echo loop error: %v", err) + } + }() defer cancel() c2.SetReadLimit(1 << 30) for i := 0; i < 10; i++ { - n := randInt(t, 131_072) - echoJSON(t, c2, n) + n := xrand.Int(131_072) + + msg := xrand.String(n) + + writeErr := goFn(func() error { + return wsjson.Write(ctx, c2, msg) + }) + + var act interface{} + err := wsjson.Read(ctx, c2, &act) + if err != nil { + t.Fatal(err) + } + + err = <-writeErr + if err != nil { + t.Fatal(err) + } + + if !cmp.Equal(msg, act) { + t.Fatalf("unexpected msg read: %v", cmp.Diff(msg, act)) + } } c2.Close(websocket.StatusNormalClosure, "") @@ -64,6 +101,16 @@ func TestConn(t *testing.T) { }) } +func assertCloseStatus(exp websocket.StatusCode, err error) error { + if websocket.CloseStatus(err) == -1 { + return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err) + } + if websocket.CloseStatus(err) != exp { + return xerrors.Errorf("unexpected close status (%v):%v", exp, err) + } + return nil +} + // echoLoop echos every msg received from c until an error // occurs or the context expires. // The read limit is set to 1 << 30. @@ -98,75 +145,3 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error { } } } - -func randBool(t testing.TB) bool { - return randInt(t, 2) == 1 -} - -func randInt(t testing.TB, max int) int { - x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) - assert.Success(t, "rand.Int", err) - return int(x.Int64()) -} - -type testHijacker struct { - *httptest.ResponseRecorder - serverConn net.Conn - hijacked chan struct{} -} - -var _ http.Hijacker = testHijacker{} - -func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { - close(hj.hijacked) - return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil -} - -func websocketPipe(t *testing.T) (*websocket.Conn, *websocket.Conn) { - var serverConn *websocket.Conn - tt := testTransport{ - h: func(w http.ResponseWriter, r *http.Request) { - serverConn = acceptWebSocket(t, r, w, nil) - }, - } - - dialOpts := &websocket.DialOptions{ - HTTPClient: &http.Client{ - Transport: tt, - }, - } - - clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts) - assert.Success(t, "websocket.Dial", err) - - if randBool(t) { - return serverConn, clientConn - } - return clientConn, serverConn -} - -type testTransport struct { - h http.HandlerFunc -} - -func (t testTransport) RoundTrip(r *http.Request) (*http.Response, error) { - clientConn, serverConn := net.Pipe() - - hj := testHijacker{ - ResponseRecorder: httptest.NewRecorder(), - serverConn: serverConn, - hijacked: make(chan struct{}), - } - - done := make(chan struct{}) - t.h.ServeHTTP(hj, r) - - select { - case <-hj.hijacked: - resp := hj.ResponseRecorder.Result() - resp.Body = clientConn - return resp, nil - case <-done: - return hj.ResponseRecorder.Result(), nil - } -} diff --git a/dial.go b/dial.go index 4557602e..a1509ab5 100644 --- a/dial.go +++ b/dial.go @@ -35,6 +35,7 @@ type DialOptions struct { // CompressionOptions controls the compression options. // See docs on the CompressionOptions type. + // TODO make * CompressionOptions CompressionOptions } diff --git a/go.mod b/go.mod index ee1708a2..fc4ebb99 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/gobwas/ws v1.0.2 github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect github.com/golang/protobuf v1.3.3 - github.com/google/go-cmp v0.4.0 // indirect + github.com/google/go-cmp v0.4.0 github.com/gorilla/websocket v1.4.1 github.com/mattn/go-isatty v0.0.12 // indirect go.opencensus.io v0.22.3 // indirect diff --git a/internal/test/cmp/cmp.go b/internal/test/cmp/cmp.go new file mode 100644 index 00000000..d0eee6d0 --- /dev/null +++ b/internal/test/cmp/cmp.go @@ -0,0 +1,22 @@ +package cmp + +import ( + "reflect" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +// Equal checks if v1 and v2 are equal with go-cmp. +func Equal(v1, v2 interface{}) bool { + return cmp.Equal(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { + return true + })) +} + +// Diff returns a human readable diff between v1 and v2 +func Diff(v1, v2 interface{}) string { + return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { + return true + })) +} diff --git a/internal/test/doc.go b/internal/test/doc.go new file mode 100644 index 00000000..94b2e82d --- /dev/null +++ b/internal/test/doc.go @@ -0,0 +1,2 @@ +// Package test contains subpackages only used in tests. +package test diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go new file mode 100644 index 00000000..f3d25f55 --- /dev/null +++ b/internal/test/wstest/pipe.go @@ -0,0 +1,82 @@ +package wstest + +import ( + "bufio" + "context" + "net" + "net/http" + "net/http/httptest" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/test/xrand" +) + +// Pipe is used to create an in memory connection +// between two websockets analogous to net.Pipe. +func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (_ *websocket.Conn, _ *websocket.Conn, err error) { + defer errd.Wrap(&err, "failed to create ws pipe") + + var serverConn *websocket.Conn + var acceptErr error + tt := fakeTransport{ + h: func(w http.ResponseWriter, r *http.Request) { + serverConn, acceptErr = websocket.Accept(w, r, acceptOpts) + }, + } + + if dialOpts == nil { + dialOpts = &websocket.DialOptions{} + } + dialOpts.HTTPClient = &http.Client{ + Transport: tt, + } + + clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts) + if err != nil { + return nil, nil, xerrors.Errorf("failed to dial with fake transport: %w", err) + } + + if serverConn == nil { + return nil, nil, xerrors.Errorf("failed to get server conn from fake transport: %w", acceptErr) + } + + if xrand.True() { + return serverConn, clientConn, nil + } + return clientConn, serverConn, nil +} + +type fakeTransport struct { + h http.HandlerFunc +} + +func (t fakeTransport) RoundTrip(r *http.Request) (*http.Response, error) { + clientConn, serverConn := net.Pipe() + + hj := testHijacker{ + ResponseRecorder: httptest.NewRecorder(), + serverConn: serverConn, + } + + t.h.ServeHTTP(hj, r) + + resp := hj.ResponseRecorder.Result() + if resp.StatusCode == http.StatusSwitchingProtocols { + resp.Body = clientConn + } + return resp, nil +} + +type testHijacker struct { + *httptest.ResponseRecorder + serverConn net.Conn +} + +var _ http.Hijacker = testHijacker{} + +func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil +} diff --git a/internal/test/xrand/xrand.go b/internal/test/xrand/xrand.go new file mode 100644 index 00000000..2f3ad30f --- /dev/null +++ b/internal/test/xrand/xrand.go @@ -0,0 +1,47 @@ +package xrand + +import ( + "crypto/rand" + "fmt" + "math/big" + "strings" +) + +// Bytes generates random bytes with length n. +func Bytes(n int) []byte { + b := make([]byte, n) + _, err := rand.Reader.Read(b) + if err != nil { + panic(fmt.Sprintf("failed to generate rand bytes: %v", err)) + } + return b +} + +// String generates a random string with length n. +func String(n int) string { + s := strings.ToValidUTF8(string(Bytes(n)), "_") + s = strings.ReplaceAll(s, "\x00", "_") + if len(s) > n { + return s[:n] + } + if len(s) < n { + // Pad with = + extra := n - len(s) + return s + strings.Repeat("=", extra) + } + return s +} + +// True returns a randomly generated boolean. +func True() bool { + return Int(2) == 1 +} + +// Int returns a randomly generated integer between [0, max). +func Int(max int) int { + x, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + if err != nil { + panic(fmt.Sprintf("failed to get random int: %v", err)) + } + return int(x.Int64()) +} diff --git a/ws_js_test.go b/ws_js_test.go index 9f725a57..65309bff 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -1,4 +1,4 @@ -package websocket_test +package websocket import ( "context" @@ -6,8 +6,6 @@ import ( "os" "testing" "time" - - "nhooyr.io/websocket" ) func TestEcho(t *testing.T) { @@ -16,17 +14,17 @@ func TestEcho(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ + c, resp, err := Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &DialOptions{ Subprotocols: []string{"echo"}, }) assert.Success(t, err) - defer c.Close(websocket.StatusInternalError, "") + defer c.Close(StatusInternalError, "") assertSubprotocol(t, c, "echo") assert.Equalf(t, &http.Response{}, resp, "http.Response") echoJSON(t, ctx, c, 1024) - assertEcho(t, ctx, c, websocket.MessageBinary, 1024) + assertEcho(t, ctx, c, MessageBinary, 1024) - err = c.Close(websocket.StatusNormalClosure, "") + err = c.Close(StatusNormalClosure, "") assert.Success(t, err) } From 3f2589ffa18b5e61a7786ad5308c5ccc87688cef Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 8 Feb 2020 23:36:06 -0500 Subject: [PATCH 32/55] Remove quite a bit of slog --- ci/test.mk | 2 +- close_test.go | 42 ++++++++++++++++++++++++++---------- conn_test.go | 30 ++++++++++++++++++-------- frame_test.go | 30 ++++++++++++++++++++------ internal/test/wstest/pipe.go | 2 +- internal/test/xrand/xrand.go | 4 ++-- 6 files changed, 79 insertions(+), 31 deletions(-) diff --git a/ci/test.mk b/ci/test.mk index 95e049b2..786a8d77 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -14,4 +14,4 @@ coveralls: gotest gotest: go test -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... sed -i '/stringer\.go/d' ci/out/coverage.prof - sed -i '/assert/d' ci/out/coverage.prof + sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof diff --git a/close_test.go b/close_test.go index 16b570d0..10a35b13 100644 --- a/close_test.go +++ b/close_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - "cdr.dev/slog/sloggers/slogtest/assert" + "nhooyr.io/websocket/internal/test/cmp" ) func TestCloseError(t *testing.T) { @@ -51,13 +51,23 @@ func TestCloseError(t *testing.T) { t.Parallel() _, err := tc.ce.bytesErr() - if tc.success { - assert.Success(t, "CloseError.bytesErr", err) - } else { - assert.Error(t, "CloseError.bytesErr", err) + if tc.success != (err == nil) { + t.Fatalf("unexpected error value (wanted err == nil == %v): %v", tc.success, err) } }) } + + t.Run("Error", func(t *testing.T) { + exp := `status = StatusInternalError and reason = "meow"` + act := CloseError{ + Code: StatusInternalError, + Reason: "meow", + }.Error() + + if (act) != exp { + t.Fatal(cmp.Diff(exp, act)) + } + }) } func Test_parseClosePayload(t *testing.T) { @@ -104,10 +114,14 @@ func Test_parseClosePayload(t *testing.T) { ce, err := parseClosePayload(tc.p) if tc.success { - assert.Success(t, "parse err", err) - assert.Equal(t, "ce", tc.ce, ce) - } else { - assert.Error(t, "parse err", err) + if err != nil { + t.Fatal(err) + } + if !cmp.Equal(tc.ce, ce) { + t.Fatalf("expected %v but got %v", tc.ce, ce) + } + } else if err == nil { + t.Errorf("expected error: %v %v", ce, err) } }) } @@ -153,7 +167,10 @@ func Test_validWireCloseCode(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, "valid", tc.valid, validWireCloseCode(tc.code)) + act := validWireCloseCode(tc.code) + if !cmp.Equal(tc.valid, act) { + t.Fatalf("unexpected valid: %v", cmp.Diff(tc.valid, act)) + } }) } } @@ -190,7 +207,10 @@ func TestCloseStatus(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, "closeStatus", tc.exp, CloseStatus(tc.in)) + act := CloseStatus(tc.in) + if !cmp.Equal(tc.exp, act) { + t.Fatalf("unexpected closeStatus: %v", cmp.Diff(tc.exp, act)) + } }) } } diff --git a/conn_test.go b/conn_test.go index d246f719..02606ef5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -14,13 +14,17 @@ import ( "nhooyr.io/websocket/internal/test/cmp" "nhooyr.io/websocket/internal/test/wstest" "nhooyr.io/websocket/internal/test/xrand" - "nhooyr.io/websocket/wsjson" ) func goFn(fn func() error) chan error { errs := make(chan error) go func() { - defer close(errs) + defer func() { + r := recover() + if r != nil { + errs <- xerrors.Errorf("panic in gofn: %v", r) + } + }() errs <- fn() }() @@ -33,7 +37,7 @@ func TestConn(t *testing.T) { t.Run("data", func(t *testing.T) { t.Parallel() - for i := 0; i < 10; i++ { + for i := 0; i < 5; i++ { t.Run("", func(t *testing.T) { t.Parallel() @@ -41,7 +45,7 @@ func TestConn(t *testing.T) { defer cancel() copts := websocket.CompressionOptions{ - Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled))), + Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), Threshold: xrand.Int(9999), } @@ -70,17 +74,21 @@ func TestConn(t *testing.T) { c2.SetReadLimit(1 << 30) - for i := 0; i < 10; i++ { + for i := 0; i < 5; i++ { n := xrand.Int(131_072) - msg := xrand.String(n) + msg := xrand.Bytes(n) + + expType := websocket.MessageBinary + if xrand.Bool() { + expType = websocket.MessageText + } writeErr := goFn(func() error { - return wsjson.Write(ctx, c2, msg) + return c2.Write(ctx, expType, msg) }) - var act interface{} - err := wsjson.Read(ctx, c2, &act) + actType, act, err := c2.Read(ctx) if err != nil { t.Fatal(err) } @@ -90,6 +98,10 @@ func TestConn(t *testing.T) { t.Fatal(err) } + if expType != actType { + t.Fatalf("unexpected message typ (%v): %v", expType, actType) + } + if !cmp.Equal(msg, act) { t.Fatalf("unexpected msg read: %v", cmp.Diff(msg, act)) } diff --git a/frame_test.go b/frame_test.go index 323ea991..0b770a4c 100644 --- a/frame_test.go +++ b/frame_test.go @@ -13,9 +13,10 @@ import ( "time" _ "unsafe" - "cdr.dev/slog/sloggers/slogtest/assert" "github.com/gobwas/ws" _ "github.com/gorilla/websocket" + + "nhooyr.io/websocket/internal/test/cmp" ) func TestHeader(t *testing.T) { @@ -80,14 +81,22 @@ func testHeader(t *testing.T, h header) { r := bufio.NewReader(b) err := writeFrameHeader(h, w) - assert.Success(t, "writeFrameHeader", err) + if err != nil { + t.Fatal(err) + } err = w.Flush() - assert.Success(t, "flush", err) + if err != nil { + t.Fatal(err) + } h2, err := readFrameHeader(r) - assert.Success(t, "readFrameHeader", err) + if err != nil { + t.Fatal(err) + } - assert.Equal(t, "header", h, h2) + if !cmp.Equal(h, h2) { + t.Fatal(cmp.Diff(h, h2)) + } } func Test_mask(t *testing.T) { @@ -98,8 +107,15 @@ func Test_mask(t *testing.T) { p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} gotKey32 := mask(key32, p) - assert.Equal(t, "mask", []byte{0, 0, 0, 0x0d, 0x6}, p) - assert.Equal(t, "maskKey", bits.RotateLeft32(key32, -8), gotKey32) + expP := []byte{0, 0, 0, 0x0d, 0x6} + if !cmp.Equal(expP, p) { + t.Fatal(cmp.Diff(expP, p)) + } + + expKey32 := bits.RotateLeft32(key32, -8) + if !cmp.Equal(expKey32, gotKey32) { + t.Fatal(cmp.Diff(expKey32, gotKey32)) + } } func basicMask(maskKey [4]byte, pos int, b []byte) int { diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index f3d25f55..e958aea4 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -43,7 +43,7 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) return nil, nil, xerrors.Errorf("failed to get server conn from fake transport: %w", acceptErr) } - if xrand.True() { + if xrand.Bool() { return serverConn, clientConn, nil } return clientConn, serverConn, nil diff --git a/internal/test/xrand/xrand.go b/internal/test/xrand/xrand.go index 2f3ad30f..8de1ede8 100644 --- a/internal/test/xrand/xrand.go +++ b/internal/test/xrand/xrand.go @@ -32,8 +32,8 @@ func String(n int) string { return s } -// True returns a randomly generated boolean. -func True() bool { +// Bool returns a randomly generated boolean. +func Bool() bool { return Int(2) == 1 } From b53f306c00debd46e5ed5debd2f9594ee8889f5c Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 9 Feb 2020 00:47:34 -0500 Subject: [PATCH 33/55] Get Wasm tests working --- accept.go | 7 +- accept_test.go | 55 +++++--- autobahn_test.go | 33 +++-- close.go | 193 ------------------------- close_notjs.go | 199 ++++++++++++++++++++++++++ compress.go | 155 -------------------- compress_notjs.go | 156 +++++++++++++++++++++ compress_test.go | 29 ++-- conn.go | 261 ---------------------------------- conn_notjs.go | 264 +++++++++++++++++++++++++++++++++++ conn_test.go | 131 +++++++---------- dial.go | 7 +- dial_test.go | 2 +- example_test.go | 3 +- frame.go | 2 - internal/test/cmp/cmp.go | 9 ++ internal/test/wstest/echo.go | 90 ++++++++++++ internal/test/wstest/pipe.go | 3 + internal/test/wstest/url.go | 11 ++ internal/xsync/go.go | 25 ++++ internal/xsync/go_test.go | 20 +++ internal/xsync/int64.go | 23 +++ read.go | 19 +-- ws_js.go | 30 ++-- ws_js_test.go | 40 ++++-- 25 files changed, 985 insertions(+), 782 deletions(-) create mode 100644 close_notjs.go create mode 100644 compress_notjs.go create mode 100644 conn_notjs.go create mode 100644 internal/test/wstest/echo.go create mode 100644 internal/test/wstest/url.go create mode 100644 internal/xsync/go.go create mode 100644 internal/xsync/go_test.go create mode 100644 internal/xsync/int64.go diff --git a/accept.go b/accept.go index 0394fa6d..31f104b2 100644 --- a/accept.go +++ b/accept.go @@ -39,7 +39,7 @@ type AcceptOptions struct { // CompressionOptions controls the compression options. // See docs on the CompressionOptions type. - CompressionOptions CompressionOptions + CompressionOptions *CompressionOptions } // Accept accepts a WebSocket handshake from a client and upgrades the @@ -59,6 +59,11 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con if opts == nil { opts = &AcceptOptions{} } + opts = &*opts + + if opts.CompressionOptions == nil { + opts.CompressionOptions = &CompressionOptions{} + } err = verifyClientRequest(r) if err != nil { diff --git a/accept_test.go b/accept_test.go index 3e8b1f46..18302da5 100644 --- a/accept_test.go +++ b/accept_test.go @@ -10,8 +10,9 @@ import ( "strings" "testing" - "cdr.dev/slog/sloggers/slogtest/assert" "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/test/cmp" ) func TestAccept(t *testing.T) { @@ -24,7 +25,9 @@ func TestAccept(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) _, err := Accept(w, r, nil) - assert.ErrorContains(t, "Accept", err, "protocol violation") + if !cmp.ErrorContains(err, "protocol violation") { + t.Fatal(err) + } }) t.Run("badOrigin", func(t *testing.T) { @@ -39,7 +42,9 @@ func TestAccept(t *testing.T) { r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) - assert.ErrorContains(t, "Accept", err, "request Origin \"harhar.com\" is not authorized for Host") + if !cmp.ErrorContains(err, `request Origin "harhar.com" is not authorized for Host`) { + t.Fatal(err) + } }) t.Run("badCompression", func(t *testing.T) { @@ -56,7 +61,9 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") _, err := Accept(w, r, nil) - assert.ErrorContains(t, "Accept", err, "unsupported permessage-deflate parameter") + if !cmp.ErrorContains(err, `unsupported permessage-deflate parameter`) { + t.Fatal(err) + } }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -70,7 +77,9 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) - assert.ErrorContains(t, "Accept", err, "http.ResponseWriter does not implement http.Hijacker") + if !cmp.ErrorContains(err, `http.ResponseWriter does not implement http.Hijacker`) { + t.Fatal(err) + } }) t.Run("badHijack", func(t *testing.T) { @@ -90,7 +99,9 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) - assert.ErrorContains(t, "Accept", err, "failed to hijack connection") + if !cmp.ErrorContains(err, `failed to hijack connection`) { + t.Fatal(err) + } }) } @@ -182,10 +193,8 @@ func Test_verifyClientHandshake(t *testing.T) { } err := verifyClientRequest(r) - if tc.success { - assert.Success(t, "verifyClientRequest", err) - } else { - assert.Error(t, "verifyClientRequest", err) + if tc.success != (err == nil) { + t.Fatalf("unexpected error value: %v", err) } }) } @@ -235,7 +244,9 @@ func Test_selectSubprotocol(t *testing.T) { r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) negotiated := selectSubprotocol(r, tc.serverProtocols) - assert.Equal(t, "negotiated", tc.negotiated, negotiated) + if !cmp.Equal(tc.negotiated, negotiated) { + t.Fatalf("unexpected negotiated: %v", cmp.Diff(tc.negotiated, negotiated)) + } }) } } @@ -289,10 +300,8 @@ func Test_authenticateOrigin(t *testing.T) { r.Header.Set("Origin", tc.origin) err := authenticateOrigin(r) - if tc.success { - assert.Success(t, "authenticateOrigin", err) - } else { - assert.Error(t, "authenticateOrigin", err) + if tc.success != (err == nil) { + t.Fatalf("unexpected error value: %v", err) } }) } @@ -364,13 +373,21 @@ func Test_acceptCompression(t *testing.T) { w := httptest.NewRecorder() copts, err := acceptCompression(r, w, tc.mode) if tc.error { - assert.Error(t, "acceptCompression", err) + if err == nil { + t.Fatalf("expected error: %v", copts) + } return } - assert.Success(t, "acceptCompression", err) - assert.Equal(t, "compresssionOpts", tc.expCopts, copts) - assert.Equal(t, "respHeader", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) + if err != nil { + t.Fatal(err) + } + if !cmp.Equal(tc.expCopts, copts) { + t.Fatalf("unexpected compression options: %v", cmp.Diff(tc.expCopts, copts)) + } + if !cmp.Equal(tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) { + t.Fatalf("unexpected respHeader: %v", cmp.Diff(tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))) + } }) } } diff --git a/autobahn_test.go b/autobahn_test.go index d730cf4a..4d0bd1b5 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -15,11 +15,11 @@ import ( "testing" "time" - "cdr.dev/slog/sloggers/slogtest/assert" "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/test/wstest" ) var excludedAutobahnCases = []string{ @@ -45,14 +45,20 @@ func TestAutobahn(t *testing.T) { defer cancel() wstestURL, closeFn, err := wstestClientServer(ctx) - assert.Success(t, "wstestClient", err) + if err != nil { + t.Fatal(err) + } defer closeFn() err = waitWS(ctx, wstestURL) - assert.Success(t, "waitWS", err) + if err != nil { + t.Fatal(err) + } cases, err := wstestCaseCount(ctx, wstestURL) - assert.Success(t, "wstestCaseCount", err) + if err != nil { + t.Fatal(err) + } t.Run("cases", func(t *testing.T) { for i := 1; i <= cases; i++ { @@ -62,16 +68,19 @@ func TestAutobahn(t *testing.T) { defer cancel() c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) - assert.Success(t, "autobahn dial", err) - - err = echoLoop(ctx, c) + if err != nil { + t.Fatal(err) + } + err = wstest.EchoLoop(ctx, c) t.Logf("echoLoop: %v", err) }) } }) c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) - assert.Success(t, "dial", err) + if err != nil { + t.Fatal(err) + } c.Close(websocket.StatusNormalClosure, "") checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") @@ -163,14 +172,18 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { func checkWSTestIndex(t *testing.T, path string) { wstestOut, err := ioutil.ReadFile(path) - assert.Success(t, "ioutil.ReadFile", err) + if err != nil { + t.Fatal(err) + } var indexJSON map[string]map[string]struct { Behavior string `json:"behavior"` BehaviorClose string `json:"behaviorClose"` } err = json.Unmarshal(wstestOut, &indexJSON) - assert.Success(t, "json.Unmarshal", err) + if err != nil { + t.Fatal(err) + } for _, tests := range indexJSON { for test, result := range tests { diff --git a/close.go b/close.go index 931160e6..20073233 100644 --- a/close.go +++ b/close.go @@ -1,17 +1,9 @@ -// +build !js - package websocket import ( - "context" - "encoding/binary" "fmt" - "log" - "time" "golang.org/x/xerrors" - - "nhooyr.io/websocket/internal/errd" ) // StatusCode represents a WebSocket status code. @@ -83,188 +75,3 @@ func CloseStatus(err error) StatusCode { } return -1 } - -// Close performs the WebSocket close handshake with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for -// the peer to send a close frame. -// All data messages received from the peer during the close handshake will be discarded. -// -// The connection can only be closed once. Additional calls to Close -// are no-ops. -// -// The maximum length of reason must be 125 bytes. Avoid -// sending a dynamic reason. -// -// Close will unblock all goroutines interacting with the connection once -// complete. -func (c *Conn) Close(code StatusCode, reason string) error { - return c.closeHandshake(code, reason) -} - -func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { - defer errd.Wrap(&err, "failed to close WebSocket") - - err = c.writeClose(code, reason) - if err != nil { - return err - } - - err = c.waitCloseHandshake() - if CloseStatus(err) == -1 { - return err - } - return nil -} - -func (c *Conn) writeError(code StatusCode, err error) { - c.setCloseErr(err) - c.writeClose(code, err.Error()) - c.close(nil) -} - -func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - closing := c.wroteClose - c.wroteClose = true - c.closeMu.Unlock() - if closing { - return xerrors.New("already wrote close") - } - - ce := CloseError{ - Code: code, - Reason: reason, - } - - c.setCloseErr(xerrors.Errorf("sent close frame: %w", ce)) - - var p []byte - if ce.Code != StatusNoStatusRcvd { - p = ce.bytes() - } - - return c.writeControl(context.Background(), opClose, p) -} - -func (c *Conn) waitCloseHandshake() error { - defer c.close(nil) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - err := c.readMu.Lock(ctx) - if err != nil { - return err - } - defer c.readMu.Unlock() - - if c.readCloseFrameErr != nil { - return c.readCloseFrameErr - } - - for { - h, err := c.readLoop(ctx) - if err != nil { - return err - } - - for i := int64(0); i < h.payloadLength; i++ { - _, err := c.br.ReadByte() - if err != nil { - return err - } - } - } -} - -func parseClosePayload(p []byte) (CloseError, error) { - if len(p) == 0 { - return CloseError{ - Code: StatusNoStatusRcvd, - }, nil - } - - if len(p) < 2 { - return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) - } - - ce := CloseError{ - Code: StatusCode(binary.BigEndian.Uint16(p)), - Reason: string(p[2:]), - } - - if !validWireCloseCode(ce.Code) { - return CloseError{}, xerrors.Errorf("invalid status code %v", ce.Code) - } - - return ce, nil -} - -// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// and https://tools.ietf.org/html/rfc6455#section-7.4.1 -func validWireCloseCode(code StatusCode) bool { - switch code { - case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: - return false - } - - if code >= StatusNormalClosure && code <= StatusBadGateway { - return true - } - if code >= 3000 && code <= 4999 { - return true - } - - return false -} - -func (ce CloseError) bytes() []byte { - p, err := ce.bytesErr() - if err != nil { - log.Printf("websocket: failed to marshal close frame: %+v", err) - ce = CloseError{ - Code: StatusInternalError, - } - p, _ = ce.bytesErr() - } - return p -} - -const maxCloseReason = maxControlPayload - 2 - -func (ce CloseError) bytesErr() ([]byte, error) { - if len(ce.Reason) > maxCloseReason { - return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) - } - - if !validWireCloseCode(ce.Code) { - return nil, xerrors.Errorf("status code %v cannot be set", ce.Code) - } - - buf := make([]byte, 2+len(ce.Reason)) - binary.BigEndian.PutUint16(buf, uint16(ce.Code)) - copy(buf[2:], ce.Reason) - return buf, nil -} - -func (c *Conn) setCloseErr(err error) { - c.closeMu.Lock() - c.setCloseErrLocked(err) - c.closeMu.Unlock() -} - -func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil { - c.closeErr = xerrors.Errorf("WebSocket closed: %w", err) - } -} - -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} diff --git a/close_notjs.go b/close_notjs.go new file mode 100644 index 00000000..dd1b0e0d --- /dev/null +++ b/close_notjs.go @@ -0,0 +1,199 @@ +// +build !js + +package websocket + +import ( + "context" + "encoding/binary" + "log" + "time" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/errd" +) + +// Close performs the WebSocket close handshake with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for +// the peer to send a close frame. +// All data messages received from the peer during the close handshake will be discarded. +// +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes. Avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection once +// complete. +func (c *Conn) Close(code StatusCode, reason string) error { + return c.closeHandshake(code, reason) +} + +func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + err = c.writeClose(code, reason) + if err != nil { + return err + } + + err = c.waitCloseHandshake() + if CloseStatus(err) == -1 { + return err + } + return nil +} + +func (c *Conn) writeError(code StatusCode, err error) { + c.setCloseErr(err) + c.writeClose(code, err.Error()) + c.close(nil) +} + +func (c *Conn) writeClose(code StatusCode, reason string) error { + c.closeMu.Lock() + closing := c.wroteClose + c.wroteClose = true + c.closeMu.Unlock() + if closing { + return xerrors.New("already wrote close") + } + + ce := CloseError{ + Code: code, + Reason: reason, + } + + c.setCloseErr(xerrors.Errorf("sent close frame: %w", ce)) + + var p []byte + if ce.Code != StatusNoStatusRcvd { + p = ce.bytes() + } + + return c.writeControl(context.Background(), opClose, p) +} + +func (c *Conn) waitCloseHandshake() error { + defer c.close(nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + err := c.readMu.Lock(ctx) + if err != nil { + return err + } + defer c.readMu.Unlock() + + if c.readCloseFrameErr != nil { + return c.readCloseFrameErr + } + + for { + h, err := c.readLoop(ctx) + if err != nil { + return err + } + + for i := int64(0); i < h.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + } +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + if len(p) < 2 { + return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + ce := CloseError{ + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, xerrors.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +func (ce CloseError) bytes() []byte { + p, err := ce.bytesErr() + if err != nil { + log.Printf("websocket: failed to marshal close frame: %v", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytesErr() + } + return p +} + +const maxCloseReason = maxControlPayload - 2 + +func (ce CloseError) bytesErr() ([]byte, error) { + if len(ce.Reason) > maxCloseReason { + return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + } + + if !validWireCloseCode(ce.Code) { + return nil, xerrors.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +func (c *Conn) setCloseErr(err error) { + c.closeMu.Lock() + c.setCloseErrLocked(err) + c.closeMu.Unlock() +} + +func (c *Conn) setCloseErrLocked(err error) { + if c.closeErr == nil { + c.closeErr = xerrors.Errorf("WebSocket closed: %w", err) + } +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/compress.go b/compress.go index efd89b33..918b3b49 100644 --- a/compress.go +++ b/compress.go @@ -1,14 +1,5 @@ -// +build !js - package websocket -import ( - "compress/flate" - "io" - "net/http" - "sync" -) - // CompressionOptions represents the available deflate extension options. // See https://tools.ietf.org/html/rfc7692 type CompressionOptions struct { @@ -60,149 +51,3 @@ const ( // important than bandwidth. CompressionDisabled ) - -func (m CompressionMode) opts() *compressionOptions { - if m == CompressionDisabled { - return nil - } - return &compressionOptions{ - clientNoContextTakeover: m == CompressionNoContextTakeover, - serverNoContextTakeover: m == CompressionNoContextTakeover, - } -} - -type compressionOptions struct { - clientNoContextTakeover bool - serverNoContextTakeover bool -} - -func (copts *compressionOptions) setHeader(h http.Header) { - s := "permessage-deflate" - if copts.clientNoContextTakeover { - s += "; client_no_context_takeover" - } - if copts.serverNoContextTakeover { - s += "; server_no_context_takeover" - } - h.Set("Sec-WebSocket-Extensions", s) -} - -// These bytes are required to get flate.Reader to return. -// They are removed when sending to avoid the overhead as -// WebSocket framing tell's when the message has ended but then -// we need to add them back otherwise flate.Reader keeps -// trying to return more bytes. -const deflateMessageTail = "\x00\x00\xff\xff" - -func (c *Conn) writeNoContextTakeOver() bool { - return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover -} - -func (c *Conn) readNoContextTakeOver() bool { - return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover -} - -type trimLastFourBytesWriter struct { - w io.Writer - tail []byte -} - -func (tw *trimLastFourBytesWriter) reset() { - tw.tail = tw.tail[:0] -} - -func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { - extra := len(tw.tail) + len(p) - 4 - - if extra <= 0 { - tw.tail = append(tw.tail, p...) - return len(p), nil - } - - // Now we need to write as many extra bytes as we can from the previous tail. - if extra > len(tw.tail) { - extra = len(tw.tail) - } - if extra > 0 { - _, err := tw.w.Write(tw.tail[:extra]) - if err != nil { - return 0, err - } - tw.tail = tw.tail[extra:] - } - - // If p is less than or equal to 4 bytes, - // all of it is is part of the tail. - if len(p) <= 4 { - tw.tail = append(tw.tail, p...) - return len(p), nil - } - - // Otherwise, only the last 4 bytes are. - tw.tail = append(tw.tail, p[len(p)-4:]...) - - p = p[:len(p)-4] - n, err := tw.w.Write(p) - return n + 4, err -} - -var flateReaderPool sync.Pool - -func getFlateReader(r io.Reader, dict []byte) io.Reader { - fr, ok := flateReaderPool.Get().(io.Reader) - if !ok { - return flate.NewReaderDict(r, dict) - } - fr.(flate.Resetter).Reset(r, dict) - return fr -} - -func putFlateReader(fr io.Reader) { - flateReaderPool.Put(fr) -} - -var flateWriterPool sync.Pool - -func getFlateWriter(w io.Writer) *flate.Writer { - fw, ok := flateWriterPool.Get().(*flate.Writer) - if !ok { - fw, _ = flate.NewWriter(w, flate.BestSpeed) - return fw - } - fw.Reset(w) - return fw -} - -func putFlateWriter(w *flate.Writer) { - flateWriterPool.Put(w) -} - -type slidingWindow struct { - r io.Reader - buf []byte -} - -func newSlidingWindow(n int) *slidingWindow { - return &slidingWindow{ - buf: make([]byte, 0, n), - } -} - -func (w *slidingWindow) write(p []byte) { - if len(p) >= cap(w.buf) { - w.buf = w.buf[:cap(w.buf)] - p = p[len(p)-cap(w.buf):] - copy(w.buf, p) - return - } - - left := cap(w.buf) - len(w.buf) - if left < len(p) { - // We need to shift spaceNeeded bytes from the end to make room for p at the end. - spaceNeeded := len(p) - left - copy(w.buf, w.buf[spaceNeeded:]) - w.buf = w.buf[:len(w.buf)-spaceNeeded] - } - - w.buf = append(w.buf, p...) -} diff --git a/compress_notjs.go b/compress_notjs.go new file mode 100644 index 00000000..8bc2f87b --- /dev/null +++ b/compress_notjs.go @@ -0,0 +1,156 @@ +// +build !js + +package websocket + +import ( + "compress/flate" + "io" + "net/http" + "sync" +) + +func (m CompressionMode) opts() *compressionOptions { + if m == CompressionDisabled { + return nil + } + return &compressionOptions{ + clientNoContextTakeover: m == CompressionNoContextTakeover, + serverNoContextTakeover: m == CompressionNoContextTakeover, + } +} + +type compressionOptions struct { + clientNoContextTakeover bool + serverNoContextTakeover bool +} + +func (copts *compressionOptions) setHeader(h http.Header) { + s := "permessage-deflate" + if copts.clientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.serverNoContextTakeover { + s += "; server_no_context_takeover" + } + h.Set("Sec-WebSocket-Extensions", s) +} + +// These bytes are required to get flate.Reader to return. +// They are removed when sending to avoid the overhead as +// WebSocket framing tell's when the message has ended but then +// we need to add them back otherwise flate.Reader keeps +// trying to return more bytes. +const deflateMessageTail = "\x00\x00\xff\xff" + +func (c *Conn) writeNoContextTakeOver() bool { + return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover +} + +func (c *Conn) readNoContextTakeOver() bool { + return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover +} + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (tw *trimLastFourBytesWriter) reset() { + tw.tail = tw.tail[:0] +} + +func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + extra := len(tw.tail) + len(p) - 4 + + if extra <= 0 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(tw.tail) { + extra = len(tw.tail) + } + if extra > 0 { + _, err := tw.w.Write(tw.tail[:extra]) + if err != nil { + return 0, err + } + tw.tail = tw.tail[extra:] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + tw.tail = append(tw.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := tw.w.Write(p) + return n + 4, err +} + +var flateReaderPool sync.Pool + +func getFlateReader(r io.Reader, dict []byte) io.Reader { + fr, ok := flateReaderPool.Get().(io.Reader) + if !ok { + return flate.NewReaderDict(r, dict) + } + fr.(flate.Resetter).Reset(r, dict) + return fr +} + +func putFlateReader(fr io.Reader) { + flateReaderPool.Put(fr) +} + +var flateWriterPool sync.Pool + +func getFlateWriter(w io.Writer) *flate.Writer { + fw, ok := flateWriterPool.Get().(*flate.Writer) + if !ok { + fw, _ = flate.NewWriter(w, flate.BestSpeed) + return fw + } + fw.Reset(w) + return fw +} + +func putFlateWriter(w *flate.Writer) { + flateWriterPool.Put(w) +} + +type slidingWindow struct { + r io.Reader + buf []byte +} + +func newSlidingWindow(n int) *slidingWindow { + return &slidingWindow{ + buf: make([]byte, 0, n), + } +} + +func (w *slidingWindow) write(p []byte) { + if len(p) >= cap(w.buf) { + w.buf = w.buf[:cap(w.buf)] + p = p[len(p)-cap(w.buf):] + copy(w.buf, p) + return + } + + left := cap(w.buf) - len(w.buf) + if left < len(p) { + // We need to shift spaceNeeded bytes from the end to make room for p at the end. + spaceNeeded := len(p) - left + copy(w.buf, w.buf[spaceNeeded:]) + w.buf = w.buf[:len(w.buf)-spaceNeeded] + } + + w.buf = append(w.buf, p...) +} diff --git a/compress_test.go b/compress_test.go index 15d334d6..51f658c8 100644 --- a/compress_test.go +++ b/compress_test.go @@ -1,11 +1,11 @@ +// +build !js + package websocket import ( "strings" "testing" - "cdr.dev/slog/sloggers/slogtest/assert" - "nhooyr.io/websocket/internal/test/xrand" ) @@ -15,14 +15,21 @@ func Test_slidingWindow(t *testing.T) { const testCount = 99 const maxWindow = 99999 for i := 0; i < testCount; i++ { - input := xrand.String(maxWindow) - windowLength := xrand.Int(maxWindow) - r := newSlidingWindow(windowLength) - r.write([]byte(input)) - - if cap(r.buf) != windowLength { - t.Fatalf("sliding window length changed somehow: %q and windowLength %d", input, windowLength) - } - assert.True(t, "hasSuffix", strings.HasSuffix(input, string(r.buf))) + t.Run("", func(t *testing.T) { + t.Parallel() + + input := xrand.String(maxWindow) + windowLength := xrand.Int(maxWindow) + r := newSlidingWindow(windowLength) + r.write([]byte(input)) + + if cap(r.buf) != windowLength { + t.Fatalf("sliding window length changed somehow: %q and windowLength %d", input, windowLength) + } + + if !strings.HasSuffix(input, string(r.buf)) { + t.Fatalf("r.buf is not a suffix of input: %q and %q", input, r.buf) + } + }) } } diff --git a/conn.go b/conn.go index 163802bb..e58a8748 100644 --- a/conn.go +++ b/conn.go @@ -2,18 +2,6 @@ package websocket -import ( - "bufio" - "context" - "io" - "runtime" - "strconv" - "sync" - "sync/atomic" - - "golang.org/x/xerrors" -) - // MessageType represents the type of a WebSocket message. // See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int @@ -25,252 +13,3 @@ const ( // MessageBinary is for binary messages like protobufs. MessageBinary ) - -// Conn represents a WebSocket connection. -// All methods may be called concurrently except for Reader and Read. -// -// You must always read from the connection. Otherwise control -// frames will not be handled. See Reader and CloseRead. -// -// Be sure to call Close on the connection when you -// are finished with it to release associated resources. -// -// On any error from any method, the connection is closed -// with an appropriate reason. -type Conn struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - br *bufio.Reader - bw *bufio.Writer - - readTimeout chan context.Context - writeTimeout chan context.Context - - // Read state. - readMu *mu - readControlBuf [maxControlPayload]byte - msgReader *msgReader - readCloseFrameErr error - - // Write state. - msgWriter *msgWriter - writeFrameMu *mu - writeBuf []byte - writeHeader header - - closed chan struct{} - closeMu sync.Mutex - closeErr error - wroteClose bool - - pingCounter int32 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} -} - -type connConfig struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - - br *bufio.Reader - bw *bufio.Writer -} - -func newConn(cfg connConfig) *Conn { - c := &Conn{ - subprotocol: cfg.subprotocol, - rwc: cfg.rwc, - client: cfg.client, - copts: cfg.copts, - flateThreshold: cfg.flateThreshold, - - br: cfg.br, - bw: cfg.bw, - - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), - - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), - } - if c.flate() && c.flateThreshold == 0 { - c.flateThreshold = 256 - if c.writeNoContextTakeOver() { - c.flateThreshold = 512 - } - } - - c.readMu = newMu(c) - c.writeFrameMu = newMu(c) - - c.msgReader = newMsgReader(c) - - c.msgWriter = newMsgWriter(c) - if c.client { - c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) - } - - runtime.SetFinalizer(c, func(c *Conn) { - c.close(xerrors.New("connection garbage collected")) - }) - - go c.timeoutLoop() - - return c -} - -// Subprotocol returns the negotiated subprotocol. -// An empty string means the default protocol. -func (c *Conn) Subprotocol() string { - return c.subprotocol -} - -func (c *Conn) close(err error) { - c.closeMu.Lock() - defer c.closeMu.Unlock() - - if c.isClosed() { - return - } - close(c.closed) - runtime.SetFinalizer(c, nil) - c.setCloseErrLocked(err) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - c.rwc.Close() - - go func() { - if c.client { - c.writeFrameMu.Lock(context.Background()) - putBufioWriter(c.bw) - } - c.msgWriter.close() - - c.msgReader.close() - if c.client { - putBufioReader(c.br) - } - }() -} - -func (c *Conn) timeoutLoop() { - readCtx := context.Background() - writeCtx := context.Background() - - for { - select { - case <-c.closed: - return - - case writeCtx = <-c.writeTimeout: - case readCtx = <-c.readTimeout: - - case <-readCtx.Done(): - c.setCloseErr(xerrors.Errorf("read timed out: %w", readCtx.Err())) - go c.writeError(StatusPolicyViolation, xerrors.New("timed out")) - case <-writeCtx.Done(): - c.close(xerrors.Errorf("write timed out: %w", writeCtx.Err())) - return - } - } -} - -func (c *Conn) flate() bool { - return c.copts != nil -} - -// Ping sends a ping to the peer and waits for a pong. -// Use this to measure latency or ensure the peer is responsive. -// Ping must be called concurrently with Reader as it does -// not read from the connection but instead waits for a Reader call -// to read the pong. -// -// TCP Keepalives should suffice for most use cases. -func (c *Conn) Ping(ctx context.Context) error { - p := atomic.AddInt32(&c.pingCounter, 1) - - err := c.ping(ctx, strconv.Itoa(int(p))) - if err != nil { - return xerrors.Errorf("failed to ping: %w", err) - } - return nil -} - -func (c *Conn) ping(ctx context.Context, p string) error { - pong := make(chan struct{}) - - c.activePingsMu.Lock() - c.activePings[p] = pong - c.activePingsMu.Unlock() - - defer func() { - c.activePingsMu.Lock() - delete(c.activePings, p) - c.activePingsMu.Unlock() - }() - - err := c.writeControl(ctx, opPing, []byte(p)) - if err != nil { - return err - } - - select { - case <-c.closed: - return c.closeErr - case <-ctx.Done(): - err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err - case <-pong: - return nil - } -} - -type mu struct { - c *Conn - ch chan struct{} -} - -func newMu(c *Conn) *mu { - return &mu{ - c: c, - ch: make(chan struct{}, 1), - } -} - -func (m *mu) Lock(ctx context.Context) error { - select { - case <-m.c.closed: - return m.c.closeErr - case <-ctx.Done(): - err := xerrors.Errorf("failed to acquire lock: %w", ctx.Err()) - m.c.close(err) - return err - case m.ch <- struct{}{}: - return nil - } -} - -func (m *mu) TryLock() bool { - select { - case m.ch <- struct{}{}: - return true - default: - return false - } -} - -func (m *mu) Unlock() { - select { - case <-m.ch: - default: - } -} diff --git a/conn_notjs.go b/conn_notjs.go new file mode 100644 index 00000000..d2fea4d4 --- /dev/null +++ b/conn_notjs.go @@ -0,0 +1,264 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "io" + "runtime" + "strconv" + "sync" + "sync/atomic" + + "golang.org/x/xerrors" +) + +// Conn represents a WebSocket connection. +// All methods may be called concurrently except for Reader and Read. +// +// You must always read from the connection. Otherwise control +// frames will not be handled. See Reader and CloseRead. +// +// Be sure to call Close on the connection when you +// are finished with it to release associated resources. +// +// On any error from any method, the connection is closed +// with an appropriate reason. +type Conn struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer + + readTimeout chan context.Context + writeTimeout chan context.Context + + // Read state. + readMu *mu + readControlBuf [maxControlPayload]byte + msgReader *msgReader + readCloseFrameErr error + + // Write state. + msgWriter *msgWriter + writeFrameMu *mu + writeBuf []byte + writeHeader header + + closed chan struct{} + closeMu sync.Mutex + closeErr error + wroteClose bool + + pingCounter int32 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} +} + +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + + br *bufio.Reader + bw *bufio.Writer +} + +func newConn(cfg connConfig) *Conn { + c := &Conn{ + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + } + if c.flate() && c.flateThreshold == 0 { + c.flateThreshold = 256 + if c.writeNoContextTakeOver() { + c.flateThreshold = 512 + } + } + + c.readMu = newMu(c) + c.writeFrameMu = newMu(c) + + c.msgReader = newMsgReader(c) + + c.msgWriter = newMsgWriter(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } + + runtime.SetFinalizer(c, func(c *Conn) { + c.close(xerrors.New("connection garbage collected")) + }) + + go c.timeoutLoop() + + return c +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +func (c *Conn) close(err error) { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.isClosed() { + return + } + close(c.closed) + runtime.SetFinalizer(c, nil) + c.setCloseErrLocked(err) + + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.rwc.Close() + + go func() { + if c.client { + c.writeFrameMu.Lock(context.Background()) + putBufioWriter(c.bw) + } + c.msgWriter.close() + + c.msgReader.close() + if c.client { + putBufioReader(c.br) + } + }() +} + +func (c *Conn) timeoutLoop() { + readCtx := context.Background() + writeCtx := context.Background() + + for { + select { + case <-c.closed: + return + + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: + + case <-readCtx.Done(): + c.setCloseErr(xerrors.Errorf("read timed out: %w", readCtx.Err())) + go c.writeError(StatusPolicyViolation, xerrors.New("timed out")) + case <-writeCtx.Done(): + c.close(xerrors.Errorf("write timed out: %w", writeCtx.Err())) + return + } + } +} + +func (c *Conn) flate() bool { + return c.copts != nil +} + +// Ping sends a ping to the peer and waits for a pong. +// Use this to measure latency or ensure the peer is responsive. +// Ping must be called concurrently with Reader as it does +// not read from the connection but instead waits for a Reader call +// to read the pong. +// +// TCP Keepalives should suffice for most use cases. +func (c *Conn) Ping(ctx context.Context) error { + p := atomic.AddInt32(&c.pingCounter, 1) + + err := c.ping(ctx, strconv.Itoa(int(p))) + if err != nil { + return xerrors.Errorf("failed to ping: %w", err) + } + return nil +} + +func (c *Conn) ping(ctx context.Context, p string) error { + pong := make(chan struct{}) + + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() + + defer func() { + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() + }() + + err := c.writeControl(ctx, opPing, []byte(p)) + if err != nil { + return err + } + + select { + case <-c.closed: + return c.closeErr + case <-ctx.Done(): + err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err()) + c.close(err) + return err + case <-pong: + return nil + } +} + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) Lock(ctx context.Context) error { + select { + case <-m.c.closed: + return m.c.closeErr + case <-ctx.Done(): + err := xerrors.Errorf("failed to acquire lock: %w", ctx.Err()) + m.c.close(err) + return err + case m.ch <- struct{}{}: + return nil + } +} + +func (m *mu) TryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + +func (m *mu) Unlock() { + select { + case <-m.ch: + default: + } +} diff --git a/conn_test.go b/conn_test.go index 02606ef5..5c817a25 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,33 +4,23 @@ package websocket_test import ( "context" - "io" + "fmt" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "sync" "testing" "time" "golang.org/x/xerrors" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/test/cmp" "nhooyr.io/websocket/internal/test/wstest" "nhooyr.io/websocket/internal/test/xrand" + "nhooyr.io/websocket/internal/xsync" ) -func goFn(fn func() error) chan error { - errs := make(chan error) - go func() { - defer func() { - r := recover() - if r != nil { - errs <- xerrors.Errorf("panic in gofn: %v", r) - } - }() - errs <- fn() - }() - - return errs -} - func TestConn(t *testing.T) { t.Parallel() @@ -44,7 +34,7 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - copts := websocket.CompressionOptions{ + copts := &websocket.CompressionOptions{ Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), Threshold: xrand.Int(9999), } @@ -60,8 +50,8 @@ func TestConn(t *testing.T) { defer c1.Close(websocket.StatusInternalError, "") defer c2.Close(websocket.StatusInternalError, "") - echoLoopErr := goFn(func() error { - err := echoLoop(ctx, c1) + echoLoopErr := xsync.Go(func() error { + err := wstest.EchoLoop(ctx, c1) return assertCloseStatus(websocket.StatusNormalClosure, err) }) defer func() { @@ -72,39 +62,13 @@ func TestConn(t *testing.T) { }() defer cancel() - c2.SetReadLimit(1 << 30) + c2.SetReadLimit(131072) for i := 0; i < 5; i++ { - n := xrand.Int(131_072) - - msg := xrand.Bytes(n) - - expType := websocket.MessageBinary - if xrand.Bool() { - expType = websocket.MessageText - } - - writeErr := goFn(func() error { - return c2.Write(ctx, expType, msg) - }) - - actType, act, err := c2.Read(ctx) - if err != nil { - t.Fatal(err) - } - - err = <-writeErr + err := wstest.Echo(ctx, c2, 131072) if err != nil { t.Fatal(err) } - - if expType != actType { - t.Fatalf("unexpected message typ (%v): %v", expType, actType) - } - - if !cmp.Equal(msg, act) { - t.Fatalf("unexpected msg read: %v", cmp.Diff(msg, act)) - } } c2.Close(websocket.StatusNormalClosure, "") @@ -113,47 +77,50 @@ func TestConn(t *testing.T) { }) } -func assertCloseStatus(exp websocket.StatusCode, err error) error { - if websocket.CloseStatus(err) == -1 { - return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err) - } - if websocket.CloseStatus(err) != exp { - return xerrors.Errorf("unexpected close status (%v):%v", exp, err) - } - return nil -} - -// echoLoop echos every msg received from c until an error -// occurs or the context expires. -// The read limit is set to 1 << 30. -func echoLoop(ctx context.Context, c *websocket.Conn) error { - defer c.Close(websocket.StatusInternalError, "") - - c.SetReadLimit(1 << 30) +func TestWasm(t *testing.T) { + t.Parallel() - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() + var wg sync.WaitGroup + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wg.Add(1) + defer wg.Done() - b := make([]byte, 32<<10) - for { - typ, r, err := c.Reader(ctx) + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + InsecureSkipVerify: true, + }) if err != nil { - return err + t.Error(err) + return } + defer c.Close(websocket.StatusInternalError, "") - w, err := c.Writer(ctx, typ) - if err != nil { - return err + err = wstest.EchoLoop(r.Context(), c) + if websocket.CloseStatus(err) != websocket.StatusNormalClosure { + t.Errorf("echoLoop: %v", err) } + })) + defer wg.Wait() + defer s.Close() - _, err = io.CopyBuffer(w, r, b) - if err != nil { - return err - } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + defer cancel() - err = w.Close() - if err != nil { - return err - } + cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") + cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wstest.URL(s))) + + b, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("wasm test binary failed: %v:\n%s", err, b) } } + +func assertCloseStatus(exp websocket.StatusCode, err error) error { + if websocket.CloseStatus(err) == -1 { + return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err) + } + if websocket.CloseStatus(err) != exp { + return xerrors.Errorf("unexpected close status (%v):%v", exp, err) + } + return nil +} diff --git a/dial.go b/dial.go index a1509ab5..3e2042e5 100644 --- a/dial.go +++ b/dial.go @@ -35,8 +35,7 @@ type DialOptions struct { // CompressionOptions controls the compression options. // See docs on the CompressionOptions type. - // TODO make * - CompressionOptions CompressionOptions + CompressionOptions *CompressionOptions } // Dial performs a WebSocket handshake on url. @@ -60,6 +59,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( if opts == nil { opts = &DialOptions{} } + opts = &*opts if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient @@ -67,6 +67,9 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } + if opts.CompressionOptions == nil { + opts.CompressionOptions = &CompressionOptions{} + } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { diff --git a/dial_test.go b/dial_test.go index 3be52208..e38e8f17 100644 --- a/dial_test.go +++ b/dial_test.go @@ -223,7 +223,7 @@ func Test_verifyServerHandshake(t *testing.T) { } _, err = verifyServerResponse(opts, key, resp) if (err == nil) != tc.success { - t.Fatalf("unexpected error: %+v", err) + t.Fatalf("unexpected error: %v", err) } }) } diff --git a/example_test.go b/example_test.go index 1842b765..075107b0 100644 --- a/example_test.go +++ b/example_test.go @@ -74,8 +74,7 @@ func ExampleCloseStatus() { _, _, err = c.Reader(ctx) if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %+v", err) - return + log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %v", err) } } diff --git a/frame.go b/frame.go index 47ff40f7..0257835e 100644 --- a/frame.go +++ b/frame.go @@ -1,5 +1,3 @@ -// +build !js - package websocket import ( diff --git a/internal/test/cmp/cmp.go b/internal/test/cmp/cmp.go index d0eee6d0..cdbadf70 100644 --- a/internal/test/cmp/cmp.go +++ b/internal/test/cmp/cmp.go @@ -2,6 +2,7 @@ package cmp import ( "reflect" + "strings" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -20,3 +21,11 @@ func Diff(v1, v2 interface{}) string { return true })) } + +// ErrorContains returns whether err.Error() contains sub. +func ErrorContains(err error, sub string) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), sub) +} diff --git a/internal/test/wstest/echo.go b/internal/test/wstest/echo.go new file mode 100644 index 00000000..70b2ba57 --- /dev/null +++ b/internal/test/wstest/echo.go @@ -0,0 +1,90 @@ +package wstest + +import ( + "context" + "io" + "time" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/xrand" + "nhooyr.io/websocket/internal/xsync" +) + +// EchoLoop echos every msg received from c until an error +// occurs or the context expires. +// The read limit is set to 1 << 30. +func EchoLoop(ctx context.Context, c *websocket.Conn) error { + defer c.Close(websocket.StatusInternalError, "") + + c.SetReadLimit(1 << 30) + + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + b := make([]byte, 32<<10) + for { + typ, r, err := c.Reader(ctx) + if err != nil { + return err + } + + w, err := c.Writer(ctx, typ) + if err != nil { + return err + } + + _, err = io.CopyBuffer(w, r, b) + if err != nil { + return err + } + + err = w.Close() + if err != nil { + return err + } + } +} + +// Echo writes a message and ensures the same is sent back on c. +func Echo(ctx context.Context, c *websocket.Conn, max int) error { + expType := websocket.MessageBinary + if xrand.Bool() { + expType = websocket.MessageText + } + + msg := randMessage(expType, xrand.Int(max)) + + writeErr := xsync.Go(func() error { + return c.Write(ctx, expType, msg) + }) + + actType, act, err := c.Read(ctx) + if err != nil { + return err + } + + err = <-writeErr + if err != nil { + return err + } + + if expType != actType { + return xerrors.Errorf("unexpected message typ (%v): %v", expType, actType) + } + + if !cmp.Equal(msg, act) { + return xerrors.Errorf("unexpected msg read: %v", cmp.Diff(msg, act)) + } + + return nil +} + +func randMessage(typ websocket.MessageType, n int) []byte { + if typ == websocket.MessageBinary { + return xrand.Bytes(n) + } + return []byte(xrand.String(n)) +} diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index e958aea4..81705a8a 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -1,3 +1,5 @@ +// +build !js + package wstest import ( @@ -30,6 +32,7 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) if dialOpts == nil { dialOpts = &websocket.DialOptions{} } + dialOpts = &*dialOpts dialOpts.HTTPClient = &http.Client{ Transport: tt, } diff --git a/internal/test/wstest/url.go b/internal/test/wstest/url.go new file mode 100644 index 00000000..a11c61b4 --- /dev/null +++ b/internal/test/wstest/url.go @@ -0,0 +1,11 @@ +package wstest + +import ( + "net/http/httptest" + "strings" +) + +// URL returns the ws url for s. +func URL(s *httptest.Server) string { + return strings.Replace(s.URL, "http", "ws", 1) +} diff --git a/internal/xsync/go.go b/internal/xsync/go.go new file mode 100644 index 00000000..96cf8103 --- /dev/null +++ b/internal/xsync/go.go @@ -0,0 +1,25 @@ +package xsync + +import ( + "golang.org/x/xerrors" +) + +// Go allows running a function in another goroutine +// and waiting for its error. +func Go(fn func() error) chan error { + errs := make(chan error, 1) + go func() { + defer func() { + r := recover() + if r != nil { + select { + case errs <- xerrors.Errorf("panic in go fn: %v", r): + default: + } + } + }() + errs <- fn() + }() + + return errs +} diff --git a/internal/xsync/go_test.go b/internal/xsync/go_test.go new file mode 100644 index 00000000..c0613e64 --- /dev/null +++ b/internal/xsync/go_test.go @@ -0,0 +1,20 @@ +package xsync + +import ( + "testing" + + "nhooyr.io/websocket/internal/test/cmp" +) + +func TestGoRecover(t *testing.T) { + t.Parallel() + + errs := Go(func() error { + panic("anmol") + }) + + err := <-errs + if !cmp.ErrorContains(err, "anmol") { + t.Fatalf("unexpected err: %v", err) + } +} diff --git a/internal/xsync/int64.go b/internal/xsync/int64.go new file mode 100644 index 00000000..a0c40204 --- /dev/null +++ b/internal/xsync/int64.go @@ -0,0 +1,23 @@ +package xsync + +import ( + "sync/atomic" +) + +// Int64 represents an atomic int64. +type Int64 struct { + // We do not use atomic.Load/StoreInt64 since it does not + // work on 32 bit computers but we need 64 bit integers. + i atomic.Value +} + +// Load loads the int64. +func (v *Int64) Load() int64 { + i, _ := v.i.Load().(int64) + return i +} + +// Store stores the int64. +func (v *Int64) Store(i int64) { + v.i.Store(i) +} diff --git a/read.go b/read.go index b681a944..e723ef3c 100644 --- a/read.go +++ b/read.go @@ -7,12 +7,12 @@ import ( "io" "io/ioutil" "strings" - "sync/atomic" "time" "golang.org/x/xerrors" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/xsync" ) // Reader reads from the connection until until there is a WebSocket @@ -415,7 +415,7 @@ func (mr *msgReader) read(p []byte) (int, error) { type limitReader struct { c *Conn r io.Reader - limit atomicInt64 + limit xsync.Int64 n int64 } @@ -448,21 +448,6 @@ func (lr *limitReader) Read(p []byte) (int, error) { return n, err } -type atomicInt64 struct { - // We do not use atomic.Load/StoreInt64 since it does not - // work on 32 bit computers but we need 64 bit integers. - i atomic.Value -} - -func (v *atomicInt64) Load() int64 { - i, _ := v.i.Load().(int64) - return i -} - -func (v *atomicInt64) Store(i int64) { - v.i.Store(i) -} - type readerFunc func(p []byte) (int, error) func (f readerFunc) Read(p []byte) (int, error) { diff --git a/ws_js.go b/ws_js.go index 3ce6f34d..de76afa6 100644 --- a/ws_js.go +++ b/ws_js.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "net/http" "reflect" "runtime" "sync" @@ -13,6 +14,7 @@ import ( "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" + "nhooyr.io/websocket/internal/xsync" ) // MessageType represents the type of a WebSocket message. @@ -32,10 +34,10 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit atomicInt64 + msgReadLimit xsync.Int64 closingMu sync.Mutex - isReadClosed atomicInt64 + isReadClosed xsync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -67,11 +69,8 @@ func (c *Conn) init() { c.closed = make(chan struct{}) c.readSignal = make(chan struct{}, 1) - c.msgReadLimit = &wssync.Int64{} c.msgReadLimit.Store(32768) - c.isReadClosed = &wssync.Int64{} - c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { err := CloseError{ Code: StatusCode(e.Code), @@ -121,7 +120,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { return 0, nil, xerrors.Errorf("failed to read: %w", err) } if int64(len(p)) > c.msgReadLimit.Load() { - err := xerrors.Errorf("read limited at %v bytes", c.msgReadLimit) + err := xerrors.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) c.Close(StatusMessageTooBig, err.Error()) return 0, nil, err } @@ -248,17 +247,17 @@ type DialOptions struct { // Dial creates a new WebSocket connection to the given url with the given options. // The passed context bounds the maximum time spent waiting for the connection to open. -// The returned *http.Response is always nil or the zero value. It's only in the signature +// The returned *http.Response is always nil or a mock. It's only in the signature // to match the core API. -func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, error) { - c, err := dial(ctx, url, opts) +func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { + c, resp, err := dial(ctx, url, opts) if err != nil { - return nil, resp, xerrors.Errorf("failed to WebSocket dial %q: %w", url, err) + return nil, nil, xerrors.Errorf("failed to WebSocket dial %q: %w", url, err) } - return c, nil + return c, resp, nil } -func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, error) { +func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { if opts == nil { opts = &DialOptions{} } @@ -284,11 +283,12 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, error) { c.Close(StatusPolicyViolation, "dial timed out") return nil, nil, ctx.Err() case <-opench: + return c, &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + }, nil case <-c.closed: - return c, nil, c.closeErr + return nil, nil, c.closeErr } - - return c, nil } // Reader attempts to read a message from the connection. diff --git a/ws_js_test.go b/ws_js_test.go index 65309bff..8d49af6b 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -1,4 +1,4 @@ -package websocket +package websocket_test import ( "context" @@ -6,25 +6,43 @@ import ( "os" "testing" "time" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/wstest" ) -func TestEcho(t *testing.T) { +func TestWasm(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - c, resp, err := Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &DialOptions{ + c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) - assert.Success(t, err) - defer c.Close(StatusInternalError, "") + if err != nil { + t.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + if !cmp.Equal("echo", c.Subprotocol()) { + t.Fatalf("unexpected subprotocol: %v", cmp.Diff("echo", c.Subprotocol())) + } + if !cmp.Equal(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Fatalf("unexpected status code: %v", cmp.Diff(http.StatusSwitchingProtocols, resp.StatusCode)) + } - assertSubprotocol(t, c, "echo") - assert.Equalf(t, &http.Response{}, resp, "http.Response") - echoJSON(t, ctx, c, 1024) - assertEcho(t, ctx, c, MessageBinary, 1024) + c.SetReadLimit(65536) + for i := 0; i < 10; i++ { + err = wstest.Echo(ctx, c, 65536) + if err != nil { + t.Fatal(err) + } + } - err = c.Close(StatusNormalClosure, "") - assert.Success(t, err) + err = c.Close(websocket.StatusNormalClosure, "") + if err != nil { + t.Fatal(err) + } } From 69ff675fa5b55466a0a7ad8af391cec679b21216 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 9 Feb 2020 01:27:32 -0500 Subject: [PATCH 34/55] More tests and fixes --- close_notjs.go | 21 ++-- compress_notjs.go | 11 -- conn_notjs.go | 22 ++-- conn_test.go | 141 +++++++++++++++++++++++-- dial_test.go | 18 +++- go.mod | 8 -- go.sum | 255 ---------------------------------------------- write.go | 6 ++ ws_js_test.go | 2 +- 9 files changed, 170 insertions(+), 314 deletions(-) diff --git a/close_notjs.go b/close_notjs.go index dd1b0e0d..08a1ea05 100644 --- a/close_notjs.go +++ b/close_notjs.go @@ -35,7 +35,7 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") err = c.writeClose(code, reason) - if err != nil { + if CloseStatus(err) == -1 { return err } @@ -46,12 +46,6 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { return nil } -func (c *Conn) writeError(code StatusCode, err error) { - c.setCloseErr(err) - c.writeClose(code, err.Error()) - c.close(nil) -} - func (c *Conn) writeClose(code StatusCode, reason string) error { c.closeMu.Lock() closing := c.wroteClose @@ -70,7 +64,12 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { var p []byte if ce.Code != StatusNoStatusRcvd { - p = ce.bytes() + var err error + p, err = ce.bytes() + if err != nil { + log.Printf("websocket: %v", err) + return err + } } return c.writeControl(context.Background(), opClose, p) @@ -148,16 +147,16 @@ func validWireCloseCode(code StatusCode) bool { return false } -func (ce CloseError) bytes() []byte { +func (ce CloseError) bytes() ([]byte, error) { p, err := ce.bytesErr() if err != nil { - log.Printf("websocket: failed to marshal close frame: %v", err) + err = xerrors.Errorf("failed to marshal close frame: %w", err) ce = CloseError{ Code: StatusInternalError, } p, _ = ce.bytesErr() } - return p + return p, err } const maxCloseReason = maxControlPayload - 2 diff --git a/compress_notjs.go b/compress_notjs.go index 8bc2f87b..6ab6e284 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -10,9 +10,6 @@ import ( ) func (m CompressionMode) opts() *compressionOptions { - if m == CompressionDisabled { - return nil - } return &compressionOptions{ clientNoContextTakeover: m == CompressionNoContextTakeover, serverNoContextTakeover: m == CompressionNoContextTakeover, @@ -42,14 +39,6 @@ func (copts *compressionOptions) setHeader(h http.Header) { // trying to return more bytes. const deflateMessageTail = "\x00\x00\xff\xff" -func (c *Conn) writeNoContextTakeOver() bool { - return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover -} - -func (c *Conn) readNoContextTakeOver() bool { - return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover -} - type trimLastFourBytesWriter struct { w io.Writer tail []byte diff --git a/conn_notjs.go b/conn_notjs.go index d2fea4d4..96d17b73 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -87,12 +87,6 @@ func newConn(cfg connConfig) *Conn { closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } - if c.flate() && c.flateThreshold == 0 { - c.flateThreshold = 256 - if c.writeNoContextTakeOver() { - c.flateThreshold = 512 - } - } c.readMu = newMu(c) c.writeFrameMu = newMu(c) @@ -104,6 +98,13 @@ func newConn(cfg connConfig) *Conn { c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) } + if c.flate() && c.flateThreshold == 0 { + c.flateThreshold = 256 + if !c.msgWriter.flateContextTakeover() { + c.flateThreshold = 512 + } + } + runtime.SetFinalizer(c, func(c *Conn) { c.close(xerrors.New("connection garbage collected")) }) @@ -247,15 +248,6 @@ func (m *mu) Lock(ctx context.Context) error { } } -func (m *mu) TryLock() bool { - select { - case m.ch <- struct{}{}: - return true - default: - return false - } -} - func (m *mu) Unlock() { select { case <-m.ch: diff --git a/conn_test.go b/conn_test.go index 5c817a25..098fc097 100644 --- a/conn_test.go +++ b/conn_test.go @@ -16,6 +16,7 @@ import ( "golang.org/x/xerrors" "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/test/cmp" "nhooyr.io/websocket/internal/test/wstest" "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" @@ -31,9 +32,6 @@ func TestConn(t *testing.T) { t.Run("", func(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - copts := &websocket.CompressionOptions{ Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), Threshold: xrand.Int(9999), @@ -47,11 +45,14 @@ func TestConn(t *testing.T) { if err != nil { t.Fatal(err) } - defer c1.Close(websocket.StatusInternalError, "") defer c2.Close(websocket.StatusInternalError, "") + defer c1.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() echoLoopErr := xsync.Go(func() error { - err := wstest.EchoLoop(ctx, c1) + err := wstest.EchoLoop(ctx, c2) return assertCloseStatus(websocket.StatusNormalClosure, err) }) defer func() { @@ -62,19 +63,143 @@ func TestConn(t *testing.T) { }() defer cancel() - c2.SetReadLimit(131072) + c1.SetReadLimit(131072) for i := 0; i < 5; i++ { - err := wstest.Echo(ctx, c2, 131072) + err := wstest.Echo(ctx, c1, 131072) if err != nil { t.Fatal(err) } } - c2.Close(websocket.StatusNormalClosure, "") + err = c1.Close(websocket.StatusNormalClosure, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } }) } }) + + t.Run("badClose", func(t *testing.T) { + t.Parallel() + + c1, c2, err := wstest.Pipe(nil, nil) + if err != nil { + t.Fatal(err) + } + defer c1.Close(websocket.StatusInternalError, "") + defer c2.Close(websocket.StatusInternalError, "") + + err = c1.Close(-1, "") + if !cmp.ErrorContains(err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("ping", func(t *testing.T) { + t.Parallel() + + c1, c2, err := wstest.Pipe(nil, nil) + if err != nil { + t.Fatal(err) + } + defer c1.Close(websocket.StatusInternalError, "") + defer c2.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + + c2.CloseRead(ctx) + c1.CloseRead(ctx) + + for i := 0; i < 10; i++ { + err = c1.Ping(ctx) + if err != nil { + t.Fatal(err) + } + } + + err = c1.Close(websocket.StatusNormalClosure, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("badPing", func(t *testing.T) { + t.Parallel() + + c1, c2, err := wstest.Pipe(nil, nil) + if err != nil { + t.Fatal(err) + } + defer c1.Close(websocket.StatusInternalError, "") + defer c2.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + c2.CloseRead(ctx) + + err = c1.Ping(ctx) + if !cmp.ErrorContains(err, "failed to wait for pong") { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("concurrentWrite", func(t *testing.T) { + t.Parallel() + + c1, c2, err := wstest.Pipe(nil, nil) + if err != nil { + t.Fatal(err) + } + defer c2.Close(websocket.StatusInternalError, "") + defer c1.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + discardLoopErr := xsync.Go(func() error { + for { + _, _, err := c2.Read(ctx) + if websocket.CloseStatus(err) == websocket.StatusNormalClosure { + return nil + } + if err != nil { + return err + } + } + }) + defer func() { + err := <-discardLoopErr + if err != nil { + t.Errorf("discard loop error: %v", err) + } + }() + defer cancel() + + msg := xrand.Bytes(xrand.Int(9999)) + const count = 100 + errs := make(chan error, count) + + for i := 0; i < count; i++ { + go func() { + errs <- c1.Write(ctx, websocket.MessageBinary, msg) + }() + } + + for i := 0; i < count; i++ { + err := <-errs + if err != nil { + t.Fatal(err) + } + } + + err = c1.Close(websocket.StatusNormalClosure, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) } func TestWasm(t *testing.T) { diff --git a/dial_test.go b/dial_test.go index e38e8f17..c4657415 100644 --- a/dial_test.go +++ b/dial_test.go @@ -13,7 +13,7 @@ import ( "testing" "time" - "cdr.dev/slog/sloggers/slogtest/assert" + "nhooyr.io/websocket/internal/test/cmp" ) func TestBadDials(t *testing.T) { @@ -70,7 +70,9 @@ func TestBadDials(t *testing.T) { } _, _, err := dial(ctx, tc.url, tc.opts, tc.rand) - assert.Error(t, "dial", err) + if err == nil { + t.Fatalf("expected error") + } }) } }) @@ -88,7 +90,9 @@ func TestBadDials(t *testing.T) { }, nil }), }) - assert.ErrorContains(t, "dial", err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") + if !cmp.ErrorContains(err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") { + t.Fatal(err) + } }) t.Run("badBody", func(t *testing.T) { @@ -113,7 +117,9 @@ func TestBadDials(t *testing.T) { _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ HTTPClient: mockHTTPClient(rt), }) - assert.ErrorContains(t, "dial", err, "response body is not a io.ReadWriteCloser") + if !cmp.ErrorContains(err, "response body is not a io.ReadWriteCloser") { + t.Fatal(err) + } }) } @@ -211,7 +217,9 @@ func Test_verifyServerHandshake(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) key, err := secWebSocketKey(rand.Reader) - assert.Success(t, "secWebSocketKey", err) + if err != nil { + t.Fatal(err) + } r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { diff --git a/go.mod b/go.mod index fc4ebb99..cb372391 100644 --- a/go.mod +++ b/go.mod @@ -3,20 +3,12 @@ module nhooyr.io/websocket go 1.12 require ( - cdr.dev/slog v1.3.0 - github.com/alecthomas/chroma v0.7.1 // indirect - github.com/fatih/color v1.9.0 // indirect github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect github.com/gobwas/pool v0.2.0 // indirect github.com/gobwas/ws v1.0.2 - github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect github.com/golang/protobuf v1.3.3 github.com/google/go-cmp v0.4.0 github.com/gorilla/websocket v1.4.1 - github.com/mattn/go-isatty v0.0.12 // indirect - go.opencensus.io v0.22.3 // indirect - golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340 // indirect - golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 // indirect golang.org/x/time v0.0.0-20191024005414-555d28b269f0 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 ) diff --git a/go.sum b/go.sum index 1d1dc3a6..8cbc66ce 100644 --- a/go.sum +++ b/go.sum @@ -1,271 +1,16 @@ -cdr.dev/slog v1.3.0 h1:MYN1BChIaVEGxdS7I5cpdyMC0+WfJfK8BETAfzfLUGQ= -cdr.dev/slog v1.3.0/go.mod h1:C5OL99WyuOK8YHZdYY57dAPN1jK2WJlCdq2VP6xeQns= -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= -cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= -cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= -cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= -cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= -cloud.google.com/go v0.49.0 h1:CH+lkubJzcPYB1Ggupcq0+k8Ni2ILdG2lYjDIgavDBQ= -cloud.google.com/go v0.49.0/go.mod h1:hGvAdzcWNbyuxS3nWhD7H2cIJxjRRTRLQVB0bdputVY= -cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= -cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= -cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= -cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= -dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/GeertJohan/go.incremental v1.0.0/go.mod h1:6fAjUhbVuX1KcMD3c8TEgVUqmo4seqhv0i0kdATSkM0= -github.com/GeertJohan/go.rice v1.0.0/go.mod h1:eH6gbSOAUv07dQuZVnBmoDP8mgsM1rtixis4Tib9if0= -github.com/akavel/rsrc v0.8.0/go.mod h1:uLoCtb9J+EyAqh+26kdrTgmzRBFPGOolLWKpdxkKq+c= -github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38 h1:smF2tmSOzy2Mm+0dGI2AIUHY+w0BUc+4tn40djz7+6U= -github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38/go.mod h1:r7bzyVFMNntcxPZXK3/+KdruV1H5KSlyVY0gc+NgInI= -github.com/alecthomas/chroma v0.7.0 h1:z+0HgTUmkpRDRz0SRSdMaqOLfJV4F+N1FPDZUZIDUzw= -github.com/alecthomas/chroma v0.7.0/go.mod h1:1U/PfCsTALWWYHDnsIQkxEBM0+6LLe0v8+RSVMOwxeY= -github.com/alecthomas/chroma v0.7.1 h1:G1i02OhUbRi2nJxcNkwJaY/J1gHXj9tt72qN6ZouLFQ= -github.com/alecthomas/chroma v0.7.1/go.mod h1:gHw09mkX1Qp80JlYbmN9L3+4R5o6DJJ3GRShh+AICNc= -github.com/alecthomas/colour v0.0.0-20160524082231-60882d9e2721 h1:JHZL0hZKJ1VENNfmXvHbgYlbUOvpzYzvy2aZU5gXVeo= -github.com/alecthomas/colour v0.0.0-20160524082231-60882d9e2721/go.mod h1:QO9JBoKquHd+jz9nshCh40fOfO+JzsoXy8qTHF68zU0= -github.com/alecthomas/kong v0.1.17-0.20190424132513-439c674f7ae0/go.mod h1:+inYUSluD+p4L8KdviBSgzcqEjUQOfC5fQDRFuc36lI= -github.com/alecthomas/kong v0.2.1-0.20190708041108-0548c6b1afae/go.mod h1:+inYUSluD+p4L8KdviBSgzcqEjUQOfC5fQDRFuc36lI= -github.com/alecthomas/kong-hcl v0.1.8-0.20190615233001-b21fea9723c8/go.mod h1:MRgZdU3vrFd05IQ89AxUZ0aYdF39BYoNFa324SodPCA= -github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897 h1:p9Sln00KOTlrYkxI1zYWl1QLnEqAqEARBEYa8FQnQcY= -github.com/alecthomas/repr v0.0.0-20180818092828-117648cd9897/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/daaku/go.zipexe v1.0.0/go.mod h1:z8IiR6TsVLEYKwXAoE/I+8ys/sDkgTzSL0CLnGVd57E= -github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 h1:y5HC9v93H5EPKqaS1UYVg1uYah5Xf51mBfIoWehClUQ= -github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964/go.mod h1:Xd9hchkHSWYkEqJwUGisez3G1QY8Ryz0sdWrLPMGjLk= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dlclark/regexp2 v1.1.6 h1:CqB4MjHw0MFCDj+PHHjiESmHX+N7t0tJzKvC6M97BRg= -github.com/dlclark/regexp2 v1.1.6/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= -github.com/dlclark/regexp2 v1.2.0 h1:8sAhBGEM0dRWogWqWyQeIJnxjWO6oIjl8FKqREDsGfk= -github.com/dlclark/regexp2 v1.2.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys= -github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= -github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= -github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= -github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM39fZuwSd1LwSqqSW0hOdXCYYDX0R3I= -github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9 h1:uHTyIjqVhYRhLbJ8nIiOJHkEZZ+5YoOsAbD3sk82NiE= -github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= -github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.2-0.20191216170541-340f1ebe299e h1:4WfjkTUTsO6siF8ghDQQk6t7x/FPsv3w6MXkc47do7Q= -github.com/google/go-cmp v0.3.2-0.20191216170541-340f1ebe299e/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= -github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= -github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= -github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= -github.com/gorilla/csrf v1.6.0/go.mod h1:7tSf8kmjNYr7IWDCYhd3U8Ck34iQ/Yw5CJu7bAkHEGI= -github.com/gorilla/handlers v1.4.1/go.mod h1:Qkdc/uu4tH4g6mTK6auzZ766c4CA0Ng8+o/OAirnOIQ= -github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= -github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= -github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= -github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= -github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= -github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= -github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/nkovacs/streamquote v0.0.0-20170412213628-49af9bddb229/go.mod h1:0aYXnNPJ8l7uZxf45rWW1a/uME32OF0rhiYGNQ2oF2E= -github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ= -github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= -github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.0.1/go.mod h1:UQGH1tvbgY+Nz5t2n7tXsz52dQxojPUpymEIMZ47gx8= -go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= -go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= -go.opencensus.io v0.22.2 h1:75k/FF0Q2YM8QYo07VPddOLBslDt1MZOdEslOHvmzAs= -go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opencensus.io v0.22.3 h1:8sGtKOrtQqkN1bp2AtX+misvLIlOmsEsNd+9NIcPEm8= -go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5 h1:58fnuSXlxZmFdJyvtTFVmVhcMLU6v5fEb/ok4wyqtNU= -golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413 h1:ULYEB3JvPRE/IfO+9uO7vKV/xzVTO7XPAwm8xbf4w2g= -golang.org/x/crypto v0.0.0-20191206172530-e9b2fee46413/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340 h1:KOcEaR10tFr7gdJV2GCKw8Os5yED1u1aOqHjOAb6d2Y= -golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= -golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= -golang.org/x/exp v0.0.0-20191030013958-a1ab85dbe136/go.mod h1:JXzH8nQsPlswgeRAPE3MuO9GYsAcnJvJ4vnMwN/5qkY= -golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= -golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= -golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= -golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwLmSJpwZ1yqXm8j0v2QI= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8= -golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181128092732-4ed8d59d0b35/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191210023423-ac6580df4449 h1:gSbV7h1NRL2G1xTg/owz62CST1oJBmxy4QpMMregXVQ= -golang.org/x/sys v0.0.0-20191210023423-ac6580df4449/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 h1:LfCXLvNmTYH9kEmVgqbnsWfruoXZIrh4YBgqVHtDvw0= -golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= -golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2 h1:EtTFh6h4SAKemS+CURDMTDIANuduG5zKEXShyy18bGA= -golang.org/x/tools v0.0.0-20191115202509-3a792d9c32b2/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= -google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= -google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= -google.golang.org/api v0.14.0/go.mod h1:iLdEw5Ide6rF15KTC1Kkl0iskquN2gFfn9o9XIsbkAI= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= -google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= -google.golang.org/genproto v0.0.0-20191115194625-c23dd37a84c9/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1 h1:aQktFqmDE2yjveXJlVIfslDFmFnUXSqG0i6KRcJAeMc= -google.golang.org/genproto v0.0.0-20191216164720-4f79533eabd1/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= -google.golang.org/grpc v1.21.1 h1:j6XxA85m/6txkUCHvzlV5f+HBNl/1r5cZ2A/3IEFOO8= -google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1 h1:wdKvqQk7IttEw92GoRyKG2IDrUIpgpj6H6m81yfeMW0= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= diff --git a/write.go b/write.go index 70656b9f..612e52cb 100644 --- a/write.go +++ b/write.go @@ -335,3 +335,9 @@ func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { return writeBuf } + +func (c *Conn) writeError(code StatusCode, err error) { + c.setCloseErr(err) + c.writeClose(code, err.Error()) + c.close(nil) +} diff --git a/ws_js_test.go b/ws_js_test.go index 8d49af6b..bda9c0a5 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -15,7 +15,7 @@ import ( func TestWasm(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) defer cancel() c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ From 085e6717e9e4982bfc3f235e2627f7aed7d69d04 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 9 Feb 2020 01:45:02 -0500 Subject: [PATCH 35/55] Get coverage to 85% --- conn_test.go | 117 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/conn_test.go b/conn_test.go index 098fc097..db5ec84d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -5,6 +5,8 @@ package websocket_test import ( "context" "fmt" + "io" + "io/ioutil" "net/http" "net/http/httptest" "os" @@ -200,6 +202,121 @@ func TestConn(t *testing.T) { t.Fatalf("unexpected error: %v", err) } }) + + t.Run("concurrentWriteError", func(t *testing.T) { + t.Parallel() + + c1, c2, err := wstest.Pipe(nil, nil) + if err != nil { + t.Fatal(err) + } + defer c2.Close(websocket.StatusInternalError, "") + defer c1.Close(websocket.StatusInternalError, "") + + _, err = c1.Writer(context.Background(), websocket.MessageText) + if err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + + err = c1.Write(ctx, websocket.MessageText, []byte("x")) + if !xerrors.Is(err, context.DeadlineExceeded) { + t.Fatal(err) + } + }) + + t.Run("netConn", func(t *testing.T) { + t.Parallel() + + c1, c2, err := wstest.Pipe(nil, nil) + if err != nil { + t.Fatal(err) + } + defer c2.Close(websocket.StatusInternalError, "") + defer c1.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(ctx, c2, websocket.MessageBinary) + + // Does not give any confidence but at least ensures no crashes. + d, _ := ctx.Deadline() + n1.SetDeadline(d) + n1.SetDeadline(time.Time{}) + + if n1.RemoteAddr() != n1.LocalAddr() { + t.Fatal() + } + if n1.RemoteAddr().String() != "websocket/unknown-addr" || n1.RemoteAddr().Network() != "websocket" { + t.Fatal(n1.RemoteAddr()) + } + + errs := xsync.Go(func() error { + _, err := n2.Write([]byte("hello")) + if err != nil { + return err + } + return n2.Close() + }) + + b, err := ioutil.ReadAll(n1) + if err != nil { + t.Fatal(err) + } + + _, err = n1.Read(nil) + if err != io.EOF { + t.Fatalf("expected EOF: %v", err) + } + + err = <-errs + if err != nil { + t.Fatal(err) + } + + if !cmp.Equal([]byte("hello"), b) { + t.Fatalf("unexpected msg: %v", cmp.Diff([]byte("hello"), b)) + } + }) + + t.Run("netConn", func(t *testing.T) { + t.Parallel() + + c1, c2, err := wstest.Pipe(nil, nil) + if err != nil { + t.Fatal(err) + } + defer c2.Close(websocket.StatusInternalError, "") + defer c1.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(ctx, c2, websocket.MessageText) + + errs := xsync.Go(func() error { + _, err := n2.Write([]byte("hello")) + if err != nil { + return err + } + return nil + }) + + _, err = ioutil.ReadAll(n1) + if !cmp.ErrorContains(err, `unexpected frame type read (expected MessageBinary): MessageText`) { + t.Fatal(err) + } + + err = <-errs + if err != nil { + t.Fatal(err) + } + }) } func TestWasm(t *testing.T) { From 51769b30952a33d6362a4efbccaf30ba64ea3a3b Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 9 Feb 2020 01:54:31 -0500 Subject: [PATCH 36/55] Add wspb test --- conn_test.go | 105 +++++++++++++++++++++++++++++++++++++++++++++++++- ws_js_test.go | 2 +- 2 files changed, 104 insertions(+), 3 deletions(-) diff --git a/conn_test.go b/conn_test.go index db5ec84d..a0edd8df 100644 --- a/conn_test.go +++ b/conn_test.go @@ -15,6 +15,9 @@ import ( "testing" "time" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes" + "github.com/golang/protobuf/ptypes/duration" "golang.org/x/xerrors" "nhooyr.io/websocket" @@ -22,12 +25,14 @@ import ( "nhooyr.io/websocket/internal/test/wstest" "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" + "nhooyr.io/websocket/wsjson" + "nhooyr.io/websocket/wspb" ) func TestConn(t *testing.T) { t.Parallel() - t.Run("data", func(t *testing.T) { + t.Run("fuzzData", func(t *testing.T) { t.Parallel() for i := 0; i < 5; i++ { @@ -317,6 +322,102 @@ func TestConn(t *testing.T) { t.Fatal(err) } }) + + t.Run("wsjson", func(t *testing.T) { + t.Parallel() + + c1, c2, err := wstest.Pipe(nil, nil) + if err != nil { + t.Fatal(err) + } + defer c2.Close(websocket.StatusInternalError, "") + defer c1.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + echoLoopErr := xsync.Go(func() error { + err := wstest.EchoLoop(ctx, c2) + return assertCloseStatus(websocket.StatusNormalClosure, err) + }) + defer func() { + err := <-echoLoopErr + if err != nil { + t.Errorf("echo loop error: %v", err) + } + }() + defer cancel() + + c1.SetReadLimit(131072) + + exp := xrand.String(xrand.Int(131072)) + err = wsjson.Write(ctx, c1, exp) + if err != nil { + t.Fatal(err) + } + + var act interface{} + err = wsjson.Read(ctx, c1, &act) + if err != nil { + t.Fatal(err) + } + if exp != act { + t.Fatal(cmp.Diff(exp, act)) + } + + err = c1.Close(websocket.StatusNormalClosure, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("wspb", func(t *testing.T) { + t.Parallel() + + c1, c2, err := wstest.Pipe(nil, nil) + if err != nil { + t.Fatal(err) + } + defer c2.Close(websocket.StatusInternalError, "") + defer c1.Close(websocket.StatusInternalError, "") + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + echoLoopErr := xsync.Go(func() error { + err := wstest.EchoLoop(ctx, c2) + return assertCloseStatus(websocket.StatusNormalClosure, err) + }) + defer func() { + err := <-echoLoopErr + if err != nil { + t.Errorf("echo loop error: %v", err) + } + }() + defer cancel() + + c1.SetReadLimit(131072) + + exp := ptypes.DurationProto(100) + err = wspb.Write(ctx, c1, exp) + if err != nil { + t.Fatal(err) + } + + act := &duration.Duration{} + err = wspb.Read(ctx, c1, act) + if err != nil { + t.Fatal(err) + } + if !proto.Equal(exp, act) { + t.Fatal(cmp.Diff(exp, act)) + } + + err = c1.Close(websocket.StatusNormalClosure, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) } func TestWasm(t *testing.T) { @@ -345,7 +446,7 @@ func TestWasm(t *testing.T) { defer wg.Wait() defer s.Close() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") diff --git a/ws_js_test.go b/ws_js_test.go index bda9c0a5..8671dd21 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -15,7 +15,7 @@ import ( func TestWasm(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ From 670be052707b9505a51f0530535e50df06114b11 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 9 Feb 2020 02:03:16 -0500 Subject: [PATCH 37/55] Merge in handshake improvements from master --- accept.go | 25 +++++++++++++++---------- accept_test.go | 2 +- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/accept.go b/accept.go index 31f104b2..cc9babb0 100644 --- a/accept.go +++ b/accept.go @@ -65,9 +65,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con opts.CompressionOptions = &CompressionOptions{} } - err = verifyClientRequest(r) + errCode, err := verifyClientRequest(w, r) if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + http.Error(w, err.Error(), errCode) return nil, err } @@ -127,32 +127,37 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con }), nil } -func verifyClientRequest(r *http.Request) error { +func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { if !r.ProtoAtLeast(1, 1) { - return xerrors.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) + return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) } if !headerContainsToken(r.Header, "Connection", "Upgrade") { - return xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) } if !headerContainsToken(r.Header, "Upgrade", "websocket") { - return xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + return http.StatusUpgradeRequired, xerrors.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) } if r.Method != "GET" { - return xerrors.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) + return http.StatusMethodNotAllowed, xerrors.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) } if r.Header.Get("Sec-WebSocket-Version") != "13" { - return xerrors.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) + w.Header().Set("Sec-WebSocket-Version", "13") + return http.StatusBadRequest, xerrors.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) } if r.Header.Get("Sec-WebSocket-Key") == "" { - return xerrors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") + return http.StatusBadRequest, xerrors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") } - return nil + return 0, nil } func authenticateOrigin(r *http.Request) error { diff --git a/accept_test.go b/accept_test.go index 18302da5..354e95ec 100644 --- a/accept_test.go +++ b/accept_test.go @@ -192,7 +192,7 @@ func Test_verifyClientHandshake(t *testing.T) { r.Header.Set(k, v) } - err := verifyClientRequest(r) + _, err := verifyClientRequest(httptest.NewRecorder(), r) if tc.success != (err == nil) { t.Fatalf("unexpected error value: %v", err) } From 3a526d8c78ea452584284b9b6da08e1679da2867 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 9 Feb 2020 02:31:59 -0500 Subject: [PATCH 38/55] Fix bug in closeHandshake --- close_notjs.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/close_notjs.go b/close_notjs.go index 08a1ea05..160a1237 100644 --- a/close_notjs.go +++ b/close_notjs.go @@ -35,7 +35,7 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") err = c.writeClose(code, reason) - if CloseStatus(err) == -1 { + if err != nil && CloseStatus(err) == -1 { return err } @@ -63,16 +63,19 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { c.setCloseErr(xerrors.Errorf("sent close frame: %w", ce)) var p []byte + var err error if ce.Code != StatusNoStatusRcvd { - var err error p, err = ce.bytes() if err != nil { log.Printf("websocket: %v", err) - return err } } - return c.writeControl(context.Background(), opClose, p) + werr := c.writeControl(context.Background(), opClose, p) + if err != nil { + return err + } + return werr } func (c *Conn) waitCloseHandshake() error { From 999b812944250843206e83e2da9c8cb02d48f6f8 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 9 Feb 2020 02:35:11 -0500 Subject: [PATCH 39/55] Fix race in msgReader --- read.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/read.go b/read.go index e723ef3c..0c3610d3 100644 --- a/read.go +++ b/read.go @@ -352,6 +352,8 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { } if xerrors.Is(err, io.EOF) { err = io.EOF + + mr.returnFlateReader() } }() @@ -373,11 +375,7 @@ func (mr *msgReader) read(p []byte) (int, error) { if mr.payloadLength == 0 { if mr.fin { if mr.flate { - n, err := mr.flateTail.Read(p) - if xerrors.Is(err, io.EOF) { - mr.returnFlateReader() - } - return n, err + return mr.flateTail.Read(p) } return 0, io.EOF } From 4b84d25251ad9be731c0452f93cad4e48a893d6b Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 9 Feb 2020 02:46:49 -0500 Subject: [PATCH 40/55] Fix a race with c.closed --- conn_notjs.go | 2 +- conn_test.go | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/conn_notjs.go b/conn_notjs.go index 96d17b73..4d8762bf 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -127,9 +127,9 @@ func (c *Conn) close(err error) { if c.isClosed() { return } + c.setCloseErrLocked(err) close(c.closed) runtime.SetFinalizer(c, nil) - c.setCloseErrLocked(err) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns diff --git a/conn_test.go b/conn_test.go index a0edd8df..8c00522e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -55,7 +55,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() echoLoopErr := xsync.Go(func() error { @@ -142,7 +142,7 @@ func TestConn(t *testing.T) { defer c1.Close(websocket.StatusInternalError, "") defer c2.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() c2.CloseRead(ctx) @@ -242,7 +242,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary) @@ -298,7 +298,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary) @@ -333,7 +333,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() echoLoopErr := xsync.Go(func() error { @@ -381,7 +381,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() echoLoopErr := xsync.Go(func() error { From 85f249d11aff22c45781c237ec783a328e199755 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 9 Feb 2020 02:51:03 -0500 Subject: [PATCH 41/55] Up timeouts --- conn_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/conn_test.go b/conn_test.go index 8c00522e..5662344f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -55,7 +55,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() echoLoopErr := xsync.Go(func() error { @@ -163,7 +163,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() discardLoopErr := xsync.Go(func() error { @@ -288,7 +288,7 @@ func TestConn(t *testing.T) { } }) - t.Run("netConn", func(t *testing.T) { + t.Run("netConn/BadMsg", func(t *testing.T) { t.Parallel() c1, c2, err := wstest.Pipe(nil, nil) @@ -333,7 +333,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() echoLoopErr := xsync.Go(func() error { @@ -381,7 +381,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() echoLoopErr := xsync.Go(func() error { From 6b38ebbb43156c7dd421e7fd8c3a96cf8c0d5e5f Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sun, 9 Feb 2020 03:11:45 -0500 Subject: [PATCH 42/55] Test fixes --- conn_test.go | 25 +++++++++++++++---------- internal/test/wstest/echo.go | 3 ++- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/conn_test.go b/conn_test.go index 5662344f..6a5e6809 100644 --- a/conn_test.go +++ b/conn_test.go @@ -142,7 +142,7 @@ func TestConn(t *testing.T) { defer c1.Close(websocket.StatusInternalError, "") defer c2.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() c2.CloseRead(ctx) @@ -163,7 +163,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() discardLoopErr := xsync.Go(func() error { @@ -242,7 +242,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary) @@ -298,7 +298,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary) @@ -333,7 +333,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() echoLoopErr := xsync.Go(func() error { @@ -351,10 +351,10 @@ func TestConn(t *testing.T) { c1.SetReadLimit(131072) exp := xrand.String(xrand.Int(131072)) - err = wsjson.Write(ctx, c1, exp) - if err != nil { - t.Fatal(err) - } + + werr := xsync.Go(func() error { + return wsjson.Write(ctx, c1, exp) + }) var act interface{} err = wsjson.Read(ctx, c1, &act) @@ -365,6 +365,11 @@ func TestConn(t *testing.T) { t.Fatal(cmp.Diff(exp, act)) } + err = <-werr + if err != nil { + t.Fatal(err) + } + err = c1.Close(websocket.StatusNormalClosure, "") if err != nil { t.Fatalf("unexpected error: %v", err) @@ -381,7 +386,7 @@ func TestConn(t *testing.T) { defer c2.Close(websocket.StatusInternalError, "") defer c1.Close(websocket.StatusInternalError, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() echoLoopErr := xsync.Go(func() error { diff --git a/internal/test/wstest/echo.go b/internal/test/wstest/echo.go index 70b2ba57..714767fc 100644 --- a/internal/test/wstest/echo.go +++ b/internal/test/wstest/echo.go @@ -1,6 +1,7 @@ package wstest import ( + "bytes" "context" "io" "time" @@ -75,7 +76,7 @@ func Echo(ctx context.Context, c *websocket.Conn, max int) error { return xerrors.Errorf("unexpected message typ (%v): %v", expType, actType) } - if !cmp.Equal(msg, act) { + if !bytes.Equal(msg, act) { return xerrors.Errorf("unexpected msg read: %v", cmp.Diff(msg, act)) } From 6770421bab627ff98b2ddb30a85e3a99a5cef9f3 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Tue, 11 Feb 2020 23:13:23 -0500 Subject: [PATCH 43/55] Fix goroutine leak from deadlock when closing --- conn_test.go | 4 +--- read.go | 6 ++++-- write.go | 22 ++++++++++++++++++---- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/conn_test.go b/conn_test.go index 6a5e6809..5abc9f46 100644 --- a/conn_test.go +++ b/conn_test.go @@ -348,7 +348,7 @@ func TestConn(t *testing.T) { }() defer cancel() - c1.SetReadLimit(131072) + c1.SetReadLimit(1 << 30) exp := xrand.String(xrand.Int(131072)) @@ -401,8 +401,6 @@ func TestConn(t *testing.T) { }() defer cancel() - c1.SetReadLimit(131072) - exp := ptypes.DurationProto(100) err = wspb.Write(ctx, c1, exp) if err != nil { diff --git a/read.go b/read.go index 0c3610d3..cb1fa229 100644 --- a/read.go +++ b/read.go @@ -69,7 +69,9 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // // When the limit is hit, the connection will be closed with StatusMessageTooBig. func (c *Conn) SetReadLimit(n int64) { - c.msgReader.limitReader.limit.Store(n) + // We add read one more byte than the limit in case + // there is a fin frame that needs to be read. + c.msgReader.limitReader.limit.Store(n + 1) } const defaultReadLimit = 32768 @@ -80,7 +82,7 @@ func newMsgReader(c *Conn) *msgReader { fin: true, } - mr.limitReader = newLimitReader(c, readerFunc(mr.read), defaultReadLimit) + mr.limitReader = newLimitReader(c, readerFunc(mr.read), defaultReadLimit+1) return mr } diff --git a/write.go b/write.go index 612e52cb..245827a2 100644 --- a/write.go +++ b/write.go @@ -50,7 +50,8 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { type msgWriter struct { c *Conn - mu *mu + mu *mu + activeMu *mu ctx context.Context opcode opcode @@ -63,8 +64,9 @@ type msgWriter struct { func newMsgWriter(c *Conn) *msgWriter { mw := &msgWriter{ - c: c, - mu: newMu(c), + c: c, + mu: newMu(c), + activeMu: newMu(c), } return mw } @@ -147,6 +149,12 @@ func (mw *msgWriter) returnFlateWriter() { func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") + err = mw.activeMu.Lock(mw.ctx) + if err != nil { + return 0, err + } + defer mw.activeMu.Unlock() + if mw.closed { return 0, xerrors.New("cannot use closed writer") } @@ -173,6 +181,12 @@ func (mw *msgWriter) write(p []byte) (int, error) { func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") + err = mw.activeMu.Lock(mw.ctx) + if err != nil { + return err + } + defer mw.activeMu.Unlock() + if mw.closed { return xerrors.New("cannot use closed writer") } @@ -201,7 +215,7 @@ func (mw *msgWriter) Close() (err error) { } func (mw *msgWriter) close() { - mw.mu.Lock(context.Background()) + mw.activeMu.Lock(context.Background()) mw.returnFlateWriter() } From c7523658b7ecdfccd4408f90cd87b981a72c5dc4 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Tue, 11 Feb 2020 23:36:32 -0500 Subject: [PATCH 44/55] Make flateThreshold work Was noop before. --- write.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/write.go b/write.go index 245827a2..28f139af 100644 --- a/write.go +++ b/write.go @@ -106,7 +106,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return 0, err } - if !c.flate() || len(p) < c.flateThreshold { + if !c.flate() { defer c.msgWriter.mu.Unlock() return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) } @@ -159,9 +159,16 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { return 0, xerrors.New("cannot use closed writer") } - // TODO Write to buffer to detect whether to enable flate or not for this message. - if mw.c.flate() { - mw.ensureFlate() + if mw.opcode != opContinuation { + // First frame needs to be written. + if len(p) >= mw.c.flateThreshold { + // Only enables flate if the length crosses the + // threshold on the first write. + mw.ensureFlate() + } + } + + if mw.flate { return mw.flateWriter.Write(p) } From 0ea94666828da2849ce04a4cf7b5d5d8c9398fc2 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Tue, 11 Feb 2020 23:49:39 -0500 Subject: [PATCH 45/55] Cleanup writeMu and flateThreshold --- read.go | 2 +- write.go | 35 ++++++++++++++--------------------- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/read.go b/read.go index cb1fa229..42d46b85 100644 --- a/read.go +++ b/read.go @@ -120,7 +120,7 @@ func (mr *msgReader) flateContextTakeover() bool { } func (c *Conn) readRSV1Illegal(h header) bool { - // If compression is enabled, rsv1 is always illegal. + // If compression is disabled, rsv1 is always illegal. if !c.flate() { return true } diff --git a/write.go b/write.go index 28f139af..9d4b670f 100644 --- a/write.go +++ b/write.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "encoding/binary" "io" + "sync" "time" "golang.org/x/xerrors" @@ -50,8 +51,8 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { type msgWriter struct { c *Conn - mu *mu - activeMu *mu + mu *mu + writeMu sync.Mutex ctx context.Context opcode opcode @@ -64,9 +65,8 @@ type msgWriter struct { func newMsgWriter(c *Conn) *msgWriter { mw := &msgWriter{ - c: c, - mu: newMu(c), - activeMu: newMu(c), + c: c, + mu: newMu(c), } return mw } @@ -149,21 +149,17 @@ func (mw *msgWriter) returnFlateWriter() { func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") - err = mw.activeMu.Lock(mw.ctx) - if err != nil { - return 0, err - } - defer mw.activeMu.Unlock() + mw.writeMu.Lock() + defer mw.writeMu.Unlock() if mw.closed { return 0, xerrors.New("cannot use closed writer") } - if mw.opcode != opContinuation { - // First frame needs to be written. - if len(p) >= mw.c.flateThreshold { - // Only enables flate if the length crosses the - // threshold on the first write. + if mw.c.flate() { + // Only enables flate if the length crosses the + // threshold on the first frame + if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold { mw.ensureFlate() } } @@ -188,11 +184,8 @@ func (mw *msgWriter) write(p []byte) (int, error) { func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") - err = mw.activeMu.Lock(mw.ctx) - if err != nil { - return err - } - defer mw.activeMu.Unlock() + mw.writeMu.Lock() + defer mw.writeMu.Unlock() if mw.closed { return xerrors.New("cannot use closed writer") @@ -222,7 +215,7 @@ func (mw *msgWriter) Close() (err error) { } func (mw *msgWriter) close() { - mw.activeMu.Lock(context.Background()) + mw.writeMu.Lock() mw.returnFlateWriter() } From b33d48cb5b86653743c1605512f3049fb8c41958 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 12 Feb 2020 00:15:02 -0500 Subject: [PATCH 46/55] Minor cleanup --- ci/test.mk | 2 +- conn.go | 2 -- read.go | 2 +- ws_js.go | 12 ------------ 4 files changed, 2 insertions(+), 16 deletions(-) diff --git a/ci/test.mk b/ci/test.mk index 786a8d77..3fc34bbf 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -12,6 +12,6 @@ coveralls: gotest goveralls -coverprofile=ci/out/coverage.prof gotest: - go test -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... + go test -timeout=30m -covermode=count -coverprofile=ci/out/coverage.prof -coverpkg=./... $${GOTESTFLAGS-} ./... sed -i '/stringer\.go/d' ci/out/coverage.prof sed -i '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof diff --git a/conn.go b/conn.go index e58a8748..a41808be 100644 --- a/conn.go +++ b/conn.go @@ -1,5 +1,3 @@ -// +build !js - package websocket // MessageType represents the type of a WebSocket message. diff --git a/read.go b/read.go index 42d46b85..a9c291d1 100644 --- a/read.go +++ b/read.go @@ -120,7 +120,7 @@ func (mr *msgReader) flateContextTakeover() bool { } func (c *Conn) readRSV1Illegal(h header) bool { - // If compression is disabled, rsv1 is always illegal. + // If compression is disabled, rsv1 is illegal. if !c.flate() { return true } diff --git a/ws_js.go b/ws_js.go index de76afa6..05c4c062 100644 --- a/ws_js.go +++ b/ws_js.go @@ -17,18 +17,6 @@ import ( "nhooyr.io/websocket/internal/xsync" ) -// MessageType represents the type of a WebSocket message. -// See https://tools.ietf.org/html/rfc6455#section-5.6 -type MessageType int - -// MessageType constants. -const ( - // MessageText is for UTF-8 encoded text messages like JSON. - MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like protobufs. - MessageBinary -) - // Conn provides a wrapper around the browser WebSocket API. type Conn struct { ws wsjs.WebSocket From 9c5bfabce2d2b82e1b998842012502afabc32511 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Wed, 12 Feb 2020 20:29:43 -0500 Subject: [PATCH 47/55] Simplifications of conn_test.go --- conn_test.go | 443 +++++++++++++++++---------------------- internal/test/cmp/cmp.go | 3 +- 2 files changed, 194 insertions(+), 252 deletions(-) diff --git a/conn_test.go b/conn_test.go index 5abc9f46..b2a35af8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -15,7 +15,6 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/duration" "golang.org/x/xerrors" @@ -37,153 +36,86 @@ func TestConn(t *testing.T) { for i := 0; i < 5; i++ { t.Run("", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - copts := &websocket.CompressionOptions{ + dialCopts := &websocket.CompressionOptions{ Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), Threshold: xrand.Int(9999), } - c1, c2, err := wstest.Pipe(&websocket.DialOptions{ - CompressionOptions: copts, - }, &websocket.AcceptOptions{ - CompressionOptions: copts, - }) - if err != nil { - t.Fatal(err) + acceptCopts := &websocket.CompressionOptions{ + Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), + Threshold: xrand.Int(9999), } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - echoLoopErr := xsync.Go(func() error { - err := wstest.EchoLoop(ctx, c2) - return assertCloseStatus(websocket.StatusNormalClosure, err) + c1, c2 := tt.pipe(&websocket.DialOptions{ + CompressionOptions: dialCopts, + }, &websocket.AcceptOptions{ + CompressionOptions: acceptCopts, }) - defer func() { - err := <-echoLoopErr - if err != nil { - t.Errorf("echo loop error: %v", err) - } - }() - defer cancel() + + tt.goEchoLoop(c2) c1.SetReadLimit(131072) for i := 0; i < 5; i++ { - err := wstest.Echo(ctx, c1, 131072) - if err != nil { - t.Fatal(err) - } + err := wstest.Echo(tt.ctx, c1, 131072) + tt.success(err) } - err = c1.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + err := c1.Close(websocket.StatusNormalClosure, "") + tt.success(err) }) } }) t.Run("badClose", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c1.Close(websocket.StatusInternalError, "") - defer c2.Close(websocket.StatusInternalError, "") + c1, _ := tt.pipe(nil, nil) - err = c1.Close(-1, "") - if !cmp.ErrorContains(err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") { - t.Fatalf("unexpected error: %v", err) - } + err := c1.Close(-1, "") + tt.errContains(err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") }) t.Run("ping", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c1.Close(websocket.StatusInternalError, "") - defer c2.Close(websocket.StatusInternalError, "") + c1, c2 := tt.pipe(nil, nil) - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() - - c2.CloseRead(ctx) - c1.CloseRead(ctx) + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) for i := 0; i < 10; i++ { - err = c1.Ping(ctx) - if err != nil { - t.Fatal(err) - } + err := c1.Ping(tt.ctx) + tt.success(err) } - err = c1.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + err := c1.Close(websocket.StatusNormalClosure, "") + tt.success(err) }) t.Run("badPing", func(t *testing.T) { - t.Parallel() - - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c1.Close(websocket.StatusInternalError, "") - defer c2.Close(websocket.StatusInternalError, "") + tt := newTest(t) + defer tt.done() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() + c1, c2 := tt.pipe(nil, nil) - c2.CloseRead(ctx) + c2.CloseRead(tt.ctx) - err = c1.Ping(ctx) - if !cmp.ErrorContains(err, "failed to wait for pong") { - t.Fatalf("unexpected error: %v", err) - } + err := c1.Ping(tt.ctx) + tt.errContains(err, "failed to wait for pong") }) t.Run("concurrentWrite", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() - - discardLoopErr := xsync.Go(func() error { - for { - _, _, err := c2.Read(ctx) - if websocket.CloseStatus(err) == websocket.StatusNormalClosure { - return nil - } - if err != nil { - return err - } - } - }) - defer func() { - err := <-discardLoopErr - if err != nil { - t.Errorf("discard loop error: %v", err) - } - }() - defer cancel() + c1, c2 := tt.pipe(nil, nil) + tt.goDiscardLoop(c2) msg := xrand.Bytes(xrand.Int(9999)) const count = 100 @@ -191,74 +123,52 @@ func TestConn(t *testing.T) { for i := 0; i < count; i++ { go func() { - errs <- c1.Write(ctx, websocket.MessageBinary, msg) + errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg) }() } for i := 0; i < count; i++ { err := <-errs - if err != nil { - t.Fatal(err) - } + tt.success(err) } - err = c1.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + err := c1.Close(websocket.StatusNormalClosure, "") + tt.success(err) }) t.Run("concurrentWriteError", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") + c1, _ := tt.pipe(nil, nil) - _, err = c1.Writer(context.Background(), websocket.MessageText) - if err != nil { - t.Fatal(err) - } + _, err := c1.Writer(tt.ctx, websocket.MessageText) + tt.success(err) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) defer cancel() err = c1.Write(ctx, websocket.MessageText, []byte("x")) - if !xerrors.Is(err, context.DeadlineExceeded) { - t.Fatal(err) - } + tt.eq(context.DeadlineExceeded, err) }) t.Run("netConn", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() + c1, c2 := tt.pipe(nil, nil) - n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary) - n2 := websocket.NetConn(ctx, c2, websocket.MessageBinary) + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) // Does not give any confidence but at least ensures no crashes. - d, _ := ctx.Deadline() + d, _ := tt.ctx.Deadline() n1.SetDeadline(d) n1.SetDeadline(time.Time{}) - if n1.RemoteAddr() != n1.LocalAddr() { - t.Fatal() - } - if n1.RemoteAddr().String() != "websocket/unknown-addr" || n1.RemoteAddr().Network() != "websocket" { - t.Fatal(n1.RemoteAddr()) - } + tt.eq(n1.RemoteAddr(), n1.LocalAddr()) + tt.eq("websocket/unknown-addr", n1.RemoteAddr().String()) + tt.eq("websocket", n1.RemoteAddr().Network()) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) @@ -269,40 +179,25 @@ func TestConn(t *testing.T) { }) b, err := ioutil.ReadAll(n1) - if err != nil { - t.Fatal(err) - } + tt.success(err) _, err = n1.Read(nil) - if err != io.EOF { - t.Fatalf("expected EOF: %v", err) - } + tt.eq(err, io.EOF) err = <-errs - if err != nil { - t.Fatal(err) - } + tt.success(err) - if !cmp.Equal([]byte("hello"), b) { - t.Fatalf("unexpected msg: %v", cmp.Diff([]byte("hello"), b)) - } + tt.eq([]byte("hello"), b) }) t.Run("netConn/BadMsg", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() + c1, c2 := tt.pipe(nil, nil) - n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary) - n2 := websocket.NetConn(ctx, c2, websocket.MessageText) + n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) + n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) @@ -312,114 +207,60 @@ func TestConn(t *testing.T) { return nil }) - _, err = ioutil.ReadAll(n1) - if !cmp.ErrorContains(err, `unexpected frame type read (expected MessageBinary): MessageText`) { - t.Fatal(err) - } + _, err := ioutil.ReadAll(n1) + tt.errContains(err, `unexpected frame type read (expected MessageBinary): MessageText`) err = <-errs - if err != nil { - t.Fatal(err) - } + tt.success(err) }) t.Run("wsjson", func(t *testing.T) { - t.Parallel() - - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") + tt := newTest(t) + defer tt.done() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() + c1, c2 := tt.pipe(nil, nil) - echoLoopErr := xsync.Go(func() error { - err := wstest.EchoLoop(ctx, c2) - return assertCloseStatus(websocket.StatusNormalClosure, err) - }) - defer func() { - err := <-echoLoopErr - if err != nil { - t.Errorf("echo loop error: %v", err) - } - }() - defer cancel() + tt.goEchoLoop(c2) c1.SetReadLimit(1 << 30) exp := xrand.String(xrand.Int(131072)) werr := xsync.Go(func() error { - return wsjson.Write(ctx, c1, exp) + return wsjson.Write(tt.ctx, c1, exp) }) var act interface{} - err = wsjson.Read(ctx, c1, &act) - if err != nil { - t.Fatal(err) - } - if exp != act { - t.Fatal(cmp.Diff(exp, act)) - } + err := wsjson.Read(tt.ctx, c1, &act) + tt.success(err) + tt.eq(exp, act) err = <-werr - if err != nil { - t.Fatal(err) - } + tt.success(err) err = c1.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + tt.success(err) }) t.Run("wspb", func(t *testing.T) { - t.Parallel() + tt := newTest(t) + defer tt.done() - c1, c2, err := wstest.Pipe(nil, nil) - if err != nil { - t.Fatal(err) - } - defer c2.Close(websocket.StatusInternalError, "") - defer c1.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) - defer cancel() + c1, c2 := tt.pipe(nil, nil) - echoLoopErr := xsync.Go(func() error { - err := wstest.EchoLoop(ctx, c2) - return assertCloseStatus(websocket.StatusNormalClosure, err) - }) - defer func() { - err := <-echoLoopErr - if err != nil { - t.Errorf("echo loop error: %v", err) - } - }() - defer cancel() + tt.goEchoLoop(c2) exp := ptypes.DurationProto(100) - err = wspb.Write(ctx, c1, exp) - if err != nil { - t.Fatal(err) - } + err := wspb.Write(tt.ctx, c1, exp) + tt.success(err) act := &duration.Duration{} - err = wspb.Read(ctx, c1, act) - if err != nil { - t.Fatal(err) - } - if !proto.Equal(exp, act) { - t.Fatal(cmp.Diff(exp, act)) - } + err = wspb.Read(tt.ctx, c1, act) + tt.success(err) + tt.eq(exp, act) err = c1.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + tt.success(err) }) } @@ -443,7 +284,7 @@ func TestWasm(t *testing.T) { err = wstest.EchoLoop(r.Context(), c) if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - t.Errorf("echoLoop: %v", err) + t.Errorf("echoLoop failed: %v", err) } })) defer wg.Wait() @@ -470,3 +311,103 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error { } return nil } + +type test struct { + t *testing.T + ctx context.Context + + doneFuncs []func() +} + +func newTest(t *testing.T) *test { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + tt := &test{t: t, ctx: ctx} + tt.appendDone(cancel) + return tt +} + +func (tt *test) appendDone(f func()) { + tt.doneFuncs = append(tt.doneFuncs, f) +} + +func (tt *test) done() { + for i := len(tt.doneFuncs) - 1; i >= 0; i-- { + tt.doneFuncs[i]() + } +} + +func (tt *test) goEchoLoop(c *websocket.Conn) { + ctx, cancel := context.WithCancel(tt.ctx) + + echoLoopErr := xsync.Go(func() error { + err := wstest.EchoLoop(ctx, c) + return assertCloseStatus(websocket.StatusNormalClosure, err) + }) + tt.appendDone(func() { + cancel() + err := <-echoLoopErr + if err != nil { + tt.t.Errorf("echo loop error: %v", err) + } + }) +} + +func (tt *test) goDiscardLoop(c *websocket.Conn) { + ctx, cancel := context.WithCancel(tt.ctx) + + discardLoopErr := xsync.Go(func() error { + for { + _, _, err := c.Read(ctx) + if websocket.CloseStatus(err) == websocket.StatusNormalClosure { + return nil + } + if err != nil { + return err + } + } + }) + tt.appendDone(func() { + cancel() + err := <-discardLoopErr + if err != nil { + tt.t.Errorf("discard loop error: %v", err) + } + }) +} + +func (tt *test) pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (c1, c2 *websocket.Conn) { + tt.t.Helper() + + c1, c2, err := wstest.Pipe(dialOpts, acceptOpts) + if err != nil { + tt.t.Fatal(err) + } + tt.appendDone(func() { + c2.Close(websocket.StatusInternalError, "") + c1.Close(websocket.StatusInternalError, "") + }) + return c1, c2 +} + +func (tt *test) success(err error) { + tt.t.Helper() + if err != nil { + tt.t.Fatal(err) + } +} + +func (tt *test) errContains(err error, sub string) { + tt.t.Helper() + if !cmp.ErrorContains(err, sub) { + tt.t.Fatalf("error does not contain %q: %v", sub, err) + } +} + +func (tt *test) eq(exp, act interface{}) { + tt.t.Helper() + if !cmp.Equal(exp, act) { + tt.t.Fatalf(cmp.Diff(exp, act)) + } +} diff --git a/internal/test/cmp/cmp.go b/internal/test/cmp/cmp.go index cdbadf70..6f3dd706 100644 --- a/internal/test/cmp/cmp.go +++ b/internal/test/cmp/cmp.go @@ -4,6 +4,7 @@ import ( "reflect" "strings" + "github.com/golang/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" ) @@ -12,7 +13,7 @@ import ( func Equal(v1, v2 interface{}) bool { return cmp.Equal(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { return true - })) + }), cmp.Comparer(proto.Equal)) } // Diff returns a human readable diff between v1 and v2 From 3673c2cf26752428863df479227e1a00b8948ea6 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 13 Feb 2020 01:05:09 -0500 Subject: [PATCH 48/55] Use basic test assertions --- accept_js.go | 19 ++++ accept_test.go | 54 ++++------ autobahn_test.go | 31 ++---- close_test.go | 33 +++--- compress_test.go | 6 +- conn_test.go | 189 +++++++++++++-------------------- dial_test.go | 24 ++--- frame_test.go | 27 ++--- internal/test/assert/assert.go | 46 ++++++++ internal/test/cmp/cmp.go | 18 +--- internal/xsync/go_test.go | 6 +- ws_js_test.go | 22 ++-- 12 files changed, 206 insertions(+), 269 deletions(-) create mode 100644 accept_js.go create mode 100644 internal/test/assert/assert.go diff --git a/accept_js.go b/accept_js.go new file mode 100644 index 00000000..efc92817 --- /dev/null +++ b/accept_js.go @@ -0,0 +1,19 @@ +package websocket + +import ( + "net/http" + + "golang.org/x/xerrors" +) + +// AcceptOptions represents Accept's options. +type AcceptOptions struct { + Subprotocols []string + InsecureSkipVerify bool + CompressionOptions *CompressionOptions +} + +// Accept is stubbed out for Wasm. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + return nil, xerrors.New("unimplemented") +} diff --git a/accept_test.go b/accept_test.go index 354e95ec..53338e17 100644 --- a/accept_test.go +++ b/accept_test.go @@ -12,7 +12,7 @@ import ( "golang.org/x/xerrors" - "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/assert" ) func TestAccept(t *testing.T) { @@ -25,9 +25,7 @@ func TestAccept(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) _, err := Accept(w, r, nil) - if !cmp.ErrorContains(err, "protocol violation") { - t.Fatal(err) - } + assert.Contains(t, err, "protocol violation") }) t.Run("badOrigin", func(t *testing.T) { @@ -42,9 +40,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) - if !cmp.ErrorContains(err, `request Origin "harhar.com" is not authorized for Host`) { - t.Fatal(err) - } + assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host`) }) t.Run("badCompression", func(t *testing.T) { @@ -61,9 +57,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") _, err := Accept(w, r, nil) - if !cmp.ErrorContains(err, `unsupported permessage-deflate parameter`) { - t.Fatal(err) - } + assert.Contains(t, err, `unsupported permessage-deflate parameter`) }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -77,9 +71,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) - if !cmp.ErrorContains(err, `http.ResponseWriter does not implement http.Hijacker`) { - t.Fatal(err) - } + assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`) }) t.Run("badHijack", func(t *testing.T) { @@ -99,9 +91,7 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) - if !cmp.ErrorContains(err, `failed to hijack connection`) { - t.Fatal(err) - } + assert.Contains(t, err, `failed to hijack connection`) }) } @@ -193,8 +183,10 @@ func Test_verifyClientHandshake(t *testing.T) { } _, err := verifyClientRequest(httptest.NewRecorder(), r) - if tc.success != (err == nil) { - t.Fatalf("unexpected error value: %v", err) + if tc.success { + assert.Success(t, err) + } else { + assert.Error(t, err) } }) } @@ -244,9 +236,7 @@ func Test_selectSubprotocol(t *testing.T) { r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) negotiated := selectSubprotocol(r, tc.serverProtocols) - if !cmp.Equal(tc.negotiated, negotiated) { - t.Fatalf("unexpected negotiated: %v", cmp.Diff(tc.negotiated, negotiated)) - } + assert.Equal(t, "negotiated", tc.negotiated, negotiated) }) } } @@ -300,8 +290,10 @@ func Test_authenticateOrigin(t *testing.T) { r.Header.Set("Origin", tc.origin) err := authenticateOrigin(r) - if tc.success != (err == nil) { - t.Fatalf("unexpected error value: %v", err) + if tc.success { + assert.Success(t, err) + } else { + assert.Error(t, err) } }) } @@ -373,21 +365,13 @@ func Test_acceptCompression(t *testing.T) { w := httptest.NewRecorder() copts, err := acceptCompression(r, w, tc.mode) if tc.error { - if err == nil { - t.Fatalf("expected error: %v", copts) - } + assert.Error(t, err) return } - if err != nil { - t.Fatal(err) - } - if !cmp.Equal(tc.expCopts, copts) { - t.Fatalf("unexpected compression options: %v", cmp.Diff(tc.expCopts, copts)) - } - if !cmp.Equal(tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) { - t.Fatalf("unexpected respHeader: %v", cmp.Diff(tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))) - } + assert.Success(t, err) + assert.Equal(t, "compression options", tc.expCopts, copts) + assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) }) } } diff --git a/autobahn_test.go b/autobahn_test.go index 4d0bd1b5..0763bc97 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -19,6 +19,7 @@ import ( "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/wstest" ) @@ -45,32 +46,26 @@ func TestAutobahn(t *testing.T) { defer cancel() wstestURL, closeFn, err := wstestClientServer(ctx) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) defer closeFn() err = waitWS(ctx, wstestURL) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) cases, err := wstestCaseCount(ctx, wstestURL) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) t.Run("cases", func(t *testing.T) { for i := 1; i <= cases; i++ { i := i t.Run("", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) err = wstest.EchoLoop(ctx, c) t.Logf("echoLoop: %v", err) }) @@ -78,9 +73,7 @@ func TestAutobahn(t *testing.T) { }) c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) c.Close(websocket.StatusNormalClosure, "") checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") @@ -172,18 +165,14 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { func checkWSTestIndex(t *testing.T, path string) { wstestOut, err := ioutil.ReadFile(path) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) var indexJSON map[string]map[string]struct { Behavior string `json:"behavior"` BehaviorClose string `json:"behaviorClose"` } err = json.Unmarshal(wstestOut, &indexJSON) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) for _, tests := range indexJSON { for test, result := range tests { diff --git a/close_test.go b/close_test.go index 10a35b13..00a48d9e 100644 --- a/close_test.go +++ b/close_test.go @@ -8,7 +8,7 @@ import ( "strings" "testing" - "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/assert" ) func TestCloseError(t *testing.T) { @@ -51,8 +51,10 @@ func TestCloseError(t *testing.T) { t.Parallel() _, err := tc.ce.bytesErr() - if tc.success != (err == nil) { - t.Fatalf("unexpected error value (wanted err == nil == %v): %v", tc.success, err) + if tc.success { + assert.Success(t, err) + } else { + assert.Error(t, err) } }) } @@ -63,10 +65,7 @@ func TestCloseError(t *testing.T) { Code: StatusInternalError, Reason: "meow", }.Error() - - if (act) != exp { - t.Fatal(cmp.Diff(exp, act)) - } + assert.Equal(t, "CloseError.Error()", exp, act) }) } @@ -114,14 +113,10 @@ func Test_parseClosePayload(t *testing.T) { ce, err := parseClosePayload(tc.p) if tc.success { - if err != nil { - t.Fatal(err) - } - if !cmp.Equal(tc.ce, ce) { - t.Fatalf("expected %v but got %v", tc.ce, ce) - } - } else if err == nil { - t.Errorf("expected error: %v %v", ce, err) + assert.Success(t, err) + assert.Equal(t, "close payload", tc.ce, ce) + } else { + assert.Error(t, err) } }) } @@ -168,9 +163,7 @@ func Test_validWireCloseCode(t *testing.T) { t.Parallel() act := validWireCloseCode(tc.code) - if !cmp.Equal(tc.valid, act) { - t.Fatalf("unexpected valid: %v", cmp.Diff(tc.valid, act)) - } + assert.Equal(t, "wire close code", tc.valid, act) }) } } @@ -208,9 +201,7 @@ func TestCloseStatus(t *testing.T) { t.Parallel() act := CloseStatus(tc.in) - if !cmp.Equal(tc.exp, act) { - t.Fatalf("unexpected closeStatus: %v", cmp.Diff(tc.exp, act)) - } + assert.Equal(t, "close status", tc.exp, act) }) } } diff --git a/compress_test.go b/compress_test.go index 51f658c8..364d542d 100644 --- a/compress_test.go +++ b/compress_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/xrand" ) @@ -23,10 +24,7 @@ func Test_slidingWindow(t *testing.T) { r := newSlidingWindow(windowLength) r.write([]byte(input)) - if cap(r.buf) != windowLength { - t.Fatalf("sliding window length changed somehow: %q and windowLength %d", input, windowLength) - } - + assert.Equal(t, "window length", windowLength, cap(r.buf)) if !strings.HasSuffix(input, string(r.buf)) { t.Fatalf("r.buf is not a suffix of input: %q and %q", input, r.buf) } diff --git a/conn_test.go b/conn_test.go index b2a35af8..f9b52f22 100644 --- a/conn_test.go +++ b/conn_test.go @@ -20,7 +20,7 @@ import ( "golang.org/x/xerrors" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/wstest" "nhooyr.io/websocket/internal/test/xrand" "nhooyr.io/websocket/internal/xsync" @@ -34,26 +34,21 @@ func TestConn(t *testing.T) { t.Run("fuzzData", func(t *testing.T) { t.Parallel() + copts := func() *websocket.CompressionOptions { + return &websocket.CompressionOptions{ + Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), + Threshold: xrand.Int(9999), + } + } + for i := 0; i < 5; i++ { t.Run("", func(t *testing.T) { - tt := newTest(t) - defer tt.done() - - dialCopts := &websocket.CompressionOptions{ - Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), - Threshold: xrand.Int(9999), - } - - acceptCopts := &websocket.CompressionOptions{ - Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), - Threshold: xrand.Int(9999), - } - - c1, c2 := tt.pipe(&websocket.DialOptions{ - CompressionOptions: dialCopts, + tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ + CompressionOptions: copts(), }, &websocket.AcceptOptions{ - CompressionOptions: acceptCopts, + CompressionOptions: copts(), }) + defer tt.done() tt.goEchoLoop(c2) @@ -61,60 +56,53 @@ func TestConn(t *testing.T) { for i := 0; i < 5; i++ { err := wstest.Echo(tt.ctx, c1, 131072) - tt.success(err) + assert.Success(t, err) } err := c1.Close(websocket.StatusNormalClosure, "") - tt.success(err) + assert.Success(t, err) }) } }) t.Run("badClose", func(t *testing.T) { - tt := newTest(t) + tt, c1, _ := newConnTest(t, nil, nil) defer tt.done() - c1, _ := tt.pipe(nil, nil) - err := c1.Close(-1, "") - tt.errContains(err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") + assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") }) t.Run("ping", func(t *testing.T) { - tt := newTest(t) + tt, c1, c2 := newConnTest(t, nil, nil) defer tt.done() - c1, c2 := tt.pipe(nil, nil) - c1.CloseRead(tt.ctx) c2.CloseRead(tt.ctx) for i := 0; i < 10; i++ { err := c1.Ping(tt.ctx) - tt.success(err) + assert.Success(t, err) } err := c1.Close(websocket.StatusNormalClosure, "") - tt.success(err) + assert.Success(t, err) }) t.Run("badPing", func(t *testing.T) { - tt := newTest(t) + tt, c1, c2 := newConnTest(t, nil, nil) defer tt.done() - c1, c2 := tt.pipe(nil, nil) - c2.CloseRead(tt.ctx) err := c1.Ping(tt.ctx) - tt.errContains(err, "failed to wait for pong") + assert.Contains(t, err, "failed to wait for pong") }) t.Run("concurrentWrite", func(t *testing.T) { - tt := newTest(t) + tt, c1, c2 := newConnTest(t, nil, nil) defer tt.done() - c1, c2 := tt.pipe(nil, nil) tt.goDiscardLoop(c2) msg := xrand.Bytes(xrand.Int(9999)) @@ -129,35 +117,31 @@ func TestConn(t *testing.T) { for i := 0; i < count; i++ { err := <-errs - tt.success(err) + assert.Success(t, err) } err := c1.Close(websocket.StatusNormalClosure, "") - tt.success(err) + assert.Success(t, err) }) t.Run("concurrentWriteError", func(t *testing.T) { - tt := newTest(t) + tt, c1, _ := newConnTest(t, nil, nil) defer tt.done() - c1, _ := tt.pipe(nil, nil) - _, err := c1.Writer(tt.ctx, websocket.MessageText) - tt.success(err) + assert.Success(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) defer cancel() err = c1.Write(ctx, websocket.MessageText, []byte("x")) - tt.eq(context.DeadlineExceeded, err) + assert.Equal(t, "write error", context.DeadlineExceeded, err) }) t.Run("netConn", func(t *testing.T) { - tt := newTest(t) + tt, c1, c2 := newConnTest(t, nil, nil) defer tt.done() - c1, c2 := tt.pipe(nil, nil) - n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) @@ -166,9 +150,9 @@ func TestConn(t *testing.T) { n1.SetDeadline(d) n1.SetDeadline(time.Time{}) - tt.eq(n1.RemoteAddr(), n1.LocalAddr()) - tt.eq("websocket/unknown-addr", n1.RemoteAddr().String()) - tt.eq("websocket", n1.RemoteAddr().Network()) + assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr()) + assert.Equal(t, "remote addr string", "websocket/unknown-addr", n1.RemoteAddr().String()) + assert.Equal(t, "remote addr network", "websocket", n1.RemoteAddr().Network()) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) @@ -179,23 +163,21 @@ func TestConn(t *testing.T) { }) b, err := ioutil.ReadAll(n1) - tt.success(err) + assert.Success(t, err) _, err = n1.Read(nil) - tt.eq(err, io.EOF) + assert.Equal(t, "read error", err, io.EOF) err = <-errs - tt.success(err) + assert.Success(t, err) - tt.eq([]byte("hello"), b) + assert.Equal(t, "read msg", []byte("hello"), b) }) t.Run("netConn/BadMsg", func(t *testing.T) { - tt := newTest(t) + tt, c1, c2 := newConnTest(t, nil, nil) defer tt.done() - c1, c2 := tt.pipe(nil, nil) - n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) @@ -208,18 +190,16 @@ func TestConn(t *testing.T) { }) _, err := ioutil.ReadAll(n1) - tt.errContains(err, `unexpected frame type read (expected MessageBinary): MessageText`) + assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`) err = <-errs - tt.success(err) + assert.Success(t, err) }) t.Run("wsjson", func(t *testing.T) { - tt := newTest(t) + tt, c1, c2 := newConnTest(t, nil, nil) defer tt.done() - c1, c2 := tt.pipe(nil, nil) - tt.goEchoLoop(c2) c1.SetReadLimit(1 << 30) @@ -232,35 +212,33 @@ func TestConn(t *testing.T) { var act interface{} err := wsjson.Read(tt.ctx, c1, &act) - tt.success(err) - tt.eq(exp, act) + assert.Success(t, err) + assert.Equal(t, "read msg", exp, act) err = <-werr - tt.success(err) + assert.Success(t, err) err = c1.Close(websocket.StatusNormalClosure, "") - tt.success(err) + assert.Success(t, err) }) t.Run("wspb", func(t *testing.T) { - tt := newTest(t) + tt, c1, c2 := newConnTest(t, nil, nil) defer tt.done() - c1, c2 := tt.pipe(nil, nil) - tt.goEchoLoop(c2) exp := ptypes.DurationProto(100) err := wspb.Write(tt.ctx, c1, exp) - tt.success(err) + assert.Success(t, err) act := &duration.Duration{} err = wspb.Read(tt.ctx, c1, act) - tt.success(err) - tt.eq(exp, act) + assert.Success(t, err) + assert.Equal(t, "read msg", exp, act) err = c1.Close(websocket.StatusNormalClosure, "") - tt.success(err) + assert.Success(t, err) }) } @@ -277,14 +255,17 @@ func TestWasm(t *testing.T) { InsecureSkipVerify: true, }) if err != nil { - t.Error(err) + t.Errorf("echo server failed: %v", err) return } defer c.Close(websocket.StatusInternalError, "") err = wstest.EchoLoop(r.Context(), c) - if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - t.Errorf("echoLoop failed: %v", err) + + err = assertCloseStatus(websocket.StatusNormalClosure, err) + if err != nil { + t.Errorf("echo server failed: %v", err) + return } })) defer wg.Wait() @@ -307,38 +288,47 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error { return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err) } if websocket.CloseStatus(err) != exp { - return xerrors.Errorf("unexpected close status (%v):%v", exp, err) + return xerrors.Errorf("expected close status %v but got ", exp, err) } return nil } -type test struct { +type connTest struct { t *testing.T ctx context.Context doneFuncs []func() } -func newTest(t *testing.T) *test { +func newConnTest(t *testing.T, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { t.Parallel() + t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - tt := &test{t: t, ctx: ctx} + tt = &connTest{t: t, ctx: ctx} tt.appendDone(cancel) - return tt + + c1, c2, err := wstest.Pipe(dialOpts, acceptOpts) + assert.Success(tt.t, err) + tt.appendDone(func() { + c2.Close(websocket.StatusInternalError, "") + c1.Close(websocket.StatusInternalError, "") + }) + + return tt, c1, c2 } -func (tt *test) appendDone(f func()) { +func (tt *connTest) appendDone(f func()) { tt.doneFuncs = append(tt.doneFuncs, f) } -func (tt *test) done() { +func (tt *connTest) done() { for i := len(tt.doneFuncs) - 1; i >= 0; i-- { tt.doneFuncs[i]() } } -func (tt *test) goEchoLoop(c *websocket.Conn) { +func (tt *connTest) goEchoLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) echoLoopErr := xsync.Go(func() error { @@ -354,7 +344,7 @@ func (tt *test) goEchoLoop(c *websocket.Conn) { }) } -func (tt *test) goDiscardLoop(c *websocket.Conn) { +func (tt *connTest) goDiscardLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) discardLoopErr := xsync.Go(func() error { @@ -376,38 +366,3 @@ func (tt *test) goDiscardLoop(c *websocket.Conn) { } }) } - -func (tt *test) pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (c1, c2 *websocket.Conn) { - tt.t.Helper() - - c1, c2, err := wstest.Pipe(dialOpts, acceptOpts) - if err != nil { - tt.t.Fatal(err) - } - tt.appendDone(func() { - c2.Close(websocket.StatusInternalError, "") - c1.Close(websocket.StatusInternalError, "") - }) - return c1, c2 -} - -func (tt *test) success(err error) { - tt.t.Helper() - if err != nil { - tt.t.Fatal(err) - } -} - -func (tt *test) errContains(err error, sub string) { - tt.t.Helper() - if !cmp.ErrorContains(err, sub) { - tt.t.Fatalf("error does not contain %q: %v", sub, err) - } -} - -func (tt *test) eq(exp, act interface{}) { - tt.t.Helper() - if !cmp.Equal(exp, act) { - tt.t.Fatalf(cmp.Diff(exp, act)) - } -} diff --git a/dial_test.go b/dial_test.go index c4657415..06084cc5 100644 --- a/dial_test.go +++ b/dial_test.go @@ -13,7 +13,7 @@ import ( "testing" "time" - "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/assert" ) func TestBadDials(t *testing.T) { @@ -70,9 +70,7 @@ func TestBadDials(t *testing.T) { } _, _, err := dial(ctx, tc.url, tc.opts, tc.rand) - if err == nil { - t.Fatalf("expected error") - } + assert.Error(t, err) }) } }) @@ -90,9 +88,7 @@ func TestBadDials(t *testing.T) { }, nil }), }) - if !cmp.ErrorContains(err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") { - t.Fatal(err) - } + assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0") }) t.Run("badBody", func(t *testing.T) { @@ -117,9 +113,7 @@ func TestBadDials(t *testing.T) { _, _, err := Dial(ctx, "ws://example.com", &DialOptions{ HTTPClient: mockHTTPClient(rt), }) - if !cmp.ErrorContains(err, "response body is not a io.ReadWriteCloser") { - t.Fatal(err) - } + assert.Contains(t, err, "response body is not a io.ReadWriteCloser") }) } @@ -217,9 +211,7 @@ func Test_verifyServerHandshake(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) key, err := secWebSocketKey(rand.Reader) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) r.Header.Set("Sec-WebSocket-Key", key) if resp.Header.Get("Sec-WebSocket-Accept") == "" { @@ -230,8 +222,10 @@ func Test_verifyServerHandshake(t *testing.T) { Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","), } _, err = verifyServerResponse(opts, key, resp) - if (err == nil) != tc.success { - t.Fatalf("unexpected error: %v", err) + if tc.success { + assert.Success(t, err) + } else { + assert.Error(t, err) } }) } diff --git a/frame_test.go b/frame_test.go index 0b770a4c..8745da0b 100644 --- a/frame_test.go +++ b/frame_test.go @@ -16,7 +16,7 @@ import ( "github.com/gobwas/ws" _ "github.com/gorilla/websocket" - "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/assert" ) func TestHeader(t *testing.T) { @@ -81,22 +81,15 @@ func testHeader(t *testing.T, h header) { r := bufio.NewReader(b) err := writeFrameHeader(h, w) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) + err = w.Flush() - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) h2, err := readFrameHeader(r) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) - if !cmp.Equal(h, h2) { - t.Fatal(cmp.Diff(h, h2)) - } + assert.Equal(t, "read header", h, h2) } func Test_mask(t *testing.T) { @@ -108,14 +101,10 @@ func Test_mask(t *testing.T) { gotKey32 := mask(key32, p) expP := []byte{0, 0, 0, 0x0d, 0x6} - if !cmp.Equal(expP, p) { - t.Fatal(cmp.Diff(expP, p)) - } + assert.Equal(t, "p", expP, p) expKey32 := bits.RotateLeft32(key32, -8) - if !cmp.Equal(expKey32, gotKey32) { - t.Fatal(cmp.Diff(expKey32, gotKey32)) - } + assert.Equal(t, "key32", expKey32, gotKey32) } func basicMask(maskKey [4]byte, pos int, b []byte) int { diff --git a/internal/test/assert/assert.go b/internal/test/assert/assert.go new file mode 100644 index 00000000..2bc01dba --- /dev/null +++ b/internal/test/assert/assert.go @@ -0,0 +1,46 @@ +package assert + +import ( + "fmt" + "strings" + "testing" + + "nhooyr.io/websocket/internal/test/cmp" +) + +// Equal asserts exp == act. +func Equal(t testing.TB, name string, exp, act interface{}) { + t.Helper() + + if diff := cmp.Diff(exp, act); diff != "" { + t.Fatalf("unexpected %v: %v", name, diff) + } +} + +// Success asserts err == nil. +func Success(t testing.TB, err error) { + t.Helper() + + if err != nil { + t.Fatal(err) + } +} + +// Error asserts err != nil. +func Error(t testing.TB, err error) { + t.Helper() + + if err == nil { + t.Fatal("expected error") + } +} + +// Contains asserts the fmt.Sprint(v) contains sub. +func Contains(t testing.TB, v interface{}, sub string) { + t.Helper() + + vstr := fmt.Sprint(v) + if !strings.Contains(vstr, sub) { + t.Fatalf("expected %q to contain %q", vstr, sub) + } +} diff --git a/internal/test/cmp/cmp.go b/internal/test/cmp/cmp.go index 6f3dd706..eadcb5d9 100644 --- a/internal/test/cmp/cmp.go +++ b/internal/test/cmp/cmp.go @@ -2,31 +2,15 @@ package cmp import ( "reflect" - "strings" "github.com/golang/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" ) -// Equal checks if v1 and v2 are equal with go-cmp. -func Equal(v1, v2 interface{}) bool { - return cmp.Equal(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { - return true - }), cmp.Comparer(proto.Equal)) -} - // Diff returns a human readable diff between v1 and v2 func Diff(v1, v2 interface{}) string { return cmp.Diff(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool { return true - })) -} - -// ErrorContains returns whether err.Error() contains sub. -func ErrorContains(err error, sub string) bool { - if err == nil { - return false - } - return strings.Contains(err.Error(), sub) + }), cmp.Comparer(proto.Equal)) } diff --git a/internal/xsync/go_test.go b/internal/xsync/go_test.go index c0613e64..dabea8a5 100644 --- a/internal/xsync/go_test.go +++ b/internal/xsync/go_test.go @@ -3,7 +3,7 @@ package xsync import ( "testing" - "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/assert" ) func TestGoRecover(t *testing.T) { @@ -14,7 +14,5 @@ func TestGoRecover(t *testing.T) { }) err := <-errs - if !cmp.ErrorContains(err, "anmol") { - t.Fatalf("unexpected err: %v", err) - } + assert.Contains(t, err, "anmol") } diff --git a/ws_js_test.go b/ws_js_test.go index 8671dd21..e6be6181 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -8,7 +8,7 @@ import ( "time" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/assert" "nhooyr.io/websocket/internal/test/wstest" ) @@ -21,28 +21,18 @@ func TestWasm(t *testing.T) { c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) defer c.Close(websocket.StatusInternalError, "") - if !cmp.Equal("echo", c.Subprotocol()) { - t.Fatalf("unexpected subprotocol: %v", cmp.Diff("echo", c.Subprotocol())) - } - if !cmp.Equal(http.StatusSwitchingProtocols, resp.StatusCode) { - t.Fatalf("unexpected status code: %v", cmp.Diff(http.StatusSwitchingProtocols, resp.StatusCode)) - } + assert.Equal(t, "subprotocol", "echo", c.Subprotocol()) + assert.Equal(t, "response code", http.StatusSwitchingProtocols, resp.StatusCode) c.SetReadLimit(65536) for i := 0; i < 10; i++ { err = wstest.Echo(ctx, c, 65536) - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) } err = c.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatal(err) - } + assert.Success(t, err) } From c5b0a009c19d7240f17ae3857741e8904deadeb6 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 13 Feb 2020 01:34:32 -0500 Subject: [PATCH 49/55] Fix badPing test duration --- conn_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/conn_test.go b/conn_test.go index f9b52f22..e1e6c35c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -95,7 +95,10 @@ func TestConn(t *testing.T) { c2.CloseRead(tt.ctx) - err := c1.Ping(tt.ctx) + ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100) + defer cancel() + + err := c1.Ping(ctx) assert.Contains(t, err, "failed to wait for pong") }) From 1c7c14ea4a79ee48a621e419aacff447dec3c8bf Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 13 Feb 2020 01:46:17 -0500 Subject: [PATCH 50/55] Pool sliding windows --- compress_notjs.go | 17 ++++++++++++++++- read.go | 4 ++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/compress_notjs.go b/compress_notjs.go index 6ab6e284..3f0d8b9f 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -115,16 +115,31 @@ func putFlateWriter(w *flate.Writer) { } type slidingWindow struct { - r io.Reader buf []byte } +var swPool = map[int]*sync.Pool{} + func newSlidingWindow(n int) *slidingWindow { + p, ok := swPool[n] + if !ok { + p = &sync.Pool{} + swPool[n] = p + } + sw, ok := p.Get().(*slidingWindow) + if ok { + return sw + } return &slidingWindow{ buf: make([]byte, 0, n), } } +func returnSlidingWindow(sw *slidingWindow) { + sw.buf = sw.buf[:0] + swPool[cap(sw.buf)].Put(sw) +} + func (w *slidingWindow) write(p []byte) { if len(p) >= cap(w.buf) { w.buf = w.buf[:cap(w.buf)] diff --git a/read.go b/read.go index a9c291d1..49c03b40 100644 --- a/read.go +++ b/read.go @@ -110,6 +110,10 @@ func (mr *msgReader) returnFlateReader() { func (mr *msgReader) close() { mr.c.readMu.Lock(context.Background()) mr.returnFlateReader() + + if mr.dict != nil { + returnSlidingWindow(mr.dict) + } } func (mr *msgReader) flateContextTakeover() bool { From 503b4696fcbad5c2c18e364fcc31540a7c5e43e9 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 13 Feb 2020 01:57:19 -0500 Subject: [PATCH 51/55] Simplify sliding window API --- compress_notjs.go | 43 +++++++++++++++++++++++++------------------ compress_test.go | 11 ++++++----- conn_test.go | 7 +++---- read.go | 16 +++++----------- 4 files changed, 39 insertions(+), 38 deletions(-) diff --git a/compress_notjs.go b/compress_notjs.go index 3f0d8b9f..20761362 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -120,41 +120,48 @@ type slidingWindow struct { var swPool = map[int]*sync.Pool{} -func newSlidingWindow(n int) *slidingWindow { +func (sw *slidingWindow) init(n int) { + if sw.buf != nil { + return + } + p, ok := swPool[n] if !ok { p = &sync.Pool{} swPool[n] = p } - sw, ok := p.Get().(*slidingWindow) + buf, ok := p.Get().([]byte) if ok { - return sw - } - return &slidingWindow{ - buf: make([]byte, 0, n), + sw.buf = buf[:0] + } else { + sw.buf = make([]byte, 0, n) } } -func returnSlidingWindow(sw *slidingWindow) { - sw.buf = sw.buf[:0] - swPool[cap(sw.buf)].Put(sw) +func (sw *slidingWindow) close() { + if sw.buf == nil { + return + } + + swPool[cap(sw.buf)].Put(sw.buf) + sw.buf = nil } -func (w *slidingWindow) write(p []byte) { - if len(p) >= cap(w.buf) { - w.buf = w.buf[:cap(w.buf)] - p = p[len(p)-cap(w.buf):] - copy(w.buf, p) +func (sw *slidingWindow) write(p []byte) { + if len(p) >= cap(sw.buf) { + sw.buf = sw.buf[:cap(sw.buf)] + p = p[len(p)-cap(sw.buf):] + copy(sw.buf, p) return } - left := cap(w.buf) - len(w.buf) + left := cap(sw.buf) - len(sw.buf) if left < len(p) { // We need to shift spaceNeeded bytes from the end to make room for p at the end. spaceNeeded := len(p) - left - copy(w.buf, w.buf[spaceNeeded:]) - w.buf = w.buf[:len(w.buf)-spaceNeeded] + copy(sw.buf, sw.buf[spaceNeeded:]) + sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] } - w.buf = append(w.buf, p...) + sw.buf = append(sw.buf, p...) } diff --git a/compress_test.go b/compress_test.go index 364d542d..2c4c896c 100644 --- a/compress_test.go +++ b/compress_test.go @@ -21,12 +21,13 @@ func Test_slidingWindow(t *testing.T) { input := xrand.String(maxWindow) windowLength := xrand.Int(maxWindow) - r := newSlidingWindow(windowLength) - r.write([]byte(input)) + var sw slidingWindow + sw.init(windowLength) + sw.write([]byte(input)) - assert.Equal(t, "window length", windowLength, cap(r.buf)) - if !strings.HasSuffix(input, string(r.buf)) { - t.Fatalf("r.buf is not a suffix of input: %q and %q", input, r.buf) + assert.Equal(t, "window length", windowLength, cap(sw.buf)) + if !strings.HasSuffix(input, string(sw.buf)) { + t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf) } }) } diff --git a/conn_test.go b/conn_test.go index e1e6c35c..25b0809d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -351,13 +351,12 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) discardLoopErr := xsync.Go(func() error { + defer c.Close(websocket.StatusInternalError, "") + for { _, _, err := c.Read(ctx) - if websocket.CloseStatus(err) == websocket.StatusNormalClosure { - return nil - } if err != nil { - return err + return assertCloseStatus(websocket.StatusNormalClosure, err) } } }) diff --git a/read.go b/read.go index 49c03b40..dd73ac92 100644 --- a/read.go +++ b/read.go @@ -87,15 +87,11 @@ func newMsgReader(c *Conn) *msgReader { } func (mr *msgReader) resetFlate() { - if mr.flateContextTakeover() && mr.dict == nil { - mr.dict = newSlidingWindow(32768) - } - if mr.flateContextTakeover() { - mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf) - } else { - mr.flateReader = getFlateReader(readerFunc(mr.read), nil) + mr.dict.init(32768) } + + mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf) mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } @@ -111,9 +107,7 @@ func (mr *msgReader) close() { mr.c.readMu.Lock(context.Background()) mr.returnFlateReader() - if mr.dict != nil { - returnSlidingWindow(mr.dict) - } + mr.dict.close() } func (mr *msgReader) flateContextTakeover() bool { @@ -325,7 +319,7 @@ type msgReader struct { flateReader io.Reader flateTail strings.Reader limitReader *limitReader - dict *slidingWindow + dict slidingWindow fin bool payloadLength int64 From dff4af3cbf8ae30e4961ff2a6b32e46344c1b424 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Thu, 13 Feb 2020 02:54:18 -0500 Subject: [PATCH 52/55] Add conn benchmark --- .gitignore | 1 + autobahn_test.go | 2 - compress_notjs.go | 7 +++ conn_notjs.go | 1 + conn_test.go | 124 ++++++++++++++++++++++++++++++++++++++----- frame.go | 13 +++-- frame_test.go | 3 +- internal/xsync/go.go | 2 +- read.go | 6 ++- 9 files changed, 132 insertions(+), 27 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..6961e5c8 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +websocket.test diff --git a/autobahn_test.go b/autobahn_test.go index 0763bc97..fb24a06b 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -59,8 +59,6 @@ func TestAutobahn(t *testing.T) { for i := 1; i <= cases; i++ { i := i t.Run("", func(t *testing.T) { - t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() diff --git a/compress_notjs.go b/compress_notjs.go index 20761362..7c6b2fc0 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -118,6 +118,7 @@ type slidingWindow struct { buf []byte } +var swPoolMu sync.Mutex var swPool = map[int]*sync.Pool{} func (sw *slidingWindow) init(n int) { @@ -125,6 +126,9 @@ func (sw *slidingWindow) init(n int) { return } + swPoolMu.Lock() + defer swPoolMu.Unlock() + p, ok := swPool[n] if !ok { p = &sync.Pool{} @@ -143,6 +147,9 @@ func (sw *slidingWindow) close() { return } + swPoolMu.Lock() + defer swPoolMu.Unlock() + swPool[cap(sw.buf)].Put(sw.buf) sw.buf = nil } diff --git a/conn_notjs.go b/conn_notjs.go index 4d8762bf..178fcad0 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -39,6 +39,7 @@ type Conn struct { // Read state. readMu *mu + readHeader header readControlBuf [maxControlPayload]byte msgReader *msgReader readCloseFrameErr error diff --git a/conn_test.go b/conn_test.go index 25b0809d..265156e9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,7 +3,9 @@ package websocket_test import ( + "bytes" "context" + "crypto/rand" "fmt" "io" "io/ioutil" @@ -48,7 +50,7 @@ func TestConn(t *testing.T) { }, &websocket.AcceptOptions{ CompressionOptions: copts(), }) - defer tt.done() + defer tt.cleanup() tt.goEchoLoop(c2) @@ -67,7 +69,7 @@ func TestConn(t *testing.T) { t.Run("badClose", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() err := c1.Close(-1, "") assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") @@ -75,7 +77,7 @@ func TestConn(t *testing.T) { t.Run("ping", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() c1.CloseRead(tt.ctx) c2.CloseRead(tt.ctx) @@ -91,7 +93,7 @@ func TestConn(t *testing.T) { t.Run("badPing", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() c2.CloseRead(tt.ctx) @@ -104,7 +106,7 @@ func TestConn(t *testing.T) { t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() tt.goDiscardLoop(c2) @@ -129,7 +131,7 @@ func TestConn(t *testing.T) { t.Run("concurrentWriteError", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() _, err := c1.Writer(tt.ctx, websocket.MessageText) assert.Success(t, err) @@ -143,7 +145,7 @@ func TestConn(t *testing.T) { t.Run("netConn", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) @@ -179,7 +181,7 @@ func TestConn(t *testing.T) { t.Run("netConn/BadMsg", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) @@ -201,7 +203,7 @@ func TestConn(t *testing.T) { t.Run("wsjson", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() tt.goEchoLoop(c2) @@ -227,7 +229,7 @@ func TestConn(t *testing.T) { t.Run("wspb", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.done() + defer tt.cleanup() tt.goEchoLoop(c2) @@ -297,14 +299,16 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error { } type connTest struct { - t *testing.T + t testing.TB ctx context.Context doneFuncs []func() } -func newConnTest(t *testing.T, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { - t.Parallel() +func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { + if t, ok := t.(*testing.T); ok { + t.Parallel() + } t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) @@ -325,7 +329,7 @@ func (tt *connTest) appendDone(f func()) { tt.doneFuncs = append(tt.doneFuncs, f) } -func (tt *connTest) done() { +func (tt *connTest) cleanup() { for i := len(tt.doneFuncs) - 1; i >= 0; i-- { tt.doneFuncs[i]() } @@ -368,3 +372,95 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { } }) } + +func BenchmarkConn(b *testing.B) { + var benchCases = []struct { + name string + mode websocket.CompressionMode + }{ + { + name: "compressionDisabled", + mode: websocket.CompressionDisabled, + }, + { + name: "compression", + mode: websocket.CompressionContextTakeover, + }, + { + name: "noContextCompression", + mode: websocket.CompressionNoContextTakeover, + }, + } + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + bb, c1, c2 := newConnTest(b, &websocket.DialOptions{ + CompressionOptions: &websocket.CompressionOptions{Mode: bc.mode}, + }, nil) + defer bb.cleanup() + + bb.goEchoLoop(c2) + + const n = 32768 + writeBuf := make([]byte, n) + readBuf := make([]byte, n) + writes := make(chan websocket.MessageType) + defer close(writes) + werrs := make(chan error) + + go func() { + for typ := range writes { + werrs <- c1.Write(bb.ctx, typ, writeBuf) + } + }() + b.SetBytes(n) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := rand.Reader.Read(writeBuf) + if err != nil { + b.Fatal(err) + } + + expType := websocket.MessageBinary + if writeBuf[0]%2 == 1 { + expType = websocket.MessageText + } + writes <- expType + + typ, r, err := c1.Reader(bb.ctx) + if err != nil { + b.Fatal(err) + } + if expType != typ { + assert.Equal(b, "data type", expType, typ) + } + + _, err = io.ReadFull(r, readBuf) + if err != nil { + b.Fatal(err) + } + + n2, err := r.Read(readBuf) + if err != io.EOF { + assert.Equal(b, "read err", io.EOF, err) + } + if n2 != 0 { + assert.Equal(b, "n2", 0, n2) + } + + if !bytes.Equal(writeBuf, readBuf) { + assert.Equal(b, "msg", writeBuf, readBuf) + } + + err = <-werrs + if err != nil { + b.Fatal(err) + } + } + b.StopTimer() + + err := c1.Close(websocket.StatusNormalClosure, "") + assert.Success(b, err) + }) + } +} diff --git a/frame.go b/frame.go index 0257835e..491ae75c 100644 --- a/frame.go +++ b/frame.go @@ -46,15 +46,14 @@ type header struct { // readFrameHeader reads a header from the reader. // See https://tools.ietf.org/html/rfc6455#section-5.2. -func readFrameHeader(r *bufio.Reader) (_ header, err error) { +func readFrameHeader(h *header, r *bufio.Reader) (err error) { defer errd.Wrap(&err, "failed to read frame header") b, err := r.ReadByte() if err != nil { - return header{}, err + return err } - var h header h.fin = b&(1<<7) != 0 h.rsv1 = b&(1<<6) != 0 h.rsv2 = b&(1<<5) != 0 @@ -64,7 +63,7 @@ func readFrameHeader(r *bufio.Reader) (_ header, err error) { b, err = r.ReadByte() if err != nil { - return header{}, err + return err } h.masked = b&(1<<7) != 0 @@ -81,17 +80,17 @@ func readFrameHeader(r *bufio.Reader) (_ header, err error) { err = binary.Read(r, binary.BigEndian, &h.payloadLength) } if err != nil { - return header{}, err + return err } if h.masked { err = binary.Read(r, binary.LittleEndian, &h.maskKey) if err != nil { - return header{}, err + return err } } - return h, nil + return nil } // maxControlPayload is the maximum length of a control frame payload. diff --git a/frame_test.go b/frame_test.go index 8745da0b..38f1599a 100644 --- a/frame_test.go +++ b/frame_test.go @@ -86,7 +86,8 @@ func testHeader(t *testing.T, h header) { err = w.Flush() assert.Success(t, err) - h2, err := readFrameHeader(r) + var h2 header + err = readFrameHeader(&h2, r) assert.Success(t, err) assert.Equal(t, "read header", h, h2) diff --git a/internal/xsync/go.go b/internal/xsync/go.go index 96cf8103..d88ac622 100644 --- a/internal/xsync/go.go +++ b/internal/xsync/go.go @@ -6,7 +6,7 @@ import ( // Go allows running a function in another goroutine // and waiting for its error. -func Go(fn func() error) chan error { +func Go(fn func() error) <- chan error { errs := make(chan error, 1) go func() { defer func() { diff --git a/read.go b/read.go index dd73ac92..bf7fa6d9 100644 --- a/read.go +++ b/read.go @@ -173,7 +173,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.readTimeout <- ctx: } - h, err := readFrameHeader(c.br) + err := readFrameHeader(&c.readHeader, c.br) if err != nil { select { case <-c.closed: @@ -192,7 +192,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.readTimeout <- context.Background(): } - return h, nil + return c.readHeader, nil } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { @@ -390,6 +390,8 @@ func (mr *msgReader) read(p []byte) (int, error) { return 0, err } mr.setFrame(h) + + return mr.read(p) } if int64(len(p)) > mr.payloadLength { From 2377cca1760dfd3ee74cd945b775aef44b98ebb9 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 15 Feb 2020 15:54:25 -0500 Subject: [PATCH 53/55] Switch to klauspost/compress --- compress_notjs.go | 44 +++++++++++++----- conn_notjs.go | 17 +++---- conn_test.go | 49 +++++++++----------- frame.go | 41 ++++++++++------ frame_test.go | 5 +- go.mod | 1 + go.sum | 2 + internal/xsync/go.go | 2 +- read.go | 93 +++++++++++++++++++++---------------- write.go | 108 ++++++++++++++++++++++++++----------------- 10 files changed, 215 insertions(+), 147 deletions(-) diff --git a/compress_notjs.go b/compress_notjs.go index 7c6b2fc0..a61b7ba4 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -3,10 +3,11 @@ package websocket import ( - "compress/flate" "io" "net/http" "sync" + + "github.com/klauspost/compress/flate" ) func (m CompressionMode) opts() *compressionOptions { @@ -45,10 +46,16 @@ type trimLastFourBytesWriter struct { } func (tw *trimLastFourBytesWriter) reset() { - tw.tail = tw.tail[:0] + if tw != nil && tw.tail != nil { + tw.tail = tw.tail[:0] + } } func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + if tw.tail == nil { + tw.tail = make([]byte, 0, 4) + } + extra := len(tw.tail) + len(p) - 4 if extra <= 0 { @@ -65,7 +72,10 @@ func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { if err != nil { return 0, err } - tw.tail = tw.tail[extra:] + + // Shift remaining bytes in tail over. + n := copy(tw.tail, tw.tail[extra:]) + tw.tail = tw.tail[:n] } // If p is less than or equal to 4 bytes, @@ -118,22 +128,32 @@ type slidingWindow struct { buf []byte } -var swPoolMu sync.Mutex +var swPoolMu sync.RWMutex var swPool = map[int]*sync.Pool{} -func (sw *slidingWindow) init(n int) { - if sw.buf != nil { - return +func slidingWindowPool(n int) *sync.Pool { + swPoolMu.RLock() + p, ok := swPool[n] + swPoolMu.RUnlock() + if ok { + return p } + p = &sync.Pool{} + swPoolMu.Lock() - defer swPoolMu.Unlock() + swPool[n] = p + swPoolMu.Unlock() - p, ok := swPool[n] - if !ok { - p = &sync.Pool{} - swPool[n] = p + return p +} + +func (sw *slidingWindow) init(n int) { + if sw.buf != nil { + return } + + p := slidingWindowPool(n) buf, ok := p.Get().([]byte) if ok { sw.buf = buf[:0] diff --git a/conn_notjs.go b/conn_notjs.go index 178fcad0..e6ff7df3 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -39,16 +39,17 @@ type Conn struct { // Read state. readMu *mu - readHeader header + readHeaderBuf [8]byte readControlBuf [maxControlPayload]byte msgReader *msgReader readCloseFrameErr error // Write state. - msgWriter *msgWriter - writeFrameMu *mu - writeBuf []byte - writeHeader header + msgWriterState *msgWriterState + writeFrameMu *mu + writeBuf []byte + writeHeaderBuf [8]byte + writeHeader header closed chan struct{} closeMu sync.Mutex @@ -94,14 +95,14 @@ func newConn(cfg connConfig) *Conn { c.msgReader = newMsgReader(c) - c.msgWriter = newMsgWriter(c) + c.msgWriterState = newMsgWriterState(c) if c.client { c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) } if c.flate() && c.flateThreshold == 0 { c.flateThreshold = 256 - if !c.msgWriter.flateContextTakeover() { + if !c.msgWriterState.flateContextTakeover() { c.flateThreshold = 512 } } @@ -142,7 +143,7 @@ func (c *Conn) close(err error) { c.writeFrameMu.Lock(context.Background()) putBufioWriter(c.bw) } - c.msgWriter.close() + c.msgWriterState.close() c.msgReader.close() if c.client { diff --git a/conn_test.go b/conn_test.go index 265156e9..398ffd51 100644 --- a/conn_test.go +++ b/conn_test.go @@ -5,7 +5,6 @@ package websocket_test import ( "bytes" "context" - "crypto/rand" "fmt" "io" "io/ioutil" @@ -13,6 +12,7 @@ import ( "net/http/httptest" "os" "os/exec" + "strings" "sync" "testing" "time" @@ -379,15 +379,15 @@ func BenchmarkConn(b *testing.B) { mode websocket.CompressionMode }{ { - name: "compressionDisabled", + name: "disabledCompress", mode: websocket.CompressionDisabled, }, { - name: "compression", + name: "compress", mode: websocket.CompressionContextTakeover, }, { - name: "noContextCompression", + name: "compressNoContext", mode: websocket.CompressionNoContextTakeover, }, } @@ -395,44 +395,36 @@ func BenchmarkConn(b *testing.B) { b.Run(bc.name, func(b *testing.B) { bb, c1, c2 := newConnTest(b, &websocket.DialOptions{ CompressionOptions: &websocket.CompressionOptions{Mode: bc.mode}, - }, nil) + }, &websocket.AcceptOptions{ + CompressionOptions: &websocket.CompressionOptions{Mode: bc.mode}, + }) defer bb.cleanup() bb.goEchoLoop(c2) - const n = 32768 - writeBuf := make([]byte, n) - readBuf := make([]byte, n) - writes := make(chan websocket.MessageType) + msg := []byte(strings.Repeat("1234", 128)) + readBuf := make([]byte, len(msg)) + writes := make(chan struct{}) defer close(writes) werrs := make(chan error) go func() { - for typ := range writes { - werrs <- c1.Write(bb.ctx, typ, writeBuf) + for range writes { + werrs <- c1.Write(bb.ctx, websocket.MessageText, msg) } }() - b.SetBytes(n) + b.SetBytes(int64(len(msg))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := rand.Reader.Read(writeBuf) - if err != nil { - b.Fatal(err) - } - - expType := websocket.MessageBinary - if writeBuf[0]%2 == 1 { - expType = websocket.MessageText - } - writes <- expType + writes <- struct{}{} typ, r, err := c1.Reader(bb.ctx) if err != nil { b.Fatal(err) } - if expType != typ { - assert.Equal(b, "data type", expType, typ) + if websocket.MessageText != typ { + assert.Equal(b, "data type", websocket.MessageText, typ) } _, err = io.ReadFull(r, readBuf) @@ -448,8 +440,8 @@ func BenchmarkConn(b *testing.B) { assert.Equal(b, "n2", 0, n2) } - if !bytes.Equal(writeBuf, readBuf) { - assert.Equal(b, "msg", writeBuf, readBuf) + if !bytes.Equal(msg, readBuf) { + assert.Equal(b, "msg", msg, readBuf) } err = <-werrs @@ -464,3 +456,8 @@ func BenchmarkConn(b *testing.B) { }) } } + +func TestCompression(t *testing.T) { + t.Parallel() + +} diff --git a/frame.go b/frame.go index 491ae75c..4acaecf4 100644 --- a/frame.go +++ b/frame.go @@ -3,9 +3,12 @@ package websocket import ( "bufio" "encoding/binary" + "io" "math" "math/bits" + "golang.org/x/xerrors" + "nhooyr.io/websocket/internal/errd" ) @@ -46,12 +49,12 @@ type header struct { // readFrameHeader reads a header from the reader. // See https://tools.ietf.org/html/rfc6455#section-5.2. -func readFrameHeader(h *header, r *bufio.Reader) (err error) { +func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { defer errd.Wrap(&err, "failed to read frame header") b, err := r.ReadByte() if err != nil { - return err + return header{}, err } h.fin = b&(1<<7) != 0 @@ -63,7 +66,7 @@ func readFrameHeader(h *header, r *bufio.Reader) (err error) { b, err = r.ReadByte() if err != nil { - return err + return header{}, err } h.masked = b&(1<<7) != 0 @@ -73,24 +76,29 @@ func readFrameHeader(h *header, r *bufio.Reader) (err error) { case payloadLength < 126: h.payloadLength = int64(payloadLength) case payloadLength == 126: - var pl uint16 - err = binary.Read(r, binary.BigEndian, &pl) - h.payloadLength = int64(pl) + _, err = io.ReadFull(r, readBuf[:2]) + h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) case payloadLength == 127: - err = binary.Read(r, binary.BigEndian, &h.payloadLength) + _, err = io.ReadFull(r, readBuf) + h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) } if err != nil { - return err + return header{}, err + } + + if h.payloadLength < 0 { + return header{}, xerrors.Errorf("received negative payload length: %v", h.payloadLength) } if h.masked { - err = binary.Read(r, binary.LittleEndian, &h.maskKey) + _, err = io.ReadFull(r, readBuf[:4]) if err != nil { - return err + return header{}, err } + h.maskKey = binary.LittleEndian.Uint32(readBuf) } - return nil + return h, nil } // maxControlPayload is the maximum length of a control frame payload. @@ -99,7 +107,7 @@ const maxControlPayload = 125 // writeFrameHeader writes the bytes of the header to w. // See https://tools.ietf.org/html/rfc6455#section-5.2 -func writeFrameHeader(h header, w *bufio.Writer) (err error) { +func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { defer errd.Wrap(&err, "failed to write frame header") var b byte @@ -143,16 +151,19 @@ func writeFrameHeader(h header, w *bufio.Writer) (err error) { switch { case h.payloadLength > math.MaxUint16: - err = binary.Write(w, binary.BigEndian, h.payloadLength) + binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) + _, err = w.Write(buf) case h.payloadLength > 125: - err = binary.Write(w, binary.BigEndian, uint16(h.payloadLength)) + binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) + _, err = w.Write(buf[:2]) } if err != nil { return err } if h.masked { - err = binary.Write(w, binary.LittleEndian, h.maskKey) + binary.LittleEndian.PutUint32(buf, h.maskKey) + _, err = w.Write(buf[:4]) if err != nil { return err } diff --git a/frame_test.go b/frame_test.go index 38f1599a..76826248 100644 --- a/frame_test.go +++ b/frame_test.go @@ -80,14 +80,13 @@ func testHeader(t *testing.T, h header) { w := bufio.NewWriter(b) r := bufio.NewReader(b) - err := writeFrameHeader(h, w) + err := writeFrameHeader(h, w, make([]byte, 8)) assert.Success(t, err) err = w.Flush() assert.Success(t, err) - var h2 header - err = readFrameHeader(&h2, r) + h2, err := readFrameHeader(r, make([]byte, 8)) assert.Success(t, err) assert.Equal(t, "read header", h, h2) diff --git a/go.mod b/go.mod index cb372391..a10c7b1e 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/golang/protobuf v1.3.3 github.com/google/go-cmp v0.4.0 github.com/gorilla/websocket v1.4.1 + github.com/klauspost/compress v1.10.0 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 ) diff --git a/go.sum b/go.sum index 8cbc66ce..e4bbd62d 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/klauspost/compress v1.10.0 h1:92XGj1AcYzA6UrVdd4qIIBrT8OroryvRvdmg/IfmC7Y= +github.com/klauspost/compress v1.10.0/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/internal/xsync/go.go b/internal/xsync/go.go index d88ac622..712739aa 100644 --- a/internal/xsync/go.go +++ b/internal/xsync/go.go @@ -6,7 +6,7 @@ import ( // Go allows running a function in another goroutine // and waiting for its error. -func Go(fn func() error) <- chan error { +func Go(fn func() error) <-chan error { errs := make(chan error, 1) go func() { defer func() { diff --git a/read.go b/read.go index bf7fa6d9..bbad30d1 100644 --- a/read.go +++ b/read.go @@ -3,6 +3,7 @@ package websocket import ( + "bufio" "context" "io" "io/ioutil" @@ -81,8 +82,9 @@ func newMsgReader(c *Conn) *msgReader { c: c, fin: true, } + mr.readFunc = mr.read - mr.limitReader = newLimitReader(c, readerFunc(mr.read), defaultReadLimit+1) + mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) return mr } @@ -90,13 +92,16 @@ func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() { mr.dict.init(32768) } + if mr.flateBufio == nil { + mr.flateBufio = getBufioReader(mr.readFunc) + } - mr.flateReader = getFlateReader(readerFunc(mr.read), mr.dict.buf) + mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) mr.limitReader.r = mr.flateReader mr.flateTail.Reset(deflateMessageTail) } -func (mr *msgReader) returnFlateReader() { +func (mr *msgReader) putFlateReader() { if mr.flateReader != nil { putFlateReader(mr.flateReader) mr.flateReader = nil @@ -105,9 +110,11 @@ func (mr *msgReader) returnFlateReader() { func (mr *msgReader) close() { mr.c.readMu.Lock(context.Background()) - mr.returnFlateReader() - + mr.putFlateReader() mr.dict.close() + if mr.flateBufio != nil { + putBufioReader(mr.flateBufio) + } } func (mr *msgReader) flateContextTakeover() bool { @@ -173,7 +180,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.readTimeout <- ctx: } - err := readFrameHeader(&c.readHeader, c.br) + h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) if err != nil { select { case <-c.closed: @@ -192,7 +199,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.readTimeout <- context.Background(): } - return c.readHeader, nil + return h, nil } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { @@ -317,6 +324,7 @@ type msgReader struct { ctx context.Context flate bool flateReader io.Reader + flateBufio *bufio.Reader flateTail strings.Reader limitReader *limitReader dict slidingWindow @@ -324,12 +332,15 @@ type msgReader struct { fin bool payloadLength int64 maskKey uint32 + + // readerFunc(mr.Read) to avoid continuous allocations. + readFunc readerFunc } func (mr *msgReader) reset(ctx context.Context, h header) { mr.ctx = ctx mr.flate = h.rsv1 - mr.limitReader.reset(readerFunc(mr.read)) + mr.limitReader.reset(mr.readFunc) if mr.flate { mr.resetFlate() @@ -346,15 +357,15 @@ func (mr *msgReader) setFrame(h header) { func (mr *msgReader) Read(p []byte) (n int, err error) { defer func() { - errd.Wrap(&err, "failed to read") if xerrors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { err = io.EOF } if xerrors.Is(err, io.EOF) { err = io.EOF - - mr.returnFlateReader() + mr.putFlateReader() + return } + errd.Wrap(&err, "failed to read") }() err = mr.c.readMu.Lock(mr.ctx) @@ -372,44 +383,46 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { } func (mr *msgReader) read(p []byte) (int, error) { - if mr.payloadLength == 0 { - if mr.fin { - if mr.flate { - return mr.flateTail.Read(p) + for { + if mr.payloadLength == 0 { + if mr.fin { + if mr.flate { + return mr.flateTail.Read(p) + } + return 0, io.EOF } - return 0, io.EOF - } - h, err := mr.c.readLoop(mr.ctx) - if err != nil { - return 0, err - } - if h.opcode != opContinuation { - err := xerrors.New("received new data message without finishing the previous message") - mr.c.writeError(StatusProtocolError, err) - return 0, err + h, err := mr.c.readLoop(mr.ctx) + if err != nil { + return 0, err + } + if h.opcode != opContinuation { + err := xerrors.New("received new data message without finishing the previous message") + mr.c.writeError(StatusProtocolError, err) + return 0, err + } + mr.setFrame(h) + + continue } - mr.setFrame(h) - return mr.read(p) - } + if int64(len(p)) > mr.payloadLength { + p = p[:mr.payloadLength] + } - if int64(len(p)) > mr.payloadLength { - p = p[:mr.payloadLength] - } + n, err := mr.c.readFramePayload(mr.ctx, p) + if err != nil { + return n, err + } - n, err := mr.c.readFramePayload(mr.ctx, p) - if err != nil { - return n, err - } + mr.payloadLength -= int64(n) - mr.payloadLength -= int64(n) + if !mr.c.client { + mr.maskKey = mask(mr.maskKey, p) + } - if !mr.c.client { - mr.maskKey = mask(mr.maskKey, p) + return n, nil } - - return n, nil } type limitReader struct { diff --git a/write.go b/write.go index 9d4b670f..ec3b7d05 100644 --- a/write.go +++ b/write.go @@ -4,7 +4,6 @@ package websocket import ( "bufio" - "compress/flate" "context" "crypto/rand" "encoding/binary" @@ -12,6 +11,8 @@ import ( "sync" "time" + "github.com/klauspost/compress/flate" + kflate "github.com/klauspost/compress/flate" "golang.org/x/xerrors" "nhooyr.io/websocket/internal/errd" @@ -24,8 +25,6 @@ import ( // // Only one writer can be open at a time, multiple calls will block until the previous writer // is closed. -// -// Never close the returned writer twice. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { w, err := c.writer(ctx, typ) if err != nil { @@ -49,6 +48,26 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { } type msgWriter struct { + mw *msgWriterState + closed bool +} + +func (mw *msgWriter) Write(p []byte) (int, error) { + if mw.closed { + return 0, xerrors.New("cannot use closed writer") + } + return mw.mw.Write(p) +} + +func (mw *msgWriter) Close() error { + if mw.closed { + return xerrors.New("cannot use closed writer") + } + mw.closed = true + return mw.mw.Close() +} + +type msgWriterState struct { c *Conn mu *mu @@ -56,36 +75,42 @@ type msgWriter struct { ctx context.Context opcode opcode - closed bool flate bool trimWriter *trimLastFourBytesWriter flateWriter *flate.Writer + dict slidingWindow } -func newMsgWriter(c *Conn) *msgWriter { - mw := &msgWriter{ +func newMsgWriterState(c *Conn) *msgWriterState { + mw := &msgWriterState{ c: c, mu: newMu(c), } return mw } -func (mw *msgWriter) ensureFlate() { +const stateless = true + +func (mw *msgWriterState) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ w: writerFunc(mw.write), } } - if mw.flateWriter == nil { - mw.flateWriter = getFlateWriter(mw.trimWriter) + if stateless { + mw.dict.init(8192) + } else { + if mw.flateWriter == nil { + mw.flateWriter = getFlateWriter(mw.trimWriter) + } } mw.flate = true } -func (mw *msgWriter) flateContextTakeover() bool { +func (mw *msgWriterState) flateContextTakeover() bool { if mw.c.client { return !mw.c.copts.clientNoContextTakeover } @@ -93,11 +118,14 @@ func (mw *msgWriter) flateContextTakeover() bool { } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := c.msgWriter.reset(ctx, typ) + err := c.msgWriterState.reset(ctx, typ) if err != nil { return nil, err } - return c.msgWriter, nil + return &msgWriter{ + mw: c.msgWriterState, + closed: false, + }, nil } func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { @@ -107,8 +135,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error } if !c.flate() { - defer c.msgWriter.mu.Unlock() - return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p) + defer c.msgWriterState.mu.Unlock() + return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) } n, err := mw.Write(p) @@ -120,25 +148,22 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error return n, err } -func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { +func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { err := mw.mu.Lock(ctx) if err != nil { return err } - mw.closed = false mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false - if mw.trimWriter != nil { - mw.trimWriter.reset() - } + mw.trimWriter.reset() return nil } -func (mw *msgWriter) returnFlateWriter() { +func (mw *msgWriterState) putFlateWriter() { if mw.flateWriter != nil { putFlateWriter(mw.flateWriter) mw.flateWriter = nil @@ -146,16 +171,12 @@ func (mw *msgWriter) returnFlateWriter() { } // Write writes the given bytes to the WebSocket connection. -func (mw *msgWriter) Write(p []byte) (_ int, err error) { +func (mw *msgWriterState) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") mw.writeMu.Lock() defer mw.writeMu.Unlock() - if mw.closed { - return 0, xerrors.New("cannot use closed writer") - } - if mw.c.flate() { // Only enables flate if the length crosses the // threshold on the first frame @@ -165,13 +186,21 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { } if mw.flate { + if stateless { + err = kflate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) + if err != nil { + return 0, err + } + mw.dict.write(p) + return len(p), nil + } return mw.flateWriter.Write(p) } return mw.write(p) } -func (mw *msgWriter) write(p []byte) (int, error) { +func (mw *msgWriterState) write(p []byte) (int, error) { n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) if err != nil { return n, xerrors.Errorf("failed to write data frame: %w", err) @@ -181,42 +210,36 @@ func (mw *msgWriter) write(p []byte) (int, error) { } // Close flushes the frame to the connection. -func (mw *msgWriter) Close() (err error) { +func (mw *msgWriterState) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") mw.writeMu.Lock() defer mw.writeMu.Unlock() - if mw.closed { - return xerrors.New("cannot use closed writer") - } - - if mw.flate { + if mw.flate && !stateless { err = mw.flateWriter.Flush() if err != nil { - return xerrors.Errorf("failed to flush flate writer: %w", err) + return xerrors.Errorf("failed to flush flate: %w", err) } } - // We set closed after flushing the flate writer to ensure Write - // can succeed. - mw.closed = true - _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return xerrors.Errorf("failed to write fin frame: %w", err) } if mw.flate && !mw.flateContextTakeover() { - mw.returnFlateWriter() + mw.dict.close() + mw.putFlateWriter() } mw.mu.Unlock() return nil } -func (mw *msgWriter) close() { +func (mw *msgWriterState) close() { mw.writeMu.Lock() - mw.returnFlateWriter() + mw.putFlateWriter() + mw.dict.close() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { @@ -250,10 +273,11 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco if c.client { c.writeHeader.masked = true - err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey) + _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) if err != nil { return 0, xerrors.Errorf("failed to generate masking key: %w", err) } + c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) } c.writeHeader.rsv1 = false @@ -261,7 +285,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco c.writeHeader.rsv1 = true } - err = writeFrameHeader(c.writeHeader, c.bw) + err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) if err != nil { return 0, err } From d57b25304679bcde2d1fa519f9bf569917a40762 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 15 Feb 2020 21:15:43 -0500 Subject: [PATCH 54/55] Report how efficient compression is in BenchmarkConn --- ci/test.mk | 2 +- compress_notjs.go | 16 ----------- conn_test.go | 6 +++++ export_test.go | 22 +++++++++++++++ internal/test/assert/assert.go | 6 ++--- write.go | 49 +++++++--------------------------- 6 files changed, 41 insertions(+), 60 deletions(-) create mode 100644 export_test.go diff --git a/ci/test.mk b/ci/test.mk index 3fc34bbf..3d1f0ed1 100644 --- a/ci/test.mk +++ b/ci/test.mk @@ -1,4 +1,4 @@ -test: gotest ci/out/coverage.html +test: ci/out/coverage.html ifdef CI test: coveralls endif diff --git a/compress_notjs.go b/compress_notjs.go index a61b7ba4..a6911056 100644 --- a/compress_notjs.go +++ b/compress_notjs.go @@ -108,22 +108,6 @@ func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } -var flateWriterPool sync.Pool - -func getFlateWriter(w io.Writer) *flate.Writer { - fw, ok := flateWriterPool.Get().(*flate.Writer) - if !ok { - fw, _ = flate.NewWriter(w, flate.BestSpeed) - return fw - } - fw.Reset(w) - return fw -} - -func putFlateWriter(w *flate.Writer) { - flateWriterPool.Put(w) -} - type slidingWindow struct { buf []byte } diff --git a/conn_test.go b/conn_test.go index 398ffd51..3b7fcdb5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -402,6 +402,9 @@ func BenchmarkConn(b *testing.B) { bb.goEchoLoop(c2) + bytesWritten := c1.RecordBytesWritten() + bytesRead := c1.RecordBytesRead() + msg := []byte(strings.Repeat("1234", 128)) readBuf := make([]byte, len(msg)) writes := make(chan struct{}) @@ -451,6 +454,9 @@ func BenchmarkConn(b *testing.B) { } b.StopTimer() + b.ReportMetric(float64(*bytesWritten/b.N), "written/op") + b.ReportMetric(float64(*bytesRead/b.N), "read/op") + err := c1.Close(websocket.StatusNormalClosure, "") assert.Success(b, err) }) diff --git a/export_test.go b/export_test.go new file mode 100644 index 00000000..88b82c9f --- /dev/null +++ b/export_test.go @@ -0,0 +1,22 @@ +// +build !js + +package websocket + +func (c *Conn) RecordBytesWritten() *int { + var bytesWritten int + c.bw.Reset(writerFunc(func(p []byte) (int, error) { + bytesWritten += len(p) + return c.rwc.Write(p) + })) + return &bytesWritten +} + +func (c *Conn) RecordBytesRead() *int { + var bytesRead int + c.br.Reset(readerFunc(func(p []byte) (int, error) { + n, err := c.rwc.Read(p) + bytesRead += n + return n, err + })) + return &bytesRead +} diff --git a/internal/test/assert/assert.go b/internal/test/assert/assert.go index 2bc01dba..602b887e 100644 --- a/internal/test/assert/assert.go +++ b/internal/test/assert/assert.go @@ -39,8 +39,8 @@ func Error(t testing.TB, err error) { func Contains(t testing.TB, v interface{}, sub string) { t.Helper() - vstr := fmt.Sprint(v) - if !strings.Contains(vstr, sub) { - t.Fatalf("expected %q to contain %q", vstr, sub) + s := fmt.Sprint(v) + if !strings.Contains(s, sub) { + t.Fatalf("expected %q to contain %q", s, sub) } } diff --git a/write.go b/write.go index ec3b7d05..b560b44c 100644 --- a/write.go +++ b/write.go @@ -12,7 +12,6 @@ import ( "time" "github.com/klauspost/compress/flate" - kflate "github.com/klauspost/compress/flate" "golang.org/x/xerrors" "nhooyr.io/websocket/internal/errd" @@ -77,9 +76,8 @@ type msgWriterState struct { opcode opcode flate bool - trimWriter *trimLastFourBytesWriter - flateWriter *flate.Writer - dict slidingWindow + trimWriter *trimLastFourBytesWriter + dict slidingWindow } func newMsgWriterState(c *Conn) *msgWriterState { @@ -90,8 +88,6 @@ func newMsgWriterState(c *Conn) *msgWriterState { return mw } -const stateless = true - func (mw *msgWriterState) ensureFlate() { if mw.trimWriter == nil { mw.trimWriter = &trimLastFourBytesWriter{ @@ -99,14 +95,7 @@ func (mw *msgWriterState) ensureFlate() { } } - if stateless { - mw.dict.init(8192) - } else { - if mw.flateWriter == nil { - mw.flateWriter = getFlateWriter(mw.trimWriter) - } - } - + mw.dict.init(8192) mw.flate = true } @@ -163,13 +152,6 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { return nil } -func (mw *msgWriterState) putFlateWriter() { - if mw.flateWriter != nil { - putFlateWriter(mw.flateWriter) - mw.flateWriter = nil - } -} - // Write writes the given bytes to the WebSocket connection. func (mw *msgWriterState) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") @@ -186,15 +168,12 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) { } if mw.flate { - if stateless { - err = kflate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) - if err != nil { - return 0, err - } - mw.dict.write(p) - return len(p), nil + err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) + if err != nil { + return 0, err } - return mw.flateWriter.Write(p) + mw.dict.write(p) + return len(p), nil } return mw.write(p) @@ -216,13 +195,6 @@ func (mw *msgWriterState) Close() (err error) { mw.writeMu.Lock() defer mw.writeMu.Unlock() - if mw.flate && !stateless { - err = mw.flateWriter.Flush() - if err != nil { - return xerrors.Errorf("failed to flush flate: %w", err) - } - } - _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return xerrors.Errorf("failed to write fin frame: %w", err) @@ -230,7 +202,6 @@ func (mw *msgWriterState) Close() (err error) { if mw.flate && !mw.flateContextTakeover() { mw.dict.close() - mw.putFlateWriter() } mw.mu.Unlock() return nil @@ -238,7 +209,6 @@ func (mw *msgWriterState) Close() (err error) { func (mw *msgWriterState) close() { mw.writeMu.Lock() - mw.putFlateWriter() mw.dict.close() } @@ -311,14 +281,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco return n, nil } -func (c *Conn) writeFramePayload(p []byte) (_ int, err error) { +func (c *Conn) writeFramePayload(p []byte) (n int, err error) { defer errd.Wrap(&err, "failed to write frame payload") if !c.writeHeader.masked { return c.bw.Write(p) } - var n int maskKey := c.writeHeader.maskKey for len(p) > 0 { // If the buffer is full, we need to flush. From 1bc100d26f19edced3ad5c6d2853c1241211a766 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Sat, 15 Feb 2020 21:54:57 -0500 Subject: [PATCH 55/55] Update docs and random little issues --- .github/workflows/ci.yml | 6 +++--- README.md | 8 +++++--- accept.go | 22 +++++++++++++--------- accept_js.go | 7 ++++--- close_notjs.go | 6 ++++-- compress.go | 17 +---------------- conn_notjs.go | 2 +- conn_test.go | 17 ++++++++--------- dial.go | 23 ++++++++++++++--------- doc.go | 5 +++-- ws_js.go | 5 +++++ 11 files changed, 61 insertions(+), 57 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 074e5246..4534425f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - - name: make fmt + - name: Run make fmt uses: ./ci/image with: args: make fmt @@ -27,7 +27,7 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - - name: make lint + - name: Run make lint uses: ./ci/image with: args: make lint @@ -42,7 +42,7 @@ jobs: key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go- - - name: make test + - name: Run make test uses: ./ci/image with: args: make test diff --git a/README.md b/README.md index 2569383a..631a14c9 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ go get nhooyr.io/websocket - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support -- Thorough tests, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) -- [Zero dependencies](https://godoc.org/nhooyr.io/websocket?imports) +- Thorough tests, fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) +- [Minimal dependencies](https://godoc.org/nhooyr.io/websocket?imports) - JSON and protobuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes - Concurrent writes @@ -34,7 +34,7 @@ go get nhooyr.io/websocket ## Examples -For a production quality example that demonstrates the full API, see the [echo example](https://godoc.org/nhooyr.io/websocket#example-package--Echo). +For a production quality example that demonstrates the complete API, see the [echo example](https://godoc.org/nhooyr.io/websocket#example-package--Echo). ### Server @@ -111,6 +111,8 @@ Advantages of nhooyr.io/websocket: - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode + - Uses [klauspost/compress](https://github.com/klauspost/compress) for optimized compression + - See [gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203) - [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) diff --git a/accept.go b/accept.go index cc9babb0..75d6d643 100644 --- a/accept.go +++ b/accept.go @@ -37,9 +37,17 @@ type AcceptOptions struct { // If used incorrectly your WebSocket server will be open to CSRF attacks. InsecureSkipVerify bool - // CompressionOptions controls the compression options. - // See docs on the CompressionOptions type. - CompressionOptions *CompressionOptions + // CompressionMode controls the compression mode. + // Defaults to CompressionNoContextTakeover. + // + // See docs on CompressionMode for details. + CompressionMode CompressionMode + + // CompressionThreshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes + // for CompressionContextTakeover. + CompressionThreshold int } // Accept accepts a WebSocket handshake from a client and upgrades the @@ -61,10 +69,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con } opts = &*opts - if opts.CompressionOptions == nil { - opts.CompressionOptions = &CompressionOptions{} - } - errCode, err := verifyClientRequest(w, r) if err != nil { http.Error(w, err.Error(), errCode) @@ -97,7 +101,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con w.Header().Set("Sec-WebSocket-Protocol", subproto) } - copts, err := acceptCompression(r, w, opts.CompressionOptions.Mode) + copts, err := acceptCompression(r, w, opts.CompressionMode) if err != nil { return nil, err } @@ -120,7 +124,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con rwc: netConn, client: false, copts: copts, - flateThreshold: opts.CompressionOptions.Threshold, + flateThreshold: opts.CompressionThreshold, br: brw.Reader, bw: brw.Writer, diff --git a/accept_js.go b/accept_js.go index efc92817..5db12d7b 100644 --- a/accept_js.go +++ b/accept_js.go @@ -8,9 +8,10 @@ import ( // AcceptOptions represents Accept's options. type AcceptOptions struct { - Subprotocols []string - InsecureSkipVerify bool - CompressionOptions *CompressionOptions + Subprotocols []string + InsecureSkipVerify bool + CompressionMode CompressionMode + CompressionThreshold int } // Accept is stubbed out for Wasm. diff --git a/close_notjs.go b/close_notjs.go index 160a1237..3367ea01 100644 --- a/close_notjs.go +++ b/close_notjs.go @@ -35,7 +35,7 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") err = c.writeClose(code, reason) - if err != nil && CloseStatus(err) == -1 { + if err != nil && CloseStatus(err) == -1 && err != errAlreadyWroteClose { return err } @@ -46,13 +46,15 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { return nil } +var errAlreadyWroteClose = xerrors.New("already wrote close") + func (c *Conn) writeClose(code StatusCode, reason string) error { c.closeMu.Lock() closing := c.wroteClose c.wroteClose = true c.closeMu.Unlock() if closing { - return xerrors.New("already wrote close") + return errAlreadyWroteClose } ce := CloseError{ diff --git a/compress.go b/compress.go index 918b3b49..57446d01 100644 --- a/compress.go +++ b/compress.go @@ -1,20 +1,5 @@ package websocket -// CompressionOptions represents the available deflate extension options. -// See https://tools.ietf.org/html/rfc7692 -type CompressionOptions struct { - // Mode controls the compression mode. - // - // See docs on CompressionMode. - Mode CompressionMode - - // Threshold controls the minimum size of a message before compression is applied. - // - // Defaults to 512 bytes for CompressionNoContextTakeover and 256 bytes - // for CompressionContextTakeover. - Threshold int -} - // CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 // @@ -38,7 +23,7 @@ const ( // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. // This enables reusing the sliding window from previous messages. // As most WebSocket protocols are repetitive, this can be very efficient. - // It carries an overhead of 64 kB for every connection compared to CompressionNoContextTakeover. + // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. // // If the peer negotiates NoContextTakeover on the client or server side, it will be // used instead as this is required by the RFC. diff --git a/conn_notjs.go b/conn_notjs.go index e6ff7df3..8598ded3 100644 --- a/conn_notjs.go +++ b/conn_notjs.go @@ -101,7 +101,7 @@ func newConn(cfg connConfig) *Conn { } if c.flate() && c.flateThreshold == 0 { - c.flateThreshold = 256 + c.flateThreshold = 128 if !c.msgWriterState.flateContextTakeover() { c.flateThreshold = 512 } diff --git a/conn_test.go b/conn_test.go index 3b7fcdb5..7755048c 100644 --- a/conn_test.go +++ b/conn_test.go @@ -36,19 +36,18 @@ func TestConn(t *testing.T) { t.Run("fuzzData", func(t *testing.T) { t.Parallel() - copts := func() *websocket.CompressionOptions { - return &websocket.CompressionOptions{ - Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), - Threshold: xrand.Int(9999), - } + compressionMode := func() websocket.CompressionMode { + return websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)) } for i := 0; i < 5; i++ { t.Run("", func(t *testing.T) { tt, c1, c2 := newConnTest(t, &websocket.DialOptions{ - CompressionOptions: copts(), + CompressionMode: compressionMode(), + CompressionThreshold: xrand.Int(9999), }, &websocket.AcceptOptions{ - CompressionOptions: copts(), + CompressionMode: compressionMode(), + CompressionThreshold: xrand.Int(9999), }) defer tt.cleanup() @@ -394,9 +393,9 @@ func BenchmarkConn(b *testing.B) { for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { bb, c1, c2 := newConnTest(b, &websocket.DialOptions{ - CompressionOptions: &websocket.CompressionOptions{Mode: bc.mode}, + CompressionMode: bc.mode, }, &websocket.AcceptOptions{ - CompressionOptions: &websocket.CompressionOptions{Mode: bc.mode}, + CompressionMode: bc.mode, }) defer bb.cleanup() diff --git a/dial.go b/dial.go index 3e2042e5..09546ac6 100644 --- a/dial.go +++ b/dial.go @@ -33,9 +33,17 @@ type DialOptions struct { // Subprotocols lists the WebSocket subprotocols to negotiate with the server. Subprotocols []string - // CompressionOptions controls the compression options. - // See docs on the CompressionOptions type. - CompressionOptions *CompressionOptions + // CompressionMode controls the compression mode. + // Defaults to CompressionNoContextTakeover. + // + // See docs on CompressionMode for details. + CompressionMode CompressionMode + + // CompressionThreshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes + // for CompressionContextTakeover. + CompressionThreshold int } // Dial performs a WebSocket handshake on url. @@ -67,9 +75,6 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } - if opts.CompressionOptions == nil { - opts.CompressionOptions = &CompressionOptions{} - } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { @@ -107,7 +112,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( rwc: rwc, client: true, copts: copts, - flateThreshold: opts.CompressionOptions.Threshold, + flateThreshold: opts.CompressionThreshold, br: getBufioReader(rwc), bw: getBufioWriter(rwc), }), resp, nil @@ -141,8 +146,8 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } - if opts.CompressionOptions.Mode != CompressionDisabled { - copts := opts.CompressionOptions.Mode.opts() + if opts.CompressionMode != CompressionDisabled { + copts := opts.CompressionMode.opts() copts.setHeader(req.Header) } diff --git a/doc.go b/doc.go index c8f5550b..efa920e3 100644 --- a/doc.go +++ b/doc.go @@ -6,7 +6,7 @@ // // Use Dial to dial a WebSocket server. // -// Accept to accept a WebSocket client. +// Use Accept to accept a WebSocket client. // // Conn represents the resulting WebSocket connection. // @@ -25,7 +25,8 @@ // // Some important caveats to be aware of: // +// - Accept always errors out // - Conn.Ping is no-op // - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op -// - *http.Response from Dial is &http.Response{} on success +// - *http.Response from Dial is &http.Response{} with a 101 status code on success package websocket // import "nhooyr.io/websocket" diff --git a/ws_js.go b/ws_js.go index 05c4c062..ecf3d78c 100644 --- a/ws_js.go +++ b/ws_js.go @@ -152,6 +152,11 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { } } +// Ping is mocked out for Wasm. +func (c *Conn) Ping(ctx context.Context) error { + return nil +} + // Write writes a message of the given type to the connection. // Always non blocking. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {