diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 143874d70ad42b..3580bc9d4bbc2b 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -973,6 +973,47 @@ func TestOnlyWriteTimeout(t *testing.T) { } } +func TestErrorAfterWriteTimeout(t *testing.T) { + setParallel(t) + defer afterTest(t) + writeTimeout := 200 * time.Millisecond + var afterTimeoutErrc = make(chan error, 1) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + time.Sleep(2 * writeTimeout) + + _, err := w.Write([]byte("test")) + afterTimeoutErrc <- err + })) + ts.Config.WriteTimeout = writeTimeout + ts.Start() + defer ts.Close() + + c := ts.Client() + + errc := make(chan error, 1) + go func() { + res, err := c.Get(ts.URL) + if err != nil { + errc <- err + return + } + _, err = io.Copy(io.Discard, res.Body) + res.Body.Close() + errc <- err + }() + select { + case err := <-errc: + if err == nil { + t.Errorf("expected an error from Get request") + } + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for Get error") + } + if err := <-afterTimeoutErrc; err == nil { + t.Error("expected write error after timeout") + } +} + // trackLastConnListener tracks the last net.Conn that was accepted. type trackLastConnListener struct { net.Listener diff --git a/src/net/http/server.go b/src/net/http/server.go index 3d427e5ae4bab6..755bb6e3381f23 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -395,11 +395,11 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { return } -func (cw *chunkWriter) flush() { +func (cw *chunkWriter) flush() error { if !cw.wroteHeader { cw.writeHeader(nil) } - cw.res.conn.bufw.Flush() + return cw.res.conn.bufw.Flush() } func (cw *chunkWriter) close() { @@ -443,6 +443,14 @@ type response struct { w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter + // writeTimeoutTimer is set when the server has a WriteTimeout configured + // and triggers when a write timed out + // writeDeadline is used to enable direct flushing of writes after the + // timeout so writers receive an error and can handle it + writeTimeoutTimer *time.Timer + writeDeadline bool + writeDeadlineMu sync.Mutex + // handlerHeader is the Header that Handlers get access to, // which may be retained and mutated even after WriteHeader. // handlerHeader is copied into cw.header at WriteHeader @@ -1045,6 +1053,9 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { if isH2Upgrade { w.closeAfterReply = true } + if d := c.server.WriteTimeout; d > 0 { + w.setWriteTimeout(d) + } w.cw.res = w w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize) return w, nil @@ -1590,6 +1601,16 @@ func (w *response) WriteString(data string) (n int, err error) { return w.write(len(data), nil, data) } +// setWriteTimeout lets the response know if the write was supposed to be +// timed out, timed out requests will force be flushed on every write +func (w *response) setWriteTimeout(d time.Duration) { + w.writeTimeoutTimer = time.AfterFunc(d, func() { + w.writeDeadlineMu.Lock() + w.writeDeadline = true + w.writeDeadlineMu.Unlock() + }) +} + // either dataB or dataS is non-zero. func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { if w.conn.hijacked() { @@ -1625,10 +1646,22 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er return 0, ErrContentLength } if dataB != nil { - return w.w.Write(dataB) + n, err = w.w.Write(dataB) } else { - return w.w.WriteString(dataS) + n, err = w.w.WriteString(dataS) } + if err == nil { + w.writeDeadlineMu.Lock() + wd := w.writeDeadline + w.writeDeadlineMu.Unlock() + + if wd { + // r.Flush returns no errors, flush manually + w.w.Flush() + err = w.cw.flush() + } + } + return } func (w *response) finishRequest() { @@ -1643,6 +1676,9 @@ func (w *response) finishRequest() { w.cw.close() w.conn.bufw.Flush() + if w.writeTimeoutTimer != nil { + w.writeTimeoutTimer.Stop() + } w.conn.r.abortPendingRead() // Close the body (regardless of w.closeAfterReply) so we can