Skip to content

Commit 14baab4

Browse files
committed
Implement compression extension negotiation
1 parent d95cda8 commit 14baab4

File tree

3 files changed

+139
-45
lines changed

3 files changed

+139
-45
lines changed

conn.go

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type Conn struct {
4242
writeBuf []byte
4343
closer io.Closer
4444
client bool
45+
copts *CompressionOptions
4546

4647
closeOnce sync.Once
4748
closeErrOnce sync.Once

handshake.go

+137-44
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
5959
return err
6060
}
6161

62-
if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") {
62+
if !headerContainsToken(r.Header, "Connection", "Upgrade") {
6363
err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
6464
http.Error(w, err.Error(), http.StatusBadRequest)
6565
return err
6666
}
6767

68-
if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") {
68+
if !headerContainsToken(r.Header, "Upgrade", "WebSocket") {
6969
err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
7070
http.Error(w, err.Error(), http.StatusBadRequest)
7171
return err
@@ -144,6 +144,18 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
144144
w.Header().Set("Sec-WebSocket-Protocol", subproto)
145145
}
146146

147+
var copts *CompressionOptions
148+
if opts.Compression != nil {
149+
copts, err = negotiateCompression(r.Header, opts.Compression)
150+
if err != nil {
151+
http.Error(w, err.Error(), http.StatusBadRequest)
152+
return nil, err
153+
}
154+
if copts != nil {
155+
copts.setHeader(w.Header())
156+
}
157+
}
158+
147159
w.WriteHeader(http.StatusSwitchingProtocols)
148160

149161
netConn, brw, err := hj.Hijack()
@@ -162,40 +174,65 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
162174
br: brw.Reader,
163175
bw: brw.Writer,
164176
closer: netConn,
177+
copts: copts,
165178
}
166179
c.init()
167180

168181
return c, nil
169182
}
170183

171-
func headerValuesContainsToken(h http.Header, key, token string) bool {
184+
func headerContainsToken(h http.Header, key, token string) bool {
172185
key = textproto.CanonicalMIMEHeaderKey(key)
173186

174-
for _, val2 := range h[key] {
175-
if headerValueContainsToken(val2, token) {
187+
token = strings.ToLower(token)
188+
match := func(t string) bool {
189+
return t == token
190+
}
191+
192+
for _, v := range h[key] {
193+
if searchHeaderTokens(v, match) != "" {
176194
return true
177195
}
178196
}
179197

180198
return false
181199
}
182200

183-
func headerValueContainsToken(val2, token string) bool {
184-
val2 = strings.TrimSpace(val2)
201+
func headerTokenHasPrefix(h http.Header, key, prefix string) string {
202+
key = textproto.CanonicalMIMEHeaderKey(key)
185203

186-
for _, val2 := range strings.Split(val2, ",") {
187-
val2 = strings.TrimSpace(val2)
188-
if strings.EqualFold(val2, token) {
189-
return true
204+
prefix = strings.ToLower(prefix)
205+
match := func(t string) bool {
206+
return strings.HasPrefix(t, prefix)
207+
}
208+
209+
for _, v := range h[key] {
210+
found := searchHeaderTokens(v, match)
211+
if found != "" {
212+
return found
190213
}
191214
}
192215

193-
return false
216+
return ""
217+
}
218+
219+
func searchHeaderTokens(v string, match func(val string) bool) string {
220+
v = strings.TrimSpace(v)
221+
222+
for _, v2 := range strings.Split(v, ",") {
223+
v2 = strings.TrimSpace(v2)
224+
v2 = strings.ToLower(v2)
225+
if match(v2) {
226+
return v2
227+
}
228+
}
229+
230+
return ""
194231
}
195232

196233
func selectSubprotocol(r *http.Request, subprotocols []string) string {
197234
for _, sp := range subprotocols {
198-
if headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) {
235+
if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) {
199236
return sp
200237
}
201238
}
@@ -268,36 +305,32 @@ type DialOptions struct {
268305
//
269306
// See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression.
270307
//
271-
// Enabling compression will increase memory and CPU usage.
272-
// Thus it is not ideal for every use case and disabled by default.
308+
// Enabling compression will increase memory and CPU usage and should
309+
// be profiled before enabling in production.
273310
// See https://github.com/gorilla/websocket/issues/203
274-
// Profile before enabling in production.
275311
//
276312
// This API is experimental and subject to change.
277313
type CompressionOptions struct {
278-
// ServerNoContextTakeover controls whether the server should use context takeover.
279-
// See docs on CompressionOptions for discussion regarding context takeover.
280-
//
281-
// If set by the client, will guarantee that the server does not use context takeover.
282-
ServerNoContextTakeover bool
283-
284314
// ClientNoContextTakeover controls whether the client should use context takeover.
285315
// See docs on CompressionOptions for discussion regarding context takeover.
286316
//
287317
// If set by the server, will guarantee that the client does not use context takeover.
288318
ClientNoContextTakeover bool
289319

320+
// ServerNoContextTakeover controls whether the server should use context takeover.
321+
// See docs on CompressionOptions for discussion regarding context takeover.
322+
//
323+
// If set by the client, will guarantee that the server does not use context takeover.
324+
ServerNoContextTakeover bool
325+
290326
// Level controls the compression level used.
291327
// Defaults to flate.BestSpeed.
292328
Level int
293329

294330
// Threshold controls the minimum message size in bytes before compression is used.
295-
// In the case of ContextTakeover == false, a flate.Writer will not be grabbed
296-
// from the pool until the message exceeds this threshold.
297-
//
298331
// Must not be greater than 4096 as that is the write buffer's size.
299332
//
300-
// Defaults to 512.
333+
// Defaults to 256.
301334
Threshold int
302335
}
303336

@@ -319,25 +352,32 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon
319352
return c, r, nil
320353
}
321354

322-
func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) {
355+
func (opts *DialOptions) ensure() (*DialOptions, error) {
323356
if opts == nil {
324357
opts = &DialOptions{}
358+
} else {
359+
opts = &*opts
325360
}
326361

327-
// Shallow copy to ensure defaults do not affect user passed options.
328-
opts2 := *opts
329-
opts = &opts2
330-
331362
if opts.HTTPClient == nil {
332363
opts.HTTPClient = http.DefaultClient
333364
}
334365
if opts.HTTPClient.Timeout > 0 {
335-
return nil, nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
366+
return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
336367
}
337368
if opts.HTTPHeader == nil {
338369
opts.HTTPHeader = http.Header{}
339370
}
340371

372+
return opts, nil
373+
}
374+
375+
func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) {
376+
opts, err = opts.ensure()
377+
if err != nil {
378+
return nil, nil, err
379+
}
380+
341381
parsedURL, err := url.Parse(u)
342382
if err != nil {
343383
return nil, nil, fmt.Errorf("failed to parse url: %w", err)
@@ -367,7 +407,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
367407
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
368408
}
369409
if opts.Compression != nil {
370-
req.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
410+
opts.Compression.setHeader(req.Header)
371411
}
372412

373413
resp, err := opts.HTTPClient.Do(req)
@@ -384,7 +424,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
384424
}
385425
}()
386426

387-
err = verifyServerResponse(req, resp)
427+
copts, err := verifyServerResponse(req, resp, opts)
388428
if err != nil {
389429
return nil, resp, err
390430
}
@@ -400,38 +440,48 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
400440
bw: getBufioWriter(rwc),
401441
closer: rwc,
402442
client: true,
443+
copts: copts,
403444
}
404445
c.extractBufioWriterBuf(rwc)
405446
c.init()
406447

407448
return c, resp, nil
408449
}
409450

410-
func verifyServerResponse(r *http.Request, resp *http.Response) error {
451+
func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*CompressionOptions, error) {
411452
if resp.StatusCode != http.StatusSwitchingProtocols {
412-
return fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
453+
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
413454
}
414455

415-
if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") {
416-
return fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
456+
if !headerContainsToken(resp.Header, "Connection", "Upgrade") {
457+
return nil, fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
417458
}
418459

419-
if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") {
420-
return fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
460+
if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") {
461+
return nil, fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
421462
}
422463

423464
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) {
424-
return fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
465+
return nil, fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
425466
resp.Header.Get("Sec-WebSocket-Accept"),
426467
r.Header.Get("Sec-WebSocket-Key"),
427468
)
428469
}
429470

