Skip to content

Commit 0b5f2f0

Browse files
net/http: if context is canceled, return its error
This permits the error message to distinguish between a context that was canceled and a context that timed out. Updates #16381. Change-Id: I3994b98e32952abcd7ddb5fee08fa1535999be6d Reviewed-on: https://go-review.googlesource.com/24978 Run-TryBot: Brad Fitzpatrick <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent 643b9ec commit 0b5f2f0

File tree

3 files changed

+57
-33
lines changed

3 files changed

+57
-33
lines changed

src/net/http/client_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,8 @@ func TestClientRedirectContext(t *testing.T) {
313313
if !ok {
314314
t.Fatalf("got error %T; want *url.Error", err)
315315
}
316-
if ue.Err != ExportErrRequestCanceled && ue.Err != ExportErrRequestCanceledConn {
317-
t.Errorf("url.Error.Err = %v; want errRequestCanceled or errRequestCanceledConn", ue.Err)
316+
if ue.Err != context.Canceled {
317+
t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled)
318318
}
319319
}
320320

src/net/http/transport.go

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ type Transport struct {
7676
idleLRU connLRU
7777

7878
reqMu sync.Mutex
79-
reqCanceler map[*Request]func()
79+
reqCanceler map[*Request]func(error)
8080

8181
altMu sync.RWMutex
8282
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
@@ -498,12 +498,17 @@ func (t *Transport) CloseIdleConnections() {
498498
// cancelable context instead. CancelRequest cannot cancel HTTP/2
499499
// requests.
500500
func (t *Transport) CancelRequest(req *Request) {
501+
t.cancelRequest(req, errRequestCanceled)
502+
}
503+
504+
// Cancel an in-flight request, recording the error value.
505+
func (t *Transport) cancelRequest(req *Request, err error) {
501506
t.reqMu.Lock()
502507
cancel := t.reqCanceler[req]
503508
delete(t.reqCanceler, req)
504509
t.reqMu.Unlock()
505510
if cancel != nil {
506-
cancel()
511+
cancel(err)
507512
}
508513
}
509514

@@ -783,11 +788,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) {
783788
}
784789
}
785790

