@@ -3393,63 +3393,54 @@ func TestTransportCloseAfterLostPing(t *testing.T) {
3393
3393
3394
3394
func TestTransportPingWhenReading (t * testing.T ) {
3395
3395
testCases := []struct {
3396
- name string
3397
- readIdleTimeout time.Duration
3398
- serverResponseInterval time.Duration
3399
- expectedPingCount int
3396
+ name string
3397
+ readIdleTimeout time.Duration
3398
+ deadline time.Duration
3399
+ expectedPingCount int
3400
3400
}{
3401
3401
{
3402
- name : "two pings in each serverResponseInterval " ,
3403
- readIdleTimeout : 400 * time .Millisecond ,
3404
- serverResponseInterval : 1000 * time .Millisecond ,
3405
- expectedPingCount : 4 ,
3402
+ name : "two pings" ,
3403
+ readIdleTimeout : 100 * time .Millisecond ,
3404
+ deadline : time .Second ,
3405
+ expectedPingCount : 2 ,
3406
3406
},
3407
3407
{
3408
- name : "one ping in each serverResponseInterval " ,
3409
- readIdleTimeout : 700 * time .Millisecond ,
3410
- serverResponseInterval : 1000 * time .Millisecond ,
3411
- expectedPingCount : 2 ,
3408
+ name : "zero ping" ,
3409
+ readIdleTimeout : time .Second ,
3410
+ deadline : 200 * time .Millisecond ,
3411
+ expectedPingCount : 0 ,
3412
3412
},
3413
3413
{
3414
- name : "zero ping in each serverResponseInterval" ,
3415
- readIdleTimeout : 1000 * time .Millisecond ,
3416
- serverResponseInterval : 500 * time .Millisecond ,
3417
- expectedPingCount : 0 ,
3418
- },
3419
- {
3420
- name : "0 readIdleTimeout means no ping" ,
3421
- readIdleTimeout : 0 * time .Millisecond ,
3422
- serverResponseInterval : 500 * time .Millisecond ,
3423
- expectedPingCount : 0 ,
3414
+ name : "0 readIdleTimeout means no ping" ,
3415
+ readIdleTimeout : 0 * time .Millisecond ,
3416
+ deadline : 500 * time .Millisecond ,
3417
+ expectedPingCount : 0 ,
3424
3418
},
3425
3419
}
3426
3420
3427
3421
for _ , tc := range testCases {
3428
3422
tc := tc // capture range variable
3429
3423
t .Run (tc .name , func (t * testing.T ) {
3430
- t .Parallel ()
3431
- testTransportPingWhenReading (t , tc .readIdleTimeout , tc .serverResponseInterval , tc .expectedPingCount )
3424
+ testTransportPingWhenReading (t , tc .readIdleTimeout , tc .deadline , tc .expectedPingCount )
3432
3425
})
3433
3426
}
3434
3427
}
3435
3428
3436
- func testTransportPingWhenReading (t * testing.T , readIdleTimeout , serverResponseInterval time.Duration , expectedPingCount int ) {
3429
+ func testTransportPingWhenReading (t * testing.T , readIdleTimeout , deadline time.Duration , expectedPingCount int ) {
3437
3430
var pingCount int
3438
- clientDone := make (chan struct {})
3439
3431
ct := newClientTester (t )
3440
3432
ct .tr .PingTimeout = 10 * time .Millisecond
3441
3433
ct .tr .ReadIdleTimeout = readIdleTimeout
3442
- // guards the ct.fr.Write
3443
- var wmu sync.Mutex
3444
3434
3435
+ ctx , cancel := context .WithTimeout (context .Background (), deadline )
3436
+ defer cancel ()
3445
3437
ct .client = func () error {
3446
3438
defer ct .cc .(* net.TCPConn ).CloseWrite ()
3447
3439
if runtime .GOOS == "plan9" {
3448
3440
// CloseWrite not supported on Plan 9; Issue 17906
3449
3441
defer ct .cc .(* net.TCPConn ).Close ()
3450
3442
}
3451
- defer close (clientDone )
3452
- req , _ := http .NewRequest ("GET" , "https://dummy.tld/" , nil )
3443
+ req , _ := http .NewRequestWithContext (ctx , "GET" , "https://dummy.tld/" , nil )
3453
3444
res , err := ct .tr .RoundTrip (req )
3454
3445
if err != nil {
3455
3446
return fmt .Errorf ("RoundTrip: %v" , err )
@@ -3459,20 +3450,24 @@ func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseI
3459
3450
return fmt .Errorf ("status code = %v; want %v" , res .StatusCode , 200 )
3460
3451
}
3461
3452
_ , err = ioutil .ReadAll (res .Body )
3453
+ if expectedPingCount == 0 && errors .Is (ctx .Err (), context .DeadlineExceeded ) {
3454
+ return nil
3455
+ }
3456
+
3457
+ cancel ()
3462
3458
return err
3463
3459
}
3464
3460
3465
3461
ct .server = func () error {
3466
3462
ct .greet ()
3467
3463
var buf bytes.Buffer
3468
3464
enc := hpack .NewEncoder (& buf )
3469
- var wg sync.WaitGroup
3470
- defer wg .Wait ()
3465
+ var streamID uint32
3471
3466
for {
3472
3467
f , err := ct .fr .ReadFrame ()
3473
3468
if err != nil {
3474
3469
select {
3475
- case <- clientDone :
3470
+ case <- ctx . Done () :
3476
3471
// If the client's done, it
3477
3472
// will have reported any
3478
3473
// errors on its side.
@@ -3494,46 +3489,24 @@ func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseI
3494
3489
EndStream : false ,
3495
3490
BlockFragment : buf .Bytes (),
3496
3491
})
3497
-
3498
- wg .Add (1 )
3499
- go func () {
3500
- defer wg .Done ()
3501
- for i := 0 ; i < 2 ; i ++ {
3502
- wmu .Lock ()
3503
- if err := ct .fr .WriteData (f .StreamID , false , []byte (fmt .Sprintf ("hello, this is server data frame %d" , i ))); err != nil {
3504
- wmu .Unlock ()
3505
- t .Error (err )
3506
- return
3507
- }
3508
- wmu .Unlock ()
3509
- time .Sleep (serverResponseInterval )
3510
- }
3511
- wmu .Lock ()
3512
- if err := ct .fr .WriteData (f .StreamID , true , []byte ("hello, this is last server data frame" )); err != nil {
3513
- wmu .Unlock ()
3514
- t .Error (err )
3515
- return
3516
- }
3517
- wmu .Unlock ()
3518
- }()
3492
+ streamID = f .StreamID
3519
3493
case * PingFrame :
3520
3494
pingCount ++
3521
- wmu .Lock ()
3495
+ if pingCount == expectedPingCount {
3496
+ if err := ct .fr .WriteData (streamID , true , []byte ("hello, this is last server data frame" )); err != nil {
3497
+ return err
3498
+ }
3499
+ }
3522
3500
if err := ct .fr .WritePing (true , f .Data ); err != nil {
3523
- wmu .Unlock ()
3524
3501
return err
3525
3502
}
3526
- wmu . Unlock ()
3503
+ case * RSTStreamFrame :
3527
3504
default :
3528
3505
return fmt .Errorf ("Unexpected client frame %v" , f )
3529
3506
}
3530
3507
}
3531
3508
}
3532
3509
ct .run ()
3533
- if e , a := expectedPingCount , pingCount ; e != a {
3534
- t .Errorf ("expected receiving %d pings, got %d pings" , e , a )
3535
-
3536
- }
3537
3510
}
3538
3511
3539
3512
func TestTransportRetryAfterGOAWAY (t * testing.T ) {
0 commit comments