diff --git a/PR_DESCRIPTION.md b/PR_DESCRIPTION.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/server/grpc_service.go b/server/grpc_service.go index c5cba4f0085..4e7ea71b7f2 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -113,6 +113,9 @@ func (s *tsoServer) Send(m *pdpb.TsoResponse) error { case err := <-done: if err != nil { atomic.StoreInt32(&s.closed, 1) + if err == io.EOF { + return io.EOF + } } return errors.WithStack(err) case <-timer.C: @@ -142,6 +145,9 @@ func (s *tsoServer) recv(timeout time.Duration) (*pdpb.TsoRequest, error) { case req := <-requestCh: if req.err != nil { atomic.StoreInt32(&s.closed, 1) + if req.err == io.EOF { + return nil, io.EOF + } return nil, errors.WithStack(req.err) } return req.request, nil @@ -174,6 +180,9 @@ func (s *heartbeatServer) Send(m core.RegionHeartbeatResponse) error { case err := <-done: if err != nil { atomic.StoreInt32(&s.closed, 1) + if err == io.EOF { + return io.EOF + } } return errors.WithStack(err) case <-timer.C: @@ -190,6 +199,9 @@ func (s *heartbeatServer) Recv() (*pdpb.RegionHeartbeatRequest, error) { req, err := s.stream.Recv() if err != nil { atomic.StoreInt32(&s.closed, 1) + if err == io.EOF { + return nil, err + } return nil, errors.WithStack(err) } return req, nil @@ -1030,7 +1042,10 @@ func (b *bucketHeartbeatServer) send(bucket *pdpb.ReportBucketsResponse) error { if err != nil { atomic.StoreInt32(&b.closed, 1) } - return err + if err == io.EOF { + return io.EOF + } + return errors.WithStack(err) case <-timer.C: atomic.StoreInt32(&b.closed, 1) return errs.ErrSendHeartbeatTimeout @@ -1044,6 +1059,9 @@ func (b *bucketHeartbeatServer) recv() (*pdpb.ReportBucketsRequest, error) { req, err := b.stream.Recv() if err != nil { atomic.StoreInt32(&b.closed, 1) + if err == io.EOF { + return nil, io.EOF + } return nil, errors.WithStack(err) } return req, nil @@ -1228,6 +1246,7 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error { } } + // RegionHeartbeat implements gRPC PDServer. func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error { var ( diff --git a/server/grpc_service_test.go b/server/grpc_service_test.go new file mode 100644 index 00000000000..034ca4a021d --- /dev/null +++ b/server/grpc_service_test.go @@ -0,0 +1,164 @@ +package server + +import ( + "context" + "io" + "sync/atomic" + "testing" + "time" + + "github.com/pingcap/kvproto/pkg/pdpb" + "google.golang.org/grpc" +) + +type mockReportBucketsServer struct { + grpc.ServerStream + recvFunc func() (*pdpb.ReportBucketsRequest, error) +} + +func (m *mockReportBucketsServer) SendAndClose(*pdpb.ReportBucketsResponse) error { return nil } +func (m *mockReportBucketsServer) Recv() (*pdpb.ReportBucketsRequest, error) { + if m.recvFunc != nil { + return m.recvFunc() + } + return nil, io.EOF +} +func (m *mockReportBucketsServer) Context() context.Context { return context.Background() } + +func TestBucketHeartbeatServerRecvEOF(t *testing.T) { + stream := &mockReportBucketsServer{ + recvFunc: func() (*pdpb.ReportBucketsRequest, error) { + return nil, io.EOF + }, + } + s := &bucketHeartbeatServer{stream: stream} + _, err := s.recv() + + t.Logf("Error type: %T, value: %v", err, err) + if err == io.EOF { + t.Log("Error is strictly equal to io.EOF") + } else { + t.Log("Error is NOT strictly equal to io.EOF") + } + + // We expect io.EOF exactly, but due to the bug it is wrapped. + // This test should FAIL if the bug is present. + if err != io.EOF { + t.Fatalf("recv() returned error %v (type %T), want exactly io.EOF", err, err) + } + if atomic.LoadInt32(&s.closed) != 1 { + t.Errorf("expected closed to be 1, got %d", atomic.LoadInt32(&s.closed)) + } + if atomic.LoadInt32(&s.closed) != 1 { + t.Errorf("expected closed to be 1, got %d", atomic.LoadInt32(&s.closed)) + } +} + +type mockHeartbeatServer struct { + grpc.ServerStream + sendFunc func(*pdpb.RegionHeartbeatResponse) error + recvFunc func() (*pdpb.RegionHeartbeatRequest, error) +} + +func (m *mockHeartbeatServer) Send(resp *pdpb.RegionHeartbeatResponse) error { + if m.sendFunc != nil { + return m.sendFunc(resp) + } + return nil +} + +func (m *mockHeartbeatServer) Recv() (*pdpb.RegionHeartbeatRequest, error) { + if m.recvFunc != nil { + return m.recvFunc() + } + return nil, nil +} + +func TestHeartbeatServerSendEOF(t *testing.T) { + stream := &mockHeartbeatServer{ + sendFunc: func(_ *pdpb.RegionHeartbeatResponse) error { + return io.EOF + }, + } + s := &heartbeatServer{stream: stream} + err := s.Send(&pdpb.RegionHeartbeatResponse{}) + + t.Logf("Error type: %T, value: %v", err, err) + if err != io.EOF { + t.Fatalf("Send() returned error %v (type %T), want exactly io.EOF", err, err) + } +} + +func TestHeartbeatServerRecvEOF(t *testing.T) { + stream := &mockHeartbeatServer{ + recvFunc: func() (*pdpb.RegionHeartbeatRequest, error) { + return nil, io.EOF + }, + } + s := &heartbeatServer{stream: stream} + _, err := s.Recv() + + t.Logf("Error type: %T, value: %v", err, err) + if err != io.EOF { + t.Fatalf("Recv() returned error %v (type %T), want exactly io.EOF", err, err) + } + if atomic.LoadInt32(&s.closed) != 1 { + t.Errorf("expected closed to be 1, got %d", atomic.LoadInt32(&s.closed)) + } + if atomic.LoadInt32(&s.closed) != 1 { + t.Errorf("expected closed to be 1, got %d", atomic.LoadInt32(&s.closed)) + } +} + +type mockTsoServer struct { + grpc.ServerStream + sendFunc func(*pdpb.TsoResponse) error + recvFunc func() (*pdpb.TsoRequest, error) +} + +func (m *mockTsoServer) Send(resp *pdpb.TsoResponse) error { + if m.sendFunc != nil { + return m.sendFunc(resp) + } + return nil +} + +func (m *mockTsoServer) Recv() (*pdpb.TsoRequest, error) { + if m.recvFunc != nil { + return m.recvFunc() + } + return nil, nil +} + +func TestTsoServerSendEOF(t *testing.T) { + stream := &mockTsoServer{ + sendFunc: func(_ *pdpb.TsoResponse) error { + return io.EOF + }, + } + s := &tsoServer{stream: stream} + err := s.Send(&pdpb.TsoResponse{}) + + t.Logf("Error type: %T, value: %v", err, err) + if err != io.EOF { + t.Fatalf("Send() returned error %v (type %T), want exactly io.EOF", err, err) + } +} + +func TestTsoServerRecvEOF(t *testing.T) { + stream := &mockTsoServer{ + recvFunc: func() (*pdpb.TsoRequest, error) { + return nil, io.EOF + }, + } + s := &tsoServer{stream: stream} + _, err := s.recv(time.Second) + + t.Logf("Error type: %T, value: %v", err, err) + if err != io.EOF { + t.Fatalf("recv() returned error %v (type %T), want exactly io.EOF", err, err) + } + if atomic.LoadInt32(&s.closed) != 1 { + t.Errorf("expected closed to be 1, got %d", atomic.LoadInt32(&s.closed)) + } +} diff --git a/tests/server/grpc_service_regr_test.go b/tests/server/grpc_service_regr_test.go new file mode 100644 index 00000000000..3c8ea39517e --- /dev/null +++ b/tests/server/grpc_service_regr_test.go @@ -0,0 +1,92 @@ +package server_test + +import ( + "context" + "io" + "testing" + + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/stretchr/testify/require" + "github.com/tikv/pd/pkg/utils/testutil" + "github.com/tikv/pd/tests" +) + +func TestRegionHeartbeatEOF(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + defer cluster.Destroy() + re.NoError(err) + + err = cluster.RunInitialServers() + re.NoError(err) + + leader := cluster.WaitLeader() + re.NotEmpty(leader) + leaderServer := cluster.GetServer(leader) + addr := leaderServer.GetAddr() + grpcPDClient, conn := testutil.MustNewGrpcClient(re, addr) + defer conn.Close() + + stream, err := grpcPDClient.RegionHeartbeat(ctx) + re.NoError(err) + + // Send one heartbeat to establish connection + err = stream.Send(&pdpb.RegionHeartbeatRequest{ + Header: &pdpb.RequestHeader{ClusterId: leaderServer.GetClusterID()}, + }) + re.NoError(err) + + // Close the send direction. This sends io.EOF to the server. + // The server should handle this as a clean shutdown and return nil. + err = stream.CloseSend() + re.NoError(err) + + // The server should close the stream cleanly, resulting in io.EOF on the client side. + // We might receive responses before EOF if the server sent any. + for { + _, err = stream.Recv() + if err != nil { + re.Equal(io.EOF, err) + break + } + } +} + +func TestReportBucketsEOF(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cluster, err := tests.NewTestCluster(ctx, 1) + defer cluster.Destroy() + re.NoError(err) + + err = cluster.RunInitialServers() + re.NoError(err) + + leader := cluster.WaitLeader() + re.NotEmpty(leader) + leaderServer := cluster.GetServer(leader) + addr := leaderServer.GetAddr() + grpcPDClient, conn := testutil.MustNewGrpcClient(re, addr) + defer conn.Close() + + stream, err := grpcPDClient.ReportBuckets(ctx) + re.NoError(err) + + // Send one bucket report to establish connection + err = stream.Send(&pdpb.ReportBucketsRequest{ + Header: &pdpb.RequestHeader{ClusterId: leaderServer.GetClusterID()}, + }) + re.NoError(err) + + // Close the send direction. + _, err = stream.CloseAndRecv() + // If the server handles io.EOF correctly, it should process the close and return response (or EOF if no response). + // ReportBuckets returns a response on CloseAndRecv usually. + // If it fails with wrapped error, we might get an error here. + if err != nil && err != io.EOF { + re.NoError(err) + } +}