Skip to content

net/http: unfurl persistConnWriter's underlying writer #30390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/net/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,17 @@ func (w persistConnWriter) Write(p []byte) (n int, err error) {
return
}

// ReadFrom exposes persistConnWriter's underlying Conn to io.Copy and if
// the Conn implements io.ReaderFrom, it can take advantage of optimizations
// such as sendfile.
func (w persistConnWriter) ReadFrom(r io.Reader) (n int64, err error) {
n, err = io.Copy(w.pc.conn, r)
w.pc.nwrite += n
return
}

var _ io.ReaderFrom = (*persistConnWriter)(nil)

// connectMethod is the map key (in its String form) for keeping persistent
// TCP connections alive for subsequent HTTP requests.
//
Expand Down
143 changes: 143 additions & 0 deletions src/net/http/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5059,3 +5059,146 @@ func TestTransportRequestReplayable(t *testing.T) {
})
}
}

// testMockTCPConn is a mock TCP connection used to test that
// ReadFrom is called when sending the request body.
type testMockTCPConn struct {
*net.TCPConn

ReadFromCalled bool
}

func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) {
c.ReadFromCalled = true
return c.TCPConn.ReadFrom(r)
}

func TestTransportRequestWriteRoundTrip(t *testing.T) {
nBytes := int64(1 << 10)
newFileFunc := func() (r io.Reader, done func(), err error) {
f, err := ioutil.TempFile("", "net-http-newfilefunc")
if err != nil {
return nil, nil, err
}

// Write some bytes to the file to enable reading.
if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
}
if _, err := f.Seek(0, 0); err != nil {
return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
}

done = func() {
f.Close()
os.Remove(f.Name())
}

return f, done, nil
}

newBufferFunc := func() (io.Reader, func(), error) {
return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
}

cases := []struct {
name string
readerFunc func() (io.Reader, func(), error)
contentLength int64
expectedReadFrom bool
}{
{
name: "file, length",
readerFunc: newFileFunc,
contentLength: nBytes,
expectedReadFrom: true,
},
{
name: "file, no length",
readerFunc: newFileFunc,
},
{
name: "file, negative length",
readerFunc: newFileFunc,
contentLength: -1,
},
{
name: "buffer",
contentLength: nBytes,
readerFunc: newBufferFunc,
},
{
name: "buffer, no length",
readerFunc: newBufferFunc,
},
{
name: "buffer, length -1",
contentLength: -1,
readerFunc: newBufferFunc,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
r, cleanup, err := tc.readerFunc()
if err != nil {
t.Fatal(err)
}
defer cleanup()

tConn := &testMockTCPConn{}
trFunc := func(tr *Transport) {
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
var d net.Dialer
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}

tcpConn, ok := conn.(*net.TCPConn)
if !ok {
return nil, fmt.Errorf("%s/%s does not provide a *net.TCPConn", network, addr)
}

tConn.TCPConn = tcpConn
return tConn, nil
}
}

cst := newClientServerTest(
t,
h1Mode,
HandlerFunc(func(w ResponseWriter, r *Request) {
io.Copy(ioutil.Discard, r.Body)
r.Body.Close()
w.WriteHeader(200)
}),
trFunc,
)
defer cst.close()

req, err := NewRequest("PUT", cst.ts.URL, r)
if err != nil {
t.Fatal(err)
}
req.ContentLength = tc.contentLength
req.Header.Set("Content-Type", "application/octet-stream")
resp, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("status code = %d; want 200", resp.StatusCode)
}

if !tConn.ReadFromCalled && tc.expectedReadFrom {
t.Fatalf("did not call ReadFrom")
}

if tConn.ReadFromCalled && !tc.expectedReadFrom {
t.Fatalf("ReadFrom was unexpectedly invoked")
}
})
}
}