@@ -76,7 +76,7 @@ type Transport struct {
76
76
idleLRU connLRU
77
77
78
78
reqMu sync.Mutex
79
- reqCanceler map [* Request ]func ()
79
+ reqCanceler map [* Request ]func (error )
80
80
81
81
altMu sync.RWMutex
82
82
altProto map [string ]RoundTripper // nil or map of URI scheme => RoundTripper
@@ -498,12 +498,17 @@ func (t *Transport) CloseIdleConnections() {
498
498
// cancelable context instead. CancelRequest cannot cancel HTTP/2
499
499
// requests.
500
500
func (t * Transport ) CancelRequest (req * Request ) {
501
+ t .cancelRequest (req , errRequestCanceled )
502
+ }
503
+
504
+ // Cancel an in-flight request, recording the error value.
505
+ func (t * Transport ) cancelRequest (req * Request , err error ) {
501
506
t .reqMu .Lock ()
502
507
cancel := t .reqCanceler [req ]
503
508
delete (t .reqCanceler , req )
504
509
t .reqMu .Unlock ()
505
510
if cancel != nil {
506
- cancel ()
511
+ cancel (err )
507
512
}
508
513
}
509
514
@@ -783,11 +788,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) {
783
788
}
784
789
}
785
790
786
- func (t * Transport ) setReqCanceler (r * Request , fn func ()) {
791
+ func (t * Transport ) setReqCanceler (r * Request , fn func (error )) {
787
792
t .reqMu .Lock ()
788
793
defer t .reqMu .Unlock ()
789
794
if t .reqCanceler == nil {
790
- t .reqCanceler = make (map [* Request ]func ())
795
+ t .reqCanceler = make (map [* Request ]func (error ))
791
796
}
792
797
if fn != nil {
793
798
t .reqCanceler [r ] = fn
@@ -800,7 +805,7 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) {
800
805
// for the request, we don't set the function and return false.
801
806
// Since CancelRequest will clear the canceler, we can use the return value to detect if
802
807
// the request was canceled since the last setReqCancel call.
803
- func (t * Transport ) replaceReqCanceler (r * Request , fn func ()) bool {
808
+ func (t * Transport ) replaceReqCanceler (r * Request , fn func (error )) bool {
804
809
t .reqMu .Lock ()
805
810
defer t .reqMu .Unlock ()
806
811
_ , ok := t .reqCanceler [r ]
@@ -849,7 +854,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
849
854
// set request canceler to some non-nil function so we
850
855
// can detect whether it was cleared between now and when
851
856
// we enter roundTrip
852
- t .setReqCanceler (req , func () {})
857
+ t .setReqCanceler (req , func (error ) {})
853
858
return pc , nil
854
859
}
855
860
@@ -874,8 +879,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
874
879
}()
875
880
}
876
881
877
- cancelc := make (chan struct {} )
878
- t .setReqCanceler (req , func () { close ( cancelc ) })
882
+ cancelc := make (chan error , 1 )
883
+ t .setReqCanceler (req , func (err error ) { cancelc <- err })
879
884
880
885
go func () {
881
886
pc , err := t .dialConn (ctx , cm )
@@ -897,7 +902,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
897
902
select {
898
903
case <- req .Cancel :
899
904
case <- req .Context ().Done ():
900
- case <- cancelc :
905
+ return nil , req .Context ().Err ()
906
+ case err := <- cancelc :
907
+ if err == errRequestCanceled {
908
+ err = errRequestCanceledConn
909
+ }
910
+ return nil , err
901
911
default :
902
912
// It wasn't an error due to cancelation, so
903
913
// return the original error message:
@@ -922,10 +932,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
922
932
return nil , errRequestCanceledConn
923
933
case <- req .Context ().Done ():
924
934
handlePendingDial ()
925
- return nil , errRequestCanceledConn
926
- case <- cancelc :
935
+ return nil , req . Context (). Err ()
936
+ case err := <- cancelc :
927
937
handlePendingDial ()
928
- return nil , errRequestCanceledConn
938
+ if err == errRequestCanceled {
939
+ err = errRequestCanceledConn
940
+ }
941
+ return nil , err
929
942
}
930
943
}
931
944
@@ -1231,8 +1244,8 @@ type persistConn struct {
1231
1244
mu sync.Mutex // guards following fields
1232
1245
numExpectedResponses int
1233
1246
closed error // set non-nil when conn is closed, before closech is closed
1247
+ canceledErr error // set non-nil if conn is canceled
1234
1248
broken bool // an error has happened on this connection; marked broken so it's not reused.
1235
- canceled bool // whether this conn was broken due a CancelRequest
1236
1249
reused bool // whether conn has had successful request/response and is being reused.
1237
1250
// mutateHeaderFunc is an optional func to modify extra
1238
1251
// headers on each outbound request before it's written. (the
@@ -1270,11 +1283,12 @@ func (pc *persistConn) isBroken() bool {
1270
1283
return b
1271
1284
}
1272
1285
1273
- // isCanceled reports whether this connection was closed due to CancelRequest.
1274
- func (pc * persistConn ) isCanceled () bool {
1286
+ // canceled returns non-nil if the connection was closed due to
1287
+ // CancelRequest or due to context cancelation.
1288
+ func (pc * persistConn ) canceled () error {
1275
1289
pc .mu .Lock ()
1276
1290
defer pc .mu .Unlock ()
1277
- return pc .canceled
1291
+ return pc .canceledErr
1278
1292
}
1279
1293
1280
1294
// isReused reports whether this connection is in a known broken state.
@@ -1297,10 +1311,10 @@ func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnIn
1297
1311
return
1298
1312
}
1299
1313
1300
- func (pc * persistConn ) cancelRequest () {
1314
+ func (pc * persistConn ) cancelRequest (err error ) {
1301
1315
pc .mu .Lock ()
1302
1316
defer pc .mu .Unlock ()
1303
- pc .canceled = true
1317
+ pc .canceledErr = err
1304
1318
pc .closeLocked (errRequestCanceled )
1305
1319
}
1306
1320
@@ -1328,8 +1342,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
1328
1342
if err == nil {
1329
1343
return nil
1330
1344
}
1331
- if pc .isCanceled () {
1332
- return errRequestCanceled
1345
+ if err := pc .canceled (); err != nil {
1346
+ return err
1333
1347
}
1334
1348
if err == errServerClosedIdle {
1335
1349
return err
@@ -1351,8 +1365,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
1351
1365
// its pc.closech channel close, indicating the persistConn is dead.
1352
1366
// (after closech is closed, pc.closed is valid).
1353
1367
func (pc * persistConn ) mapRoundTripErrorAfterClosed (startBytesWritten int64 ) error {
1354
- if pc .isCanceled () {
1355
- return errRequestCanceled
1368
+ if err := pc .canceled (); err != nil {
1369
+ return err
1356
1370
}
1357
1371
err := pc .closed
1358
1372
if err == errServerClosedIdle {
@@ -1509,8 +1523,10 @@ func (pc *persistConn) readLoop() {
1509
1523
waitForBodyRead <- isEOF
1510
1524
if isEOF {
1511
1525
<- eofc // see comment above eofc declaration
1512
- } else if err != nil && pc .isCanceled () {
1513
- return errRequestCanceled
1526
+ } else if err != nil {
1527
+ if cerr := pc .canceled (); cerr != nil {
1528
+ return cerr
1529
+ }
1514
1530
}
1515
1531
return err
1516
1532
},
@@ -1550,7 +1566,7 @@ func (pc *persistConn) readLoop() {
1550
1566
pc .t .CancelRequest (rc .req )
1551
1567
case <- rc .req .Context ().Done ():
1552
1568
alive = false
1553
- pc .t .CancelRequest (rc .req )
1569
+ pc .t .cancelRequest (rc .req , rc . req . Context (). Err () )
1554
1570
case <- pc .closech :
1555
1571
alive = false
1556
1572
}
@@ -1836,8 +1852,8 @@ WaitResponse:
1836
1852
select {
1837
1853
case err := <- writeErrCh :
1838
1854
if err != nil {
1839
- if pc .isCanceled () {
1840
- err = errRequestCanceled
1855
+ if cerr := pc .canceled (); cerr != nil {
1856
+ err = cerr
1841
1857
}
1842
1858
re = responseAndError {err : err }
1843
1859
pc .close (fmt .Errorf ("write error: %v" , err ))
@@ -1861,9 +1877,8 @@ WaitResponse:
1861
1877
case <- cancelChan :
1862
1878
pc .t .CancelRequest (req .Request )
1863
1879
cancelChan = nil
1864
- ctxDoneChan = nil
1865
1880
case <- ctxDoneChan :
1866
- pc .t .CancelRequest (req .Request )
1881
+ pc .t .cancelRequest (req .Request , req . Context (). Err () )
1867
1882
cancelChan = nil
1868
1883
ctxDoneChan = nil
1869
1884
}
0 commit comments