@@ -59,13 +59,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
59
59
return err
60
60
}
61
61
62
- if ! headerValuesContainsToken (r .Header , "Connection" , "Upgrade" ) {
62
+ if ! headerContainsToken (r .Header , "Connection" , "Upgrade" ) {
63
63
err := fmt .Errorf ("websocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
64
64
http .Error (w , err .Error (), http .StatusBadRequest )
65
65
return err
66
66
}
67
67
68
- if ! headerValuesContainsToken (r .Header , "Upgrade" , "WebSocket" ) {
68
+ if ! headerContainsToken (r .Header , "Upgrade" , "WebSocket" ) {
69
69
err := fmt .Errorf ("websocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
70
70
http .Error (w , err .Error (), http .StatusBadRequest )
71
71
return err
@@ -144,6 +144,18 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
144
144
w .Header ().Set ("Sec-WebSocket-Protocol" , subproto )
145
145
}
146
146
147
+ var copts * CompressionOptions
148
+ if opts .Compression != nil {
149
+ copts , err = negotiateCompression (r .Header , opts .Compression )
150
+ if err != nil {
151
+ http .Error (w , err .Error (), http .StatusBadRequest )
152
+ return nil , err
153
+ }
154
+ if copts != nil {
155
+ copts .setHeader (w .Header ())
156
+ }
157
+ }
158
+
147
159
w .WriteHeader (http .StatusSwitchingProtocols )
148
160
149
161
netConn , brw , err := hj .Hijack ()
@@ -162,40 +174,65 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
162
174
br : brw .Reader ,
163
175
bw : brw .Writer ,
164
176
closer : netConn ,
177
+ copts : copts ,
165
178
}
166
179
c .init ()
167
180
168
181
return c , nil
169
182
}
170
183
171
- func headerValuesContainsToken (h http.Header , key , token string ) bool {
184
+ func headerContainsToken (h http.Header , key , token string ) bool {
172
185
key = textproto .CanonicalMIMEHeaderKey (key )
173
186
174
- for _ , val2 := range h [key ] {
175
- if headerValueContainsToken (val2 , token ) {
187
+ token = strings .ToLower (token )
188
+ match := func (t string ) bool {
189
+ return t == token
190
+ }
191
+
192
+ for _ , v := range h [key ] {
193
+ if searchHeaderTokens (v , match ) != "" {
176
194
return true
177
195
}
178
196
}
179
197
180
198
return false
181
199
}
182
200
183
- func headerValueContainsToken ( val2 , token string ) bool {
184
- val2 = strings . TrimSpace ( val2 )
201
+ func headerTokenHasPrefix ( h http. Header , key , prefix string ) string {
202
+ key = textproto . CanonicalMIMEHeaderKey ( key )
185
203
186
- for _ , val2 := range strings .Split (val2 , "," ) {
187
- val2 = strings .TrimSpace (val2 )
188
- if strings .EqualFold (val2 , token ) {
189
- return true
204
+ prefix = strings .ToLower (prefix )
205
+ match := func (t string ) bool {
206
+ return strings .HasPrefix (t , prefix )
207
+ }
208
+
209
+ for _ , v := range h [key ] {
210
+ found := searchHeaderTokens (v , match )
211
+ if found != "" {
212
+ return found
190
213
}
191
214
}
192
215
193
- return false
216
+ return ""
217
+ }
218
+
219
+ func searchHeaderTokens (v string , match func (val string ) bool ) string {
220
+ v = strings .TrimSpace (v )
221
+
222
+ for _ , v2 := range strings .Split (v , "," ) {
223
+ v2 = strings .TrimSpace (v2 )
224
+ v2 = strings .ToLower (v2 )
225
+ if match (v2 ) {
226
+ return v2
227
+ }
228
+ }
229
+
230
+ return ""
194
231
}
195
232
196
233
func selectSubprotocol (r * http.Request , subprotocols []string ) string {
197
234
for _ , sp := range subprotocols {
198
- if headerValuesContainsToken (r .Header , "Sec-WebSocket-Protocol" , sp ) {
235
+ if headerContainsToken (r .Header , "Sec-WebSocket-Protocol" , sp ) {
199
236
return sp
200
237
}
201
238
}
@@ -268,36 +305,32 @@ type DialOptions struct {
268
305
//
269
306
// See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression.
270
307
//
271
- // Enabling compression will increase memory and CPU usage.
272
- // Thus it is not ideal for every use case and disabled by default .
308
+ // Enabling compression will increase memory and CPU usage and should
309
+ // be profiled before enabling in production .
273
310
// See https://github.com/gorilla/websocket/issues/203
274
- // Profile before enabling in production.
275
311
//
276
312
// This API is experimental and subject to change.
277
313
type CompressionOptions struct {
278
- // ServerNoContextTakeover controls whether the server should use context takeover.
279
- // See docs on CompressionOptions for discussion regarding context takeover.
280
- //
281
- // If set by the client, will guarantee that the server does not use context takeover.
282
- ServerNoContextTakeover bool
283
-
284
314
// ClientNoContextTakeover controls whether the client should use context takeover.
285
315
// See docs on CompressionOptions for discussion regarding context takeover.
286
316
//
287
317
// If set by the server, will guarantee that the client does not use context takeover.
288
318
ClientNoContextTakeover bool
289
319
320
+ // ServerNoContextTakeover controls whether the server should use context takeover.
321
+ // See docs on CompressionOptions for discussion regarding context takeover.
322
+ //
323
+ // If set by the client, will guarantee that the server does not use context takeover.
324
+ ServerNoContextTakeover bool
325
+
290
326
// Level controls the compression level used.
291
327
// Defaults to flate.BestSpeed.
292
328
Level int
293
329
294
330
// Threshold controls the minimum message size in bytes before compression is used.
295
- // In the case of ContextTakeover == false, a flate.Writer will not be grabbed
296
- // from the pool until the message exceeds this threshold.
297
- //
298
331
// Must not be greater than 4096 as that is the write buffer's size.
299
332
//
300
- // Defaults to 512 .
333
+ // Defaults to 256 .
301
334
Threshold int
302
335
}
303
336
@@ -319,25 +352,32 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon
319
352
return c , r , nil
320
353
}
321
354
322
- func dial ( ctx context. Context , u string , opts * DialOptions ) ( _ * Conn , _ * http. Response , err error ) {
355
+ func ( opts * DialOptions ) ensure () ( * DialOptions , error ) {
323
356
if opts == nil {
324
357
opts = & DialOptions {}
358
+ } else {
359
+ opts = & * opts
325
360
}
326
361
327
- // Shallow copy to ensure defaults do not affect user passed options.
328
- opts2 := * opts
329
- opts = & opts2
330
-
331
362
if opts .HTTPClient == nil {
332
363
opts .HTTPClient = http .DefaultClient
333
364
}
334
365
if opts .HTTPClient .Timeout > 0 {
335
- return nil , nil , fmt .Errorf ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
366
+ return nil , fmt .Errorf ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
336
367
}
337
368
if opts .HTTPHeader == nil {
338
369
opts .HTTPHeader = http.Header {}
339
370
}
340
371
372
+ return opts , nil
373
+ }
374
+
375
+ func dial (ctx context.Context , u string , opts * DialOptions ) (_ * Conn , _ * http.Response , err error ) {
376
+ opts , err = opts .ensure ()
377
+ if err != nil {
378
+ return nil , nil , err
379
+ }
380
+
341
381
parsedURL , err := url .Parse (u )
342
382
if err != nil {
343
383
return nil , nil , fmt .Errorf ("failed to parse url: %w" , err )
@@ -367,7 +407,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
367
407
req .Header .Set ("Sec-WebSocket-Protocol" , strings .Join (opts .Subprotocols , "," ))
368
408
}
369
409
if opts .Compression != nil {
370
- req . Header . Set ( "Sec-WebSocket-Extensions" , "permessage-deflate; server_no_context_takeover; client_no_context_takeover" )
410
+ opts . Compression . setHeader ( req . Header )
371
411
}
372
412
373
413
resp , err := opts .HTTPClient .Do (req )
@@ -384,7 +424,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
384
424
}
385
425
}()
386
426
387
- err = verifyServerResponse (req , resp )
427
+ copts , err : = verifyServerResponse (req , resp , opts )
388
428
if err != nil {
389
429
return nil , resp , err
390
430
}
@@ -400,38 +440,48 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
400
440
bw : getBufioWriter (rwc ),
401
441
closer : rwc ,
402
442
client : true ,
443
+ copts : copts ,
403
444
}
404
445
c .extractBufioWriterBuf (rwc )
405
446
c .init ()
406
447
407
448
return c , resp , nil
408
449
}
409
450
410
- func verifyServerResponse (r * http.Request , resp * http.Response ) error {
451
+ func verifyServerResponse (r * http.Request , resp * http.Response , opts * DialOptions ) ( * CompressionOptions , error ) {
411
452
if resp .StatusCode != http .StatusSwitchingProtocols {
412
- return fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
453
+ return nil , fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
413
454
}
414
455
415
- if ! headerValuesContainsToken (resp .Header , "Connection" , "Upgrade" ) {
416
- return fmt .Errorf ("websocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
456
+ if ! headerContainsToken (resp .Header , "Connection" , "Upgrade" ) {
457
+ return nil , fmt .Errorf ("websocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
417
458
}
418
459
419
- if ! headerValuesContainsToken (resp .Header , "Upgrade" , "WebSocket" ) {
420
- return fmt .Errorf ("websocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
460
+ if ! headerContainsToken (resp .Header , "Upgrade" , "WebSocket" ) {
461
+ return nil , fmt .Errorf ("websocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
421
462
}
422
463
423
464
if resp .Header .Get ("Sec-WebSocket-Accept" ) != secWebSocketAccept (r .Header .Get ("Sec-WebSocket-Key" )) {
424
- return fmt .Errorf ("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q" ,
465
+ return nil , fmt .Errorf ("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q" ,
425
466
resp .Header .Get ("Sec-WebSocket-Accept" ),
426
467
r .Header .Get ("Sec-WebSocket-Key" ),
427
468
)
428
469
}
429
470
430
- if proto := resp .Header .Get ("Sec-WebSocket-Protocol" ); proto != "" && ! headerValuesContainsToken (r .Header , "Sec-WebSocket-Protocol" , proto ) {
431
- return fmt .Errorf ("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q" , proto )
471
+ if proto := resp .Header .Get ("Sec-WebSocket-Protocol" ); proto != "" && ! headerContainsToken (r .Header , "Sec-WebSocket-Protocol" , proto ) {
472
+ return nil , fmt .Errorf ("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q" , proto )
432
473
}
433
474
434
- return nil
475
+ var copts * CompressionOptions
476
+ if opts .Compression != nil {
477
+ var err error
478
+ copts , err = negotiateCompression (resp .Header , opts .Compression )
479
+ if err != nil {
480
+ return nil , err
481
+ }
482
+ }
483
+
484
+ return copts , nil
435
485
}
436
486
437
487
// The below pools can only be used by the client because http.Hijacker will always
@@ -477,3 +527,46 @@ func makeSecWebSocketKey() (string, error) {
477
527
}
478
528
return base64 .StdEncoding .EncodeToString (b ), nil
479
529
}
530
+
531
+ func negotiateCompression (h http.Header , copts * CompressionOptions ) (* CompressionOptions , error ) {
532
+ deflate := headerTokenHasPrefix (h , "Sec-WebSocket-Extensions" , "permessage-deflate" )
533
+ if deflate == "" {
534
+ return nil , nil
535
+ }
536
+
537
+ // Ensures our changes do not modify the real compression options.
538
+ copts = & * copts
539
+
540
+ params := strings .Split (deflate , ";" )
541
+ for i := range params {
542
+ params [i ] = strings .TrimSpace (params [i ])
543
+ }
544
+
545
+ if params [0 ] != "permessage-deflate" {
546
+ return nil , fmt .Errorf ("unexpected header format for permessage-deflate extension: %q" , deflate )
547
+ }
548
+
549
+ for _ , p := range params [1 :] {
550
+ switch p {
551
+ case "client_no_context_takeover" :
552
+ copts .ClientNoContextTakeover = true
553
+ case "server_no_context_takeover" :
554
+ copts .ServerNoContextTakeover = true
555
+ default :
556
+ return nil , fmt .Errorf ("unexpected permessage-deflate parameter %q in header: %q" , p , deflate )
557
+ }
558
+ }
559
+
560
+ return copts , nil
561
+ }
562
+
563
+ func (copts * CompressionOptions ) setHeader (h http.Header ) {
564
+ s := "permessage-deflate"
565
+ if copts .ClientNoContextTakeover {
566
+ s += "; client_no_context_takeover"
567
+ }
568
+ if copts .ServerNoContextTakeover {
569
+ s += "; server_no_context_takeover"
570
+ }
571
+ h .Set ("Sec-WebSocket-Extensions" , s )
572
+ }
0 commit comments