786-
func (t *Transport) setReqCanceler(r *Request, fn func()) {
791+
func (t *Transport) setReqCanceler(r *Request, fn func(error)) {
787792
t.reqMu.Lock()
788793
defer t.reqMu.Unlock()
789794
if t.reqCanceler == nil {
790-
t.reqCanceler = make(map[*Request]func())
795+
t.reqCanceler = make(map[*Request]func(error))
791796
}
792797
if fn != nil {
793798
t.reqCanceler[r] = fn
@@ -800,7 +805,7 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) {
800805
// for the request, we don't set the function and return false.
801806
// Since CancelRequest will clear the canceler, we can use the return value to detect if
802807
// the request was canceled since the last setReqCancel call.
803-
func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
808+
func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool {
804809
t.reqMu.Lock()
805810
defer t.reqMu.Unlock()
806811
_, ok := t.reqCanceler[r]
@@ -849,7 +854,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
849854
// set request canceler to some non-nil function so we
850855
// can detect whether it was cleared between now and when
851856
// we enter roundTrip
852-
t.setReqCanceler(req, func() {})
857+
t.setReqCanceler(req, func(error) {})
853858
return pc, nil
854859
}
855860

@@ -874,8 +879,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
874879
}()
875880
}
876881

877-
cancelc := make(chan struct{})
878-
t.setReqCanceler(req, func() { close(cancelc) })
882+
cancelc := make(chan error, 1)
883+
t.setReqCanceler(req, func(err error) { cancelc <- err })
879884

880885
go func() {
881886
pc, err := t.dialConn(ctx, cm)
@@ -897,7 +902,12 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
897902
select {
898903
case <-req.Cancel:
899904
case <-req.Context().Done():
900-
case <-cancelc:
905+
return nil, req.Context().Err()
906+
case err := <-cancelc:
907+
if err == errRequestCanceled {
908+
err = errRequestCanceledConn
909+
}
910+
return nil, err
901911
default:
902912
// It wasn't an error due to cancelation, so
903913
// return the original error message:
@@ -922,10 +932,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
922932
return nil, errRequestCanceledConn
923933
case <-req.Context().Done():
924934
handlePendingDial()
925-
return nil, errRequestCanceledConn
926-
case <-cancelc:
935+
return nil, req.Context().Err()
936+
case err := <-cancelc:
927937
handlePendingDial()
928-
return nil, errRequestCanceledConn
938+
if err == errRequestCanceled {
939+
err = errRequestCanceledConn
940+
}
941+
return nil, err
929942
}
930943
}
931944

@@ -1231,8 +1244,8 @@ type persistConn struct {
12311244
mu sync.Mutex // guards following fields
12321245
numExpectedResponses int
12331246
closed error // set non-nil when conn is closed, before closech is closed
1247+
canceledErr error // set non-nil if conn is canceled
12341248
broken bool // an error has happened on this connection; marked broken so it's not reused.
1235-
canceled bool // whether this conn was broken due a CancelRequest
12361249
reused bool // whether conn has had successful request/response and is being reused.
12371250
// mutateHeaderFunc is an optional func to modify extra
12381251
// headers on each outbound request before it's written. (the
@@ -1270,11 +1283,12 @@ func (pc *persistConn) isBroken() bool {
12701283
return b
12711284
}
12721285

1273-
// isCanceled reports whether this connection was closed due to CancelRequest.
1274-
func (pc *persistConn) isCanceled() bool {
1286+
// canceled returns non-nil if the connection was closed due to
1287+
// CancelRequest or due to context cancelation.
1288+
func (pc *persistConn) canceled() error {
12751289
pc.mu.Lock()
12761290
defer pc.mu.Unlock()
1277-
return pc.canceled
1291+
return pc.canceledErr
12781292
}
12791293

12801294
// isReused reports whether this connection is in a known broken state.
@@ -1297,10 +1311,10 @@ func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnIn
12971311
return
12981312
}
12991313

1300-
func (pc *persistConn) cancelRequest() {
1314+
func (pc *persistConn) cancelRequest(err error) {
13011315
pc.mu.Lock()
13021316
defer pc.mu.Unlock()
1303-
pc.canceled = true
1317+
pc.canceledErr = err
13041318
pc.closeLocked(errRequestCanceled)
13051319
}
13061320

@@ -1328,8 +1342,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
13281342
if err == nil {
13291343
return nil
13301344
}
1331-
if pc.isCanceled() {
1332-
return errRequestCanceled
1345+
if err := pc.canceled(); err != nil {
1346+
return err
13331347
}
13341348
if err == errServerClosedIdle {
13351349
return err
@@ -1351,8 +1365,8 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er
13511365
// its pc.closech channel close, indicating the persistConn is dead.
13521366
// (after closech is closed, pc.closed is valid).
13531367
func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error {
1354-
if pc.isCanceled() {
1355-
return errRequestCanceled
1368+
if err := pc.canceled(); err != nil {
1369+
return err
13561370
}
13571371
err := pc.closed
13581372
if err == errServerClosedIdle {
@@ -1509,8 +1523,10 @@ func (pc *persistConn) readLoop() {
15091523
waitForBodyRead <- isEOF
15101524
if isEOF {
15111525
<-eofc // see comment above eofc declaration
1512-
} else if err != nil && pc.isCanceled() {
1513-
return errRequestCanceled
1526+
} else if err != nil {
1527+
if cerr := pc.canceled(); cerr != nil {
1528+
return cerr
1529+
}
15141530
}
15151531
return err
15161532
},
@@ -1550,7 +1566,7 @@ func (pc *persistConn) readLoop() {
15501566
pc.t.CancelRequest(rc.req)
15511567
case <-rc.req.Context().Done():
15521568
alive = false
1553-
pc.t.CancelRequest(rc.req)
1569+
pc.t.cancelRequest(rc.req, rc.req.Context().Err())
15541570
case <-pc.closech:
15551571
alive = false
15561572
}
@@ -1836,8 +1852,8 @@ WaitResponse:
18361852
select {
18371853
case err := <-writeErrCh:
18381854
if err != nil {
1839-
if pc.isCanceled() {
1840-
err = errRequestCanceled
1855+
if cerr := pc.canceled(); cerr != nil {
1856+
err = cerr
18411857
}
18421858
re = responseAndError{err: err}
18431859
pc.close(fmt.Errorf("write error: %v", err))
@@ -1861,9 +1877,8 @@ WaitResponse:
18611877
case <-cancelChan:
18621878
pc.t.CancelRequest(req.Request)
18631879
cancelChan = nil
1864-
ctxDoneChan = nil
18651880
case <-ctxDoneChan:
1866-
pc.t.CancelRequest(req.Request)
1881+
pc.t.cancelRequest(req.Request, req.Context().Err())
18671882
cancelChan = nil
18681883
ctxDoneChan = nil
18691884
}

src/net/http/transport_test.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,8 +1718,17 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) {
17181718
}
17191719

17201720
_, err := c.Do(req)
1721-
if err == nil || !strings.Contains(err.Error(), "canceled") {
1722-
t.Errorf("Do error = %v; want cancelation", err)
1721+
if ue, ok := err.(*url.Error); ok {
1722+
err = ue.Err
1723+
}
1724+
if withCtx {
1725+
if err != context.Canceled {
1726+
t.Errorf("Do error = %v; want %v", err, context.Canceled)
1727+
}
1728+
} else {
1729+
if err == nil || !strings.Contains(err.Error(), "canceled") {
1730+
t.Errorf("Do error = %v; want cancelation", err)
1731+
}
17231732
}
17241733
}
17251734

0 commit comments

Comments
 (0)