diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index 1fe4eed3f71234..9a80bb9502540f 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -1770,3 +1770,28 @@ func testEarlyHintsRequest(t *testing.T, mode testMode) { t.Errorf("Read body %q; want Hello", body) } } + +// Issue 53808 +func TestServerReadAfterHandlerDone100Continue(t *testing.T) { + run(t, testServerReadAfterHandlerDone100Continue) +} +func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) { + readyc := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + go func() { + <-readyc + io.ReadAll(r.Body) + <-readyc + }() + })) + + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body")) + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err != nil { + t.Fatalf("Get(%q) = %v", cst.ts.URL, err) + } + res.Body.Close() + readyc <- struct{}{} // server starts reading from the request body + readyc <- struct{}{} // server finishes reading from the request body +} diff --git a/src/net/http/server.go b/src/net/http/server.go index a50b20b7da905b..a43959a6be937d 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -425,19 +425,20 @@ type response struct { reqBody io.ReadCloser cancelCtx context.CancelFunc // when ServeHTTP exits wroteHeader bool // a non-1xx header has been (logically) written - wroteContinue bool // 100 Continue response was written wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" wantsClose bool // HTTP request has Connection "close" // canWriteContinue is an atomic boolean that says whether or // not a 100 Continue header can be written to the // connection. - // writeContinueMu must be held while writing the header. + // write1xxMu must be held while writing a 1xx response + // or setting canWriteContinue. // These two fields together synchronize the body reader (the // expectContinueReader, which wants to write 100 Continue) // against the main writer. canWriteContinue atomic.Bool - writeContinueMu sync.Mutex + write1xxMu sync.Mutex + wrote1xx bool // 1xx response was written w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter @@ -916,17 +917,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { if ecr.closed.Load() { return 0, ErrBodyReadAfterClose } - w := ecr.resp - if !w.wroteContinue && w.canWriteContinue.Load() && !w.conn.hijacked() { - w.wroteContinue = true - w.writeContinueMu.Lock() - if w.canWriteContinue.Load() { - w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") - w.conn.bufw.Flush() - w.canWriteContinue.Store(false) - } - w.writeContinueMu.Unlock() - } + ecr.resp.writeContinueOnce() n, err = ecr.readCloser.Read(p) if err == io.EOF { ecr.sawEOF.Store(true) @@ -1164,11 +1155,11 @@ func (w *response) WriteHeader(code int) { // We shouldn't send any further headers after 101 Switching Protocols, // so it takes the non-informational path. if code >= 100 && code <= 199 && code != StatusSwitchingProtocols { + w.write1xxMu.Lock() + // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read() if code == 100 && w.canWriteContinue.Load() { - w.writeContinueMu.Lock() w.canWriteContinue.Store(false) - w.writeContinueMu.Unlock() } writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:]) @@ -1178,6 +1169,8 @@ func (w *response) WriteHeader(code int) { w.conn.bufw.Write(crlf) w.conn.bufw.Flush() + w.wrote1xx = true + w.write1xxMu.Unlock() return } @@ -1381,11 +1374,13 @@ func (cw *chunkWriter) writeHeader(p []byte) { if w.req.ContentLength != 0 && !w.closeAfterReply && !w.fullDuplex { var discard, tooBig bool + w.write1xxMu.Lock() + if w.wrote1xx { + discard = true + } + w.write1xxMu.Unlock() + switch bdy := w.req.Body.(type) { - case *expectContinueReader: - if bdy.resp.wroteContinue { - discard = true - } case *body: bdy.mu.Lock() switch { @@ -1625,15 +1620,7 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er return 0, ErrHijacked } - if w.canWriteContinue.Load() { - // Body reader wants to write 100 Continue but hasn't yet. - // Tell it not to. The store must be done while holding the lock - // because the lock makes sure that there is not an active write - // this very moment. - w.writeContinueMu.Lock() - w.canWriteContinue.Store(false) - w.writeContinueMu.Unlock() - } + w.disableContinue() if !w.wroteHeader { w.WriteHeader(StatusOK) @@ -1679,6 +1666,31 @@ func (w *response) finishRequest() { } } +// disableContinue disables writes of the 100 Continue status. +// If an existing write is in progress, it waits for it to complete +// before returning. +func (w *response) disableContinue() { + if w.canWriteContinue.Load() { + w.write1xxMu.Lock() + w.canWriteContinue.Store(false) + w.write1xxMu.Unlock() + } +} + +// writeContinueOnce writes 100 Continue status if allowed. +func (w *response) writeContinueOnce() { + if w.canWriteContinue.Load() && !w.conn.hijacked() { + w.write1xxMu.Lock() + if w.canWriteContinue.Load() { + w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") + w.conn.bufw.Flush() + w.canWriteContinue.Store(false) + w.wrote1xx = true + } + w.write1xxMu.Unlock() + } +} + // shouldReuseConnection reports whether the underlying TCP connection can be reused. // It must only be called after the handler is done executing. func (w *response) shouldReuseConnection() bool { @@ -1905,6 +1917,7 @@ func (c *conn) serve(ctx context.Context) { if inFlightResponse != nil { inFlightResponse.conn.r.abortPendingRead() inFlightResponse.reqBody.Close() + inFlightResponse.disableContinue() } c.close() c.setState(c.rwc, StateClosed, runHooks) @@ -2046,6 +2059,7 @@ func (c *conn) serve(ctx context.Context) { return } w.finishRequest() + w.disableContinue() c.rwc.SetWriteDeadline(time.Time{}) if !w.shouldReuseConnection() { if w.requestBodyLimitHit || w.closedRequestBodyEarly() { diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index fa147e164ed843..1526a66890f902 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -6918,19 +6918,31 @@ func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) { panic(ErrAbortHandler) })).ts + newRequest := func() *Request { + const reqLen = 6 * 1024 * 1024 + req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) + req.ContentLength = reqLen + return req + } + var wg sync.WaitGroup for i := 0; i < 2; i++ { wg.Add(1) go func() { defer wg.Done() for j := 0; j < 10; j++ { - const reqLen = 6 * 1024 * 1024 - req, _ := NewRequest("POST", ts.URL, &io.LimitedReader{R: neverEnding('x'), N: reqLen}) - req.ContentLength = reqLen + req := newRequest() resp, _ := ts.Client().Transport.RoundTrip(req) if resp != nil { resp.Body.Close() } + + req = newRequest() + req.Header.Set("Expect", "100-continue") + resp, _ = ts.Client().Transport.RoundTrip(req) + if resp != nil { + resp.Body.Close() + } } }() }