430-
if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) {
431-
return fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
471+
if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) {
472+
return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
432473
}
433474

434-
return nil
475+
var copts *CompressionOptions
476+
if opts.Compression != nil {
477+
var err error
478+
copts, err = negotiateCompression(resp.Header, opts.Compression)
479+
if err != nil {
480+
return nil, err
481+
}
482+
}
483+
484+
return copts, nil
435485
}
436486

437487
// The below pools can only be used by the client because http.Hijacker will always
@@ -477,3 +527,46 @@ func makeSecWebSocketKey() (string, error) {
477527
}
478528
return base64.StdEncoding.EncodeToString(b), nil
479529
}
530+
531+
func negotiateCompression(h http.Header, copts *CompressionOptions) (*CompressionOptions, error) {
532+
deflate := headerTokenHasPrefix(h, "Sec-WebSocket-Extensions", "permessage-deflate")
533+
if deflate == "" {
534+
return nil, nil
535+
}
536+
537+
// Ensures our changes do not modify the real compression options.
538+
copts = &*copts
539+
540+
params := strings.Split(deflate, ";")
541+
for i := range params {
542+
params[i] = strings.TrimSpace(params[i])
543+
}
544+
545+
if params[0] != "permessage-deflate" {
546+
return nil, fmt.Errorf("unexpected header format for permessage-deflate extension: %q", deflate)
547+
}
548+
549+
for _, p := range params[1:] {
550+
switch p {
551+
case "client_no_context_takeover":
552+
copts.ClientNoContextTakeover = true
553+
case "server_no_context_takeover":
554+
copts.ServerNoContextTakeover = true
555+
default:
556+
return nil, fmt.Errorf("unexpected permessage-deflate parameter %q in header: %q", p, deflate)
557+
}
558+
}
559+
560+
return copts, nil
561+
}
562+
563+
func (copts *CompressionOptions) setHeader(h http.Header) {
564+
s := "permessage-deflate"
565+
if copts.ClientNoContextTakeover {
566+
s += "; client_no_context_takeover"
567+
}
568+
if copts.ServerNoContextTakeover {
569+
s += "; server_no_context_takeover"
570+
}
571+
h.Set("Sec-WebSocket-Extensions", s)
572+
}

handshake_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ func Test_verifyServerHandshake(t *testing.T) {
377377
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
378378
}
379379

380-
err = verifyServerResponse(r, resp)
380+
_, err = verifyServerResponse(r, resp, &DialOptions{})
381381
if (err == nil) != tc.success {
382382
t.Fatalf("unexpected error: %+v", err)
383383
}

0 commit comments

Comments
 (0)