diff --git a/http2/server.go b/http2/server.go index 9862dad03..e9614d975 100644 --- a/http2/server.go +++ b/http2/server.go @@ -51,10 +51,11 @@ import ( ) const ( - prefaceTimeout = 10 * time.Second - firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway - handlerChunkWriteSize = 4 << 10 - defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? + prefaceTimeout = 10 * time.Second + firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway + handlerChunkWriteSize = 4 << 10 + defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? + maxQueuedControlFrames = 10000 ) var ( @@ -162,6 +163,15 @@ func (s *Server) maxConcurrentStreams() uint32 { return defaultMaxStreams } +// maxQueuedControlFrames is the maximum number of control frames like +// SETTINGS, PING and RST_STREAM that will be queued for writing before +// the connection is closed to prevent memory exhaustion attacks. +func (s *Server) maxQueuedControlFrames() int { + // TODO: if anybody asks, add a Server field, and remember to define the + // behavior of negative values. + return maxQueuedControlFrames +} + type serverInternalState struct { mu sync.Mutex activeConns map[*serverConn]struct{} @@ -470,6 +480,7 @@ type serverConn struct { sawFirstSettings bool // got the initial SETTINGS frame after the preface needToSendSettingsAck bool unackedSettings int // how many SETTINGS have we sent without ACKs? + queuedControlFrames int // control frames in the writeSched queue clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client curClientStreams uint32 // number of open streams initiated by the client @@ -857,6 +868,14 @@ func (sc *serverConn) serve() { } } + // If the peer is causing us to generate a lot of control frames, + // but not reading them from us, assume they are trying to make us + // run out of memory. + if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() { + sc.vlogf("http2: too many control frames in send queue, closing connection") + return + } + // Start the shutdown timer after sending a GOAWAY. When sending GOAWAY // with no error code (graceful shutdown), don't start the timer until // all open streams have been completed. @@ -1056,6 +1075,14 @@ func (sc *serverConn) writeFrame(wr FrameWriteRequest) { } if !ignoreWrite { + if wr.isControl() { + sc.queuedControlFrames++ + // For extra safety, detect wraparounds, which should not happen, + // and pull the plug. + if sc.queuedControlFrames < 0 { + sc.conn.Close() + } + } sc.writeSched.Push(wr) } sc.scheduleFrameWrite() @@ -1173,10 +1200,8 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) { // If a frame is already being written, nothing happens. This will be called again // when the frame is done being written. // -// If a frame isn't being written we need to send one, the best frame -// to send is selected, preferring first things that aren't -// stream-specific (e.g. ACKing settings), and then finding the -// highest priority stream. +// If a frame isn't being written and we need to send one, the best frame +// to send is selected by writeSched. // // If a frame isn't being written and there's nothing else to send, we // flush the write buffer. @@ -1204,6 +1229,9 @@ func (sc *serverConn) scheduleFrameWrite() { } if !sc.inGoAway || sc.goAwayCode == ErrCodeNo { if wr, ok := sc.writeSched.Pop(); ok { + if wr.isControl() { + sc.queuedControlFrames-- + } sc.startFrameWrite(wr) continue } @@ -1496,6 +1524,8 @@ func (sc *serverConn) processSettings(f *SettingsFrame) error { if err := f.ForeachSetting(sc.processSetting); err != nil { return err } + // TODO: judging by RFC 7540, Section 6.5.3 each SETTINGS frame should be + // acknowledged individually, even if multiple are received before the ACK. sc.needToSendSettingsAck = true sc.scheduleFrameWrite() return nil diff --git a/http2/server_test.go b/http2/server_test.go index bb19c9668..10cde575f 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -1159,6 +1159,32 @@ func TestServer_Ping(t *testing.T) { } } +func TestServer_MaxQueuedControlFrames(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + st := newServerTester(t, nil) + defer st.Close() + st.greet() + + const extraPings = 500000 // enough to fill the TCP buffers + + for i := 0; i < maxQueuedControlFrames+extraPings; i++ { + pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + if err := st.fr.WritePing(false, pingData); err != nil { + if i == 0 { + t.Fatal(err) + } + // We expect the connection to get closed by the server when the TCP + // buffer fills up and the write queue reaches MaxQueuedControlFrames. + t.Logf("sent %d PING frames", i) + return + } + } + t.Errorf("unexpected success sending all PING frames") +} + func TestServer_RejectsLargeFrames(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("see golang.org/issue/13434") diff --git a/http2/transport.go b/http2/transport.go index 9d1f2fadd..ef356d6d9 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -1060,6 +1060,7 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf default: } if err != nil { + cc.forgetStreamID(cs.ID) return nil, cs.getStartedWrite(), err } bodyWritten = true @@ -1181,6 +1182,7 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( sawEOF = true err = nil } else if err != nil { + cc.writeStreamReset(cs.ID, ErrCodeCancel, err) return err } diff --git a/http2/transport_test.go b/http2/transport_test.go index 5b5c0768f..2c0f53e5c 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -4183,3 +4183,99 @@ func TestNoDialH2RoundTripperType(t *testing.T) { t.Fatalf("wrong kind %T; want *Transport", v.Interface()) } } + +type errReader struct { + body []byte + err error +} + +func (r *errReader) Read(p []byte) (int, error) { + if len(r.body) > 0 { + n := copy(p, r.body) + r.body = r.body[n:] + return n, nil + } + return 0, r.err +} + +func testTransportBodyReadError(t *testing.T, body []byte) { + clientDone := make(chan struct{}) + ct := newClientTester(t) + ct.client = func() error { + defer ct.cc.(*net.TCPConn).CloseWrite() + defer close(clientDone) + + checkNoStreams := func() error { + cp, ok := ct.tr.connPool().(*clientConnPool) + if !ok { + return fmt.Errorf("conn pool is %T; want *clientConnPool", ct.tr.connPool()) + } + cp.mu.Lock() + defer cp.mu.Unlock() + conns, ok := cp.conns["dummy.tld:443"] + if !ok { + return fmt.Errorf("missing connection") + } + if len(conns) != 1 { + return fmt.Errorf("conn pool size: %v; expect 1", len(conns)) + } + if activeStreams(conns[0]) != 0 { + return fmt.Errorf("active streams count: %v; want 0", activeStreams(conns[0])) + } + return nil + } + bodyReadError := errors.New("body read error") + body := &errReader{body, bodyReadError} + req, err := http.NewRequest("PUT", "https://dummy.tld/", body) + if err != nil { + return err + } + _, err = ct.tr.RoundTrip(req) + if err != bodyReadError { + return fmt.Errorf("err = %v; want %v", err, bodyReadError) + } + if err = checkNoStreams(); err != nil { + return err + } + return nil + } + ct.server = func() error { + ct.greet() + var receivedBody []byte + var resetCount int + for { + f, err := ct.fr.ReadFrame() + if err != nil { + select { + case <-clientDone: + // If the client's done, it + // will have reported any + // errors on its side. + if bytes.Compare(receivedBody, body) != 0 { + return fmt.Errorf("body: %v; expected %v", receivedBody, body) + } + if resetCount != 1 { + return fmt.Errorf("stream reset count: %v; expected: 1", resetCount) + } + return nil + default: + return err + } + } + switch f := f.(type) { + case *WindowUpdateFrame, *SettingsFrame: + case *HeadersFrame: + case *DataFrame: + receivedBody = append(receivedBody, f.Data()...) + case *RSTStreamFrame: + resetCount++ + default: + return fmt.Errorf("Unexpected client frame %v", f) + } + } + } + ct.run() +} + +func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) } +func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) } diff --git a/http2/writesched.go b/http2/writesched.go index 4fe307307..f24d2b1e7 100644 --- a/http2/writesched.go +++ b/http2/writesched.go @@ -32,7 +32,7 @@ type WriteScheduler interface { // Pop dequeues the next frame to write. Returns false if no frames can // be written. Frames with a given wr.StreamID() are Pop'd in the same - // order they are Push'd. + // order they are Push'd. No frames should be discarded except by CloseStream. Pop() (wr FrameWriteRequest, ok bool) } @@ -76,6 +76,12 @@ func (wr FrameWriteRequest) StreamID() uint32 { return wr.stream.id } +// isControl reports whether wr is a control frame for MaxQueuedControlFrames +// purposes. That includes non-stream frames and RST_STREAM frames. +func (wr FrameWriteRequest) isControl() bool { + return wr.stream == nil +} + // DataSize returns the number of flow control bytes that must be consumed // to write this entire frame. This is 0 for non-DATA frames. func (wr FrameWriteRequest) DataSize() int { diff --git a/route/message_freebsd_test.go b/route/message_freebsd_test.go index db4b56752..c6d8a5f54 100644 --- a/route/message_freebsd_test.go +++ b/route/message_freebsd_test.go @@ -4,10 +4,7 @@ package route -import ( - "testing" - "unsafe" -) +import "testing" func TestFetchAndParseRIBOnFreeBSD(t *testing.T) { for _, typ := range []RIBType{sysNET_RT_IFMALIST} { @@ -40,8 +37,7 @@ func TestFetchAndParseRIBOnFreeBSD10AndAbove(t *testing.T) { if _, err := FetchRIB(sysAF_UNSPEC, sysNET_RT_IFLISTL, 0); err != nil { t.Skip("NET_RT_IFLISTL not supported") } - var p uintptr - if kernelAlign != int(unsafe.Sizeof(p)) { + if compatFreeBSD32 { t.Skip("NET_RT_IFLIST vs. NET_RT_IFLISTL doesn't work for 386 emulation on amd64") } diff --git a/route/sys_freebsd.go b/route/sys_freebsd.go index 89ba1c4e2..fe91be124 100644 --- a/route/sys_freebsd.go +++ b/route/sys_freebsd.go @@ -54,10 +54,12 @@ func (m *InterfaceMessage) Sys() []Sys { } } +var compatFreeBSD32 bool // 386 emulation on amd64 + func probeRoutingStack() (int, map[int]*wireFormat) { var p uintptr wordSize := int(unsafe.Sizeof(p)) - align := int(unsafe.Sizeof(p)) + align := wordSize // In the case of kern.supported_archs="amd64 i386", we need // to know the underlying kernel's architecture because the // alignment for routing facilities are set at the build time @@ -83,8 +85,11 @@ func probeRoutingStack() (int, map[int]*wireFormat) { break } } + if align != wordSize { + compatFreeBSD32 = true // 386 emulation on amd64 + } var rtm, ifm, ifam, ifmam, ifanm *wireFormat - if align != wordSize { // 386 emulation on amd64 + if compatFreeBSD32 { rtm = &wireFormat{extOff: sizeofRtMsghdrFreeBSD10Emu - sizeofRtMetricsFreeBSD10Emu, bodyOff: sizeofRtMsghdrFreeBSD10Emu} ifm = &wireFormat{extOff: 16} ifam = &wireFormat{extOff: sizeofIfaMsghdrFreeBSD10Emu, bodyOff: sizeofIfaMsghdrFreeBSD10Emu} @@ -100,35 +105,38 @@ func probeRoutingStack() (int, map[int]*wireFormat) { rel, _ := syscall.SysctlUint32("kern.osreldate") switch { case rel < 800000: - if align != wordSize { // 386 emulation on amd64 + if compatFreeBSD32 { ifm.bodyOff = sizeofIfMsghdrFreeBSD7Emu } else { ifm.bodyOff = sizeofIfMsghdrFreeBSD7 } case 800000 <= rel && rel < 900000: - if align != wordSize { // 386 emulation on amd64 + if compatFreeBSD32 { ifm.bodyOff = sizeofIfMsghdrFreeBSD8Emu } else { ifm.bodyOff = sizeofIfMsghdrFreeBSD8 } case 900000 <= rel && rel < 1000000: - if align != wordSize { // 386 emulation on amd64 + if compatFreeBSD32 { ifm.bodyOff = sizeofIfMsghdrFreeBSD9Emu } else { ifm.bodyOff = sizeofIfMsghdrFreeBSD9 } case 1000000 <= rel && rel < 1100000: - if align != wordSize { // 386 emulation on amd64 + if compatFreeBSD32 { ifm.bodyOff = sizeofIfMsghdrFreeBSD10Emu } else { ifm.bodyOff = sizeofIfMsghdrFreeBSD10 } default: - if align != wordSize { // 386 emulation on amd64 + if compatFreeBSD32 { ifm.bodyOff = sizeofIfMsghdrFreeBSD11Emu } else { ifm.bodyOff = sizeofIfMsghdrFreeBSD11 } + if rel >= 1102000 { // see https://github.com/freebsd/freebsd/commit/027c7f4d66ff8d8c4a46c3665a5ee7d6d8462034#diff-ad4e5b7f1449ea3fc87bc97280de145b + align = wordSize + } } rtm.parse = rtm.parseRouteMessage ifm.parse = ifm.parseInterfaceMessage