Skip to content

Commit bf23997

Browse files
authored
Better handle admin handler stream replication API lifecycle (#4647)
Make sure admin handler stream replication API is able to return if client -> server or server -> client link breaks
1 parent 8a75aed commit bf23997

File tree

3 files changed

+146
-27
lines changed

3 files changed

+146
-27
lines changed

service/frontend/admin_handler.go

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ import (
3333
"sync/atomic"
3434
"time"
3535

36-
"golang.org/x/sync/errgroup"
36+
"google.golang.org/grpc/metadata"
3737

38+
"go.temporal.io/server/client/history"
39+
"go.temporal.io/server/common/channel"
3840
"go.temporal.io/server/common/clock"
3941
"go.temporal.io/server/common/primitives"
4042
"go.temporal.io/server/common/util"
@@ -1950,62 +1952,84 @@ func (adh *AdminHandler) getWorkflowCompletionEvent(
19501952
}
19511953

19521954
func (adh *AdminHandler) StreamWorkflowReplicationMessages(
1953-
targetCluster adminservice.AdminService_StreamWorkflowReplicationMessagesServer,
1955+
clientCluster adminservice.AdminService_StreamWorkflowReplicationMessagesServer,
19541956
) (retError error) {
19551957
defer log.CapturePanic(adh.logger, &retError)
19561958

1957-
ctx := targetCluster.Context()
1958-
sourceCluster, err := adh.historyClient.StreamWorkflowReplicationMessages(ctx)
1959+
ctxMetadata, ok := metadata.FromIncomingContext(clientCluster.Context())
1960+
if !ok {
1961+
return serviceerror.NewInvalidArgument("missing cluster & shard ID metadata")
1962+
}
1963+
_, serverClusterShardID, err := history.DecodeClusterShardMD(ctxMetadata)
1964+
if err != nil {
1965+
return err
1966+
}
1967+
1968+
logger := log.With(adh.logger, tag.ShardID(serverClusterShardID.ShardID))
1969+
logger.Info("AdminStreamReplicationMessages started.")
1970+
defer logger.Info("AdminStreamReplicationMessages stopped.")
1971+
1972+
ctx := clientCluster.Context()
1973+
serverCluster, err := adh.historyClient.StreamWorkflowReplicationMessages(ctx)
19591974
if err != nil {
19601975
return err
19611976
}
19621977

1963-
errGroup, ctx := errgroup.WithContext(ctx)
1964-
errGroup.Go(func() error {
1965-
for ctx.Err() == nil {
1966-
req, err := targetCluster.Recv()
1978+
shutdownChan := channel.NewShutdownOnce()
1979+
go func() {
1980+
defer shutdownChan.Shutdown()
1981+
1982+
for !shutdownChan.IsShutdown() {
1983+
req, err := clientCluster.Recv()
19671984
if err != nil {
1968-
return err
1985+
logger.Info("AdminStreamReplicationMessages client -> server encountered error", tag.Error(err))
1986+
return
19691987
}
19701988
switch attr := req.GetAttributes().(type) {
19711989
case *adminservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState:
1972-
if err = sourceCluster.Send(&historyservice.StreamWorkflowReplicationMessagesRequest{
1990+
if err = serverCluster.Send(&historyservice.StreamWorkflowReplicationMessagesRequest{
19731991
Attributes: &historyservice.StreamWorkflowReplicationMessagesRequest_SyncReplicationState{
19741992
SyncReplicationState: attr.SyncReplicationState,
19751993
},
19761994
}); err != nil {
1977-
return err
1995+
logger.Info("AdminStreamReplicationMessages client -> server encountered error", tag.Error(err))
1996+
return
19781997
}
19791998
default:
1980-
return serviceerror.NewInternal(fmt.Sprintf(
1999+
logger.Info("AdminStreamReplicationMessages client -> server encountered error", tag.Error(serviceerror.NewInternal(fmt.Sprintf(
19812000
"StreamWorkflowReplicationMessages encountered unknown type: %T %v", attr, attr,
1982-
))
2001+
))))
2002+
return
19832003
}
19842004
}
1985-
return ctx.Err()
1986-
})
1987-
errGroup.Go(func() error {
1988-
for ctx.Err() == nil {
1989-
resp, err := sourceCluster.Recv()
2005+
}()
2006+
go func() {
2007+
defer shutdownChan.Shutdown()
2008+
2009+
for !shutdownChan.IsShutdown() {
2010+
resp, err := serverCluster.Recv()
19902011
if err != nil {
1991-
return err
2012+
logger.Info("AdminStreamReplicationMessages server -> client encountered error", tag.Error(err))
2013+
return
19922014
}
19932015
switch attr := resp.GetAttributes().(type) {
19942016
case *historyservice.StreamWorkflowReplicationMessagesResponse_Messages:
1995-
if err = targetCluster.Send(&adminservice.StreamWorkflowReplicationMessagesResponse{
2017+
if err = clientCluster.Send(&adminservice.StreamWorkflowReplicationMessagesResponse{
19962018
Attributes: &adminservice.StreamWorkflowReplicationMessagesResponse_Messages{
19972019
Messages: attr.Messages,
19982020
},
19992021
}); err != nil {
2000-
return err
2022+
logger.Info("AdminStreamReplicationMessages server -> client encountered error", tag.Error(err))
2023+
return
20012024
}
20022025
default:
2003-
return serviceerror.NewInternal(fmt.Sprintf(
2026+
logger.Info("AdminStreamReplicationMessages server -> client encountered error", tag.Error(serviceerror.NewInternal(fmt.Sprintf(
20042027
"StreamWorkflowReplicationMessages encountered unknown type: %T %v", attr, attr,
2005-
))
2028+
))))
2029+
return
20062030
}
20072031
}
2008-
return ctx.Err()
2009-
})
2010-
return errGroup.Wait()
2032+
}()
2033+
<-shutdownChan.Channel()
2034+
return nil
20112035
}

service/frontend/admin_handler_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,14 @@ import (
2828
"context"
2929
"errors"
3030
"fmt"
31+
"math/rand"
32+
"sync"
3133
"testing"
3234
"time"
3335

36+
"google.golang.org/grpc/metadata"
37+
38+
historyclient "go.temporal.io/server/client/history"
3439
"go.temporal.io/server/common/clock"
3540
"go.temporal.io/server/common/persistence/visibility/store/standard/cassandra"
3641
"go.temporal.io/server/common/primitives"
@@ -1445,3 +1450,93 @@ func (s *adminHandlerSuite) TestDeleteWorkflowExecution_CassandraVisibilityBacke
14451450
_, err = s.handler.DeleteWorkflowExecution(context.Background(), request)
14461451
s.NoError(err)
14471452
}
1453+
1454+
func (s *adminHandlerSuite) TestStreamWorkflowReplicationMessages_ClientToServerBroken() {
1455+
clientClusterShardID := historyclient.ClusterShardID{
1456+
ClusterID: rand.Int31(),
1457+
ShardID: rand.Int31(),
1458+
}
1459+
serverClusterShardID := historyclient.ClusterShardID{
1460+
ClusterID: rand.Int31(),
1461+
ShardID: rand.Int31(),
1462+
}
1463+
clusterShardMD := historyclient.EncodeClusterShardMD(
1464+
clientClusterShardID,
1465+
serverClusterShardID,
1466+
)
1467+
ctx := metadata.NewIncomingContext(context.Background(), clusterShardMD)
1468+
clientCluster := adminservicemock.NewMockAdminService_StreamWorkflowReplicationMessagesServer(s.controller)
1469+
clientCluster.EXPECT().Context().Return(ctx).AnyTimes()
1470+
serverCluster := historyservicemock.NewMockHistoryService_StreamWorkflowReplicationMessagesClient(s.controller)
1471+
s.mockHistoryClient.EXPECT().StreamWorkflowReplicationMessages(ctx).Return(serverCluster, nil)
1472+
1473+
waitGroupStart := sync.WaitGroup{}
1474+
waitGroupStart.Add(2)
1475+
waitGroupEnd := sync.WaitGroup{}
1476+
waitGroupEnd.Add(2)
1477+
channel := make(chan struct{})
1478+
1479+
clientCluster.EXPECT().Recv().DoAndReturn(func() (*adminservice.StreamWorkflowReplicationMessagesRequest, error) {
1480+
waitGroupStart.Done()
1481+
waitGroupStart.Wait()
1482+
1483+
defer waitGroupEnd.Done()
1484+
return nil, serviceerror.NewUnavailable("random error")
1485+
})
1486+
serverCluster.EXPECT().Recv().DoAndReturn(func() (*historyservice.StreamWorkflowReplicationMessagesResponse, error) {
1487+
waitGroupStart.Done()
1488+
waitGroupStart.Wait()
1489+
1490+
defer waitGroupEnd.Done()
1491+
<-channel
1492+
return nil, serviceerror.NewUnavailable("random error")
1493+
})
1494+
_ = s.handler.StreamWorkflowReplicationMessages(clientCluster)
1495+
close(channel)
1496+
waitGroupEnd.Wait()
1497+
}
1498+
1499+
func (s *adminHandlerSuite) TestStreamWorkflowReplicationMessages_ServerToClientBroken() {
1500+
clientClusterShardID := historyclient.ClusterShardID{
1501+
ClusterID: rand.Int31(),
1502+
ShardID: rand.Int31(),
1503+
}
1504+
serverClusterShardID := historyclient.ClusterShardID{
1505+
ClusterID: rand.Int31(),
1506+
ShardID: rand.Int31(),
1507+
}
1508+
clusterShardMD := historyclient.EncodeClusterShardMD(
1509+
clientClusterShardID,
1510+
serverClusterShardID,
1511+
)
1512+
ctx := metadata.NewIncomingContext(context.Background(), clusterShardMD)
1513+
clientCluster := adminservicemock.NewMockAdminService_StreamWorkflowReplicationMessagesServer(s.controller)
1514+
clientCluster.EXPECT().Context().Return(ctx).AnyTimes()
1515+
serverCluster := historyservicemock.NewMockHistoryService_StreamWorkflowReplicationMessagesClient(s.controller)
1516+
s.mockHistoryClient.EXPECT().StreamWorkflowReplicationMessages(ctx).Return(serverCluster, nil)
1517+
1518+
waitGroupStart := sync.WaitGroup{}
1519+
waitGroupStart.Add(2)
1520+
waitGroupEnd := sync.WaitGroup{}
1521+
waitGroupEnd.Add(2)
1522+
channel := make(chan struct{})
1523+
1524+
clientCluster.EXPECT().Recv().DoAndReturn(func() (*adminservice.StreamWorkflowReplicationMessagesRequest, error) {
1525+
waitGroupStart.Done()
1526+
waitGroupStart.Wait()
1527+
1528+
defer waitGroupEnd.Done()
1529+
<-channel
1530+
return nil, serviceerror.NewUnavailable("random error")
1531+
})
1532+
serverCluster.EXPECT().Recv().DoAndReturn(func() (*historyservice.StreamWorkflowReplicationMessagesResponse, error) {
1533+
waitGroupStart.Done()
1534+
waitGroupStart.Wait()
1535+
1536+
defer waitGroupEnd.Done()
1537+
return nil, serviceerror.NewUnavailable("random error")
1538+
})
1539+
_ = s.handler.StreamWorkflowReplicationMessages(clientCluster)
1540+
close(channel)
1541+
waitGroupEnd.Wait()
1542+
}

service/history/replication/stream_receiver_monitor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import (
3636
)
3737

3838
const (
39-
streamReceiverMonitorInterval = 5 * time.Second
39+
streamReceiverMonitorInterval = 2 * time.Second
4040
)
4141

4242
type (

0 commit comments

Comments
 (0)