Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added PR_DESCRIPTION.md
Empty file.
21 changes: 20 additions & 1 deletion server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ func (s *tsoServer) Send(m *pdpb.TsoResponse) error {
case err := <-done:
if err != nil {
atomic.StoreInt32(&s.closed, 1)
if err == io.EOF {
return io.EOF
}
}
return errors.WithStack(err)
case <-timer.C:
Expand Down Expand Up @@ -142,6 +145,9 @@ func (s *tsoServer) recv(timeout time.Duration) (*pdpb.TsoRequest, error) {
case req := <-requestCh:
if req.err != nil {
atomic.StoreInt32(&s.closed, 1)
if req.err == io.EOF {
return nil, io.EOF
}
return nil, errors.WithStack(req.err)
}
return req.request, nil
Expand Down Expand Up @@ -174,6 +180,9 @@ func (s *heartbeatServer) Send(m core.RegionHeartbeatResponse) error {
case err := <-done:
if err != nil {
atomic.StoreInt32(&s.closed, 1)
if err == io.EOF {
return io.EOF
}
}
return errors.WithStack(err)
case <-timer.C:
Expand All @@ -190,6 +199,9 @@ func (s *heartbeatServer) Recv() (*pdpb.RegionHeartbeatRequest, error) {
req, err := s.stream.Recv()
if err != nil {
atomic.StoreInt32(&s.closed, 1)
if err == io.EOF {
return nil, err
}
return nil, errors.WithStack(err)
}
return req, nil
Expand Down Expand Up @@ -1030,7 +1042,10 @@ func (b *bucketHeartbeatServer) send(bucket *pdpb.ReportBucketsResponse) error {
if err != nil {
atomic.StoreInt32(&b.closed, 1)
}
return err
if err == io.EOF {
return io.EOF
}
return errors.WithStack(err)
case <-timer.C:
atomic.StoreInt32(&b.closed, 1)
return errs.ErrSendHeartbeatTimeout
Expand All @@ -1044,6 +1059,9 @@ func (b *bucketHeartbeatServer) recv() (*pdpb.ReportBucketsRequest, error) {
req, err := b.stream.Recv()
if err != nil {
atomic.StoreInt32(&b.closed, 1)
if err == io.EOF {
return nil, io.EOF
}
return nil, errors.WithStack(err)
}
return req, nil
Expand Down Expand Up @@ -1228,6 +1246,7 @@ func (s *GrpcServer) ReportBuckets(stream pdpb.PD_ReportBucketsServer) error {
}
}


// RegionHeartbeat implements gRPC PDServer.
func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error {
var (
Expand Down
164 changes: 164 additions & 0 deletions server/grpc_service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package server

import (
"context"
"io"
"sync/atomic"
"testing"
"time"

"github.com/pingcap/kvproto/pkg/pdpb"
"google.golang.org/grpc"
)

type mockReportBucketsServer struct {
grpc.ServerStream
recvFunc func() (*pdpb.ReportBucketsRequest, error)
}

func (m *mockReportBucketsServer) SendAndClose(*pdpb.ReportBucketsResponse) error { return nil }
func (m *mockReportBucketsServer) Recv() (*pdpb.ReportBucketsRequest, error) {
if m.recvFunc != nil {
return m.recvFunc()
}
return nil, io.EOF
}
func (m *mockReportBucketsServer) Context() context.Context { return context.Background() }

func TestBucketHeartbeatServerRecvEOF(t *testing.T) {
stream := &mockReportBucketsServer{
recvFunc: func() (*pdpb.ReportBucketsRequest, error) {
return nil, io.EOF
},
}
s := &bucketHeartbeatServer{stream: stream}
_, err := s.recv()

t.Logf("Error type: %T, value: %v", err, err)
if err == io.EOF {
t.Log("Error is strictly equal to io.EOF")
} else {
t.Log("Error is NOT strictly equal to io.EOF")
}

// We expect io.EOF exactly, but due to the bug it is wrapped.
// This test should FAIL if the bug is present.
if err != io.EOF {
t.Fatalf("recv() returned error %v (type %T), want exactly io.EOF", err, err)
}
if atomic.LoadInt32(&s.closed) != 1 {
t.Errorf("expected closed to be 1, got %d", atomic.LoadInt32(&s.closed))
}
if atomic.LoadInt32(&s.closed) != 1 {
t.Errorf("expected closed to be 1, got %d", atomic.LoadInt32(&s.closed))
}
}

type mockHeartbeatServer struct {
grpc.ServerStream
sendFunc func(*pdpb.RegionHeartbeatResponse) error
recvFunc func() (*pdpb.RegionHeartbeatRequest, error)
}

func (m *mockHeartbeatServer) Send(resp *pdpb.RegionHeartbeatResponse) error {
if m.sendFunc != nil {
return m.sendFunc(resp)
}
return nil
}

func (m *mockHeartbeatServer) Recv() (*pdpb.RegionHeartbeatRequest, error) {
if m.recvFunc != nil {
return m.recvFunc()
}
return nil, nil
}

func TestHeartbeatServerSendEOF(t *testing.T) {
stream := &mockHeartbeatServer{
sendFunc: func(_ *pdpb.RegionHeartbeatResponse) error {
return io.EOF
},
}
s := &heartbeatServer{stream: stream}
err := s.Send(&pdpb.RegionHeartbeatResponse{})

t.Logf("Error type: %T, value: %v", err, err)
if err != io.EOF {
t.Fatalf("Send() returned error %v (type %T), want exactly io.EOF", err, err)
}
}

func TestHeartbeatServerRecvEOF(t *testing.T) {
stream := &mockHeartbeatServer{
recvFunc: func() (*pdpb.RegionHeartbeatRequest, error) {
return nil, io.EOF
},
}
s := &heartbeatServer{stream: stream}
_, err := s.Recv()

t.Logf("Error type: %T, value: %v", err, err)
if err != io.EOF {
t.Fatalf("Recv() returned error %v (type %T), want exactly io.EOF", err, err)
}
if atomic.LoadInt32(&s.closed) != 1 {
t.Errorf("expected closed to be 1, got %d", atomic.LoadInt32(&s.closed))
}
if atomic.LoadInt32(&s.closed) != 1 {
t.Errorf("expected closed to be 1, got %d", atomic.LoadInt32(&s.closed))
}
}

type mockTsoServer struct {
grpc.ServerStream
sendFunc func(*pdpb.TsoResponse) error
recvFunc func() (*pdpb.TsoRequest, error)
}

func (m *mockTsoServer) Send(resp *pdpb.TsoResponse) error {
if m.sendFunc != nil {
return m.sendFunc(resp)
}
return nil
}

func (m *mockTsoServer) Recv() (*pdpb.TsoRequest, error) {
if m.recvFunc != nil {
return m.recvFunc()
}
return nil, nil
}

func TestTsoServerSendEOF(t *testing.T) {
stream := &mockTsoServer{
sendFunc: func(_ *pdpb.TsoResponse) error {
return io.EOF
},
}
s := &tsoServer{stream: stream}
err := s.Send(&pdpb.TsoResponse{})

t.Logf("Error type: %T, value: %v", err, err)
if err != io.EOF {
t.Fatalf("Send() returned error %v (type %T), want exactly io.EOF", err, err)
}
}

func TestTsoServerRecvEOF(t *testing.T) {
stream := &mockTsoServer{
recvFunc: func() (*pdpb.TsoRequest, error) {
return nil, io.EOF
},
}
s := &tsoServer{stream: stream}
_, err := s.recv(time.Second)

t.Logf("Error type: %T, value: %v", err, err)
if err != io.EOF {
t.Fatalf("recv() returned error %v (type %T), want exactly io.EOF", err, err)
}
if atomic.LoadInt32(&s.closed) != 1 {
t.Errorf("expected closed to be 1, got %d", atomic.LoadInt32(&s.closed))
}
}
92 changes: 92 additions & 0 deletions tests/server/grpc_service_regr_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package server_test

import (
"context"
"io"
"testing"

"github.com/pingcap/kvproto/pkg/pdpb"
"github.com/stretchr/testify/require"
"github.com/tikv/pd/pkg/utils/testutil"
"github.com/tikv/pd/tests"
)

func TestRegionHeartbeatEOF(t *testing.T) {
re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cluster, err := tests.NewTestCluster(ctx, 1)
defer cluster.Destroy()
re.NoError(err)

err = cluster.RunInitialServers()
re.NoError(err)

leader := cluster.WaitLeader()
re.NotEmpty(leader)
leaderServer := cluster.GetServer(leader)
addr := leaderServer.GetAddr()
grpcPDClient, conn := testutil.MustNewGrpcClient(re, addr)
defer conn.Close()

stream, err := grpcPDClient.RegionHeartbeat(ctx)
re.NoError(err)

// Send one heartbeat to establish connection
err = stream.Send(&pdpb.RegionHeartbeatRequest{
Header: &pdpb.RequestHeader{ClusterId: leaderServer.GetClusterID()},
})
re.NoError(err)

// Close the send direction. This sends io.EOF to the server.
// The server should handle this as a clean shutdown and return nil.
err = stream.CloseSend()
re.NoError(err)

// The server should close the stream cleanly, resulting in io.EOF on the client side.
// We might receive responses before EOF if the server sent any.
for {
_, err = stream.Recv()
if err != nil {
re.Equal(io.EOF, err)
break
}
}
}

func TestReportBucketsEOF(t *testing.T) {
re := require.New(t)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cluster, err := tests.NewTestCluster(ctx, 1)
defer cluster.Destroy()
re.NoError(err)

err = cluster.RunInitialServers()
re.NoError(err)

leader := cluster.WaitLeader()
re.NotEmpty(leader)
leaderServer := cluster.GetServer(leader)
addr := leaderServer.GetAddr()
grpcPDClient, conn := testutil.MustNewGrpcClient(re, addr)
defer conn.Close()

stream, err := grpcPDClient.ReportBuckets(ctx)
re.NoError(err)

// Send one bucket report to establish connection
err = stream.Send(&pdpb.ReportBucketsRequest{
Header: &pdpb.RequestHeader{ClusterId: leaderServer.GetClusterID()},
})
re.NoError(err)

// Close the send direction.
_, err = stream.CloseAndRecv()
// If the server handles io.EOF correctly, it should process the close and return response (or EOF if no response).
// ReportBuckets returns a response on CloseAndRecv usually.
// If it fails with wrapped error, we might get an error here.
if err != nil && err != io.EOF {
re.NoError(err)
}
}
Loading