@@ -588,6 +588,107 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
588
588
<- reqComplete
589
589
}
590
590
591
+ func TestTransportMaxConnsPerHost (t * testing.T ) {
592
+ defer afterTest (t )
593
+ if runtime .GOOS == "js" {
594
+ t .Skipf ("skipping test on js/wasm" )
595
+ }
596
+ h := HandlerFunc (func (w ResponseWriter , r * Request ) {
597
+ _ , err := w .Write ([]byte ("foo" ))
598
+ if err != nil {
599
+ t .Fatalf ("Write: %v" , err )
600
+ }
601
+ })
602
+
603
+ testMaxConns := func (scheme string , ts * httptest.Server ) {
604
+ defer ts .Close ()
605
+
606
+ c := ts .Client ()
607
+ tr := c .Transport .(* Transport )
608
+ tr .MaxConnsPerHost = 1
609
+ if err := ExportHttp2ConfigureTransport (tr ); err != nil {
610
+ t .Fatalf ("ExportHttp2ConfigureTransport: %v" , err )
611
+ }
612
+
613
+ connCh := make (chan net.Conn , 1 )
614
+ var dialCnt , gotConnCnt , tlsHandshakeCnt int32
615
+ tr .Dial = func (network , addr string ) (net.Conn , error ) {
616
+ atomic .AddInt32 (& dialCnt , 1 )
617
+ c , err := net .Dial (network , addr )
618
+ connCh <- c
619
+ return c , err
620
+ }
621
+
622
+ doReq := func () {
623
+ trace := & httptrace.ClientTrace {
624
+ GotConn : func (connInfo httptrace.GotConnInfo ) {
625
+ if ! connInfo .Reused {
626
+ atomic .AddInt32 (& gotConnCnt , 1 )
627
+ }
628
+ },
629
+ TLSHandshakeStart : func () {
630
+ atomic .AddInt32 (& tlsHandshakeCnt , 1 )
631
+ },
632
+ }
633
+ req , _ := NewRequest ("GET" , ts .URL , nil )
634
+ req = req .WithContext (httptrace .WithClientTrace (req .Context (), trace ))
635
+
636
+ resp , err := c .Do (req )
637
+ if err != nil {
638
+ t .Fatalf ("request failed: %v" , err )
639
+ }
640
+ defer resp .Body .Close ()
641
+ _ , err = ioutil .ReadAll (resp .Body )
642
+ if err != nil {
643
+ t .Fatalf ("read body failed: %v" , err )
644
+ }
645
+ }
646
+
647
+ wg := sync.WaitGroup {}
648
+ for i := 0 ; i < 10 ; i ++ {
649
+ wg .Add (1 )
650
+ go func () {
651
+ defer wg .Done ()
652
+ doReq ()
653
+ }()
654
+ }
655
+ wg .Wait ()
656
+
657
+ expected := int32 (tr .MaxConnsPerHost )
658
+ if dialCnt != expected {
659
+ t .Errorf ("Too many dials (%s): %d" , scheme , dialCnt )
660
+ }
661
+ if gotConnCnt != expected {
662
+ t .Errorf ("Too many get connections (%s): %d" , scheme , gotConnCnt )
663
+ }
664
+ if ts .TLS != nil && tlsHandshakeCnt != expected {
665
+ t .Errorf ("Too many tls handshakes (%s): %d" , scheme , tlsHandshakeCnt )
666
+ }
667
+
668
+ (<- connCh ).Close ()
669
+
670
+ doReq ()
671
+ expected ++
672
+ if dialCnt != expected {
673
+ t .Errorf ("Too many dials (%s): %d" , scheme , dialCnt )
674
+ }
675
+ if gotConnCnt != expected {
676
+ t .Errorf ("Too many get connections (%s): %d" , scheme , gotConnCnt )
677
+ }
678
+ if ts .TLS != nil && tlsHandshakeCnt != expected {
679
+ t .Errorf ("Too many tls handshakes (%s): %d" , scheme , tlsHandshakeCnt )
680
+ }
681
+ }
682
+
683
+ testMaxConns ("http" , httptest .NewServer (h ))
684
+ testMaxConns ("https" , httptest .NewTLSServer (h ))
685
+
686
+ ts := httptest .NewUnstartedServer (h )
687
+ ts .TLS = & tls.Config {NextProtos : []string {"h2" }}
688
+ ts .StartTLS ()
689
+ testMaxConns ("http2" , ts )
690
+ }
691
+
591
692
func TestTransportRemovesDeadIdleConnections (t * testing.T ) {
592
693
setParallel (t )
593
694
defer afterTest (t )
0 commit comments