Skip to content

Commit b8b1448

Browse files
authored
Fix YARPC context propagation (#116)
1 parent 05d04dd commit b8b1448

File tree

2 files changed

+49
-36
lines changed

2 files changed

+49
-36
lines changed

service/frontend/workflowHandler.go

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ func (wh *WorkflowHandler) PollForActivityTask(
397397
DomainUUID: common.StringPtr(domainID),
398398
PollerID: common.StringPtr(pollerID),
399399
PollRequest: pollRequest,
400-
})
400+
}, versionHeaders(ctx)...)
401401
return err
402402
}
403403

@@ -487,7 +487,7 @@ func (wh *WorkflowHandler) PollForDecisionTask(
487487
DomainUUID: common.StringPtr(domainID),
488488
PollerID: common.StringPtr(pollerID),
489489
PollRequest: pollRequest,
490-
})
490+
}, versionHeaders(ctx)...)
491491
return err
492492
}
493493

@@ -547,7 +547,7 @@ func (wh *WorkflowHandler) cancelOutstandingPoll(ctx context.Context, err error,
547547
TaskListType: common.Int32Ptr(taskListType),
548548
TaskList: taskList,
549549
PollerID: common.StringPtr(pollerID),
550-
})
550+
}, versionHeaders(ctx)...)
551551
// We can not do much if this call fails. Just log the error and move on
552552
if err != nil {
553553
wh.GetLogger().Warn("Failed to cancel outstanding poller.",
@@ -629,7 +629,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeat(
629629
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
630630
DomainUUID: common.StringPtr(taskToken.DomainID),
631631
FailedRequest: failRequest,
632-
})
632+
}, versionHeaders(ctx)...)
633633
if err != nil {
634634
return nil, wh.error(err, scope)
635635
}
@@ -638,7 +638,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeat(
638638
resp, err = wh.GetHistoryClient().RecordActivityTaskHeartbeat(ctx, &h.RecordActivityTaskHeartbeatRequest{
639639
DomainUUID: common.StringPtr(taskToken.DomainID),
640640
HeartbeatRequest: heartbeatRequest,
641-
})
641+
}, versionHeaders(ctx)...)
642642
if err != nil {
643643
return nil, wh.error(err, scope)
644644
}
@@ -730,7 +730,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatByID(
730730
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
731731
DomainUUID: common.StringPtr(taskToken.DomainID),
732732
FailedRequest: failRequest,
733-
})
733+
}, versionHeaders(ctx)...)
734734
if err != nil {
735735
return nil, wh.error(err, scope)
736736
}
@@ -745,7 +745,7 @@ func (wh *WorkflowHandler) RecordActivityTaskHeartbeatByID(
745745
resp, err = wh.GetHistoryClient().RecordActivityTaskHeartbeat(ctx, &h.RecordActivityTaskHeartbeatRequest{
746746
DomainUUID: common.StringPtr(taskToken.DomainID),
747747
HeartbeatRequest: req,
748-
})
748+
}, versionHeaders(ctx)...)
749749
if err != nil {
750750
return nil, wh.error(err, scope)
751751
}
@@ -823,15 +823,15 @@ func (wh *WorkflowHandler) RespondActivityTaskCompleted(
823823
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
824824
DomainUUID: common.StringPtr(taskToken.DomainID),
825825
FailedRequest: failRequest,
826-
})
826+
}, versionHeaders(ctx)...)
827827
if err != nil {
828828
return wh.error(err, scope)
829829
}
830830
} else {
831831
err = wh.GetHistoryClient().RespondActivityTaskCompleted(ctx, &h.RespondActivityTaskCompletedRequest{
832832
DomainUUID: common.StringPtr(taskToken.DomainID),
833833
CompleteRequest: completeRequest,
834-
})
834+
}, versionHeaders(ctx)...)
835835
if err != nil {
836836
return wh.error(err, scope)
837837
}
@@ -926,7 +926,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedByID(
926926
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
927927
DomainUUID: common.StringPtr(taskToken.DomainID),
928928
FailedRequest: failRequest,
929-
})
929+
}, versionHeaders(ctx)...)
930930
if err != nil {
931931
return wh.error(err, scope)
932932
}
@@ -940,7 +940,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCompletedByID(
940940
err = wh.GetHistoryClient().RespondActivityTaskCompleted(ctx, &h.RespondActivityTaskCompletedRequest{
941941
DomainUUID: common.StringPtr(taskToken.DomainID),
942942
CompleteRequest: req,
943-
})
943+
}, versionHeaders(ctx)...)
944944
if err != nil {
945945
return wh.error(err, scope)
946946
}
@@ -1017,7 +1017,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailed(
10171017
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
10181018
DomainUUID: common.StringPtr(taskToken.DomainID),
10191019
FailedRequest: failedRequest,
1020-
})
1020+
}, versionHeaders(ctx)...)
10211021
if err != nil {
10221022
return wh.error(err, scope)
10231023
}
@@ -1114,7 +1114,7 @@ func (wh *WorkflowHandler) RespondActivityTaskFailedByID(
11141114
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
11151115
DomainUUID: common.StringPtr(taskToken.DomainID),
11161116
FailedRequest: req,
1117-
})
1117+
}, versionHeaders(ctx)...)
11181118
if err != nil {
11191119
return wh.error(err, scope)
11201120
}
@@ -1191,15 +1191,15 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceled(
11911191
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
11921192
DomainUUID: common.StringPtr(taskToken.DomainID),
11931193
FailedRequest: failRequest,
1194-
})
1194+
}, versionHeaders(ctx)...)
11951195
if err != nil {
11961196
return wh.error(err, scope)
11971197
}
11981198
} else {
11991199
err = wh.GetHistoryClient().RespondActivityTaskCanceled(ctx, &h.RespondActivityTaskCanceledRequest{
12001200
DomainUUID: common.StringPtr(taskToken.DomainID),
12011201
CancelRequest: cancelRequest,
1202-
})
1202+
}, versionHeaders(ctx)...)
12031203
if err != nil {
12041204
return wh.error(err, scope)
12051205
}
@@ -1293,7 +1293,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceledByID(
12931293
err = wh.GetHistoryClient().RespondActivityTaskFailed(ctx, &h.RespondActivityTaskFailedRequest{
12941294
DomainUUID: common.StringPtr(taskToken.DomainID),
12951295
FailedRequest: failRequest,
1296-
})
1296+
}, versionHeaders(ctx)...)
12971297
if err != nil {
12981298
return wh.error(err, scope)
12991299
}
@@ -1307,7 +1307,7 @@ func (wh *WorkflowHandler) RespondActivityTaskCanceledByID(
13071307
err = wh.GetHistoryClient().RespondActivityTaskCanceled(ctx, &h.RespondActivityTaskCanceledRequest{
13081308
DomainUUID: common.StringPtr(taskToken.DomainID),
13091309
CancelRequest: req,
1310-
})
1310+
}, versionHeaders(ctx)...)
13111311
if err != nil {
13121312
return wh.error(err, scope)
13131313
}
@@ -1362,6 +1362,7 @@ func (wh *WorkflowHandler) RespondDecisionTaskCompleted(
13621362
histResp, err := wh.GetHistoryClient().RespondDecisionTaskCompleted(ctx, &h.RespondDecisionTaskCompletedRequest{
13631363
DomainUUID: common.StringPtr(taskToken.DomainID),
13641364
CompleteRequest: completeRequest},
1365+
versionHeaders(ctx)...,
13651366
)
13661367
if err != nil {
13671368
return nil, wh.error(err, scope)
@@ -1464,7 +1465,7 @@ func (wh *WorkflowHandler) RespondDecisionTaskFailed(
14641465
err = wh.GetHistoryClient().RespondDecisionTaskFailed(ctx, &h.RespondDecisionTaskFailedRequest{
14651466
DomainUUID: common.StringPtr(taskToken.DomainID),
14661467
FailedRequest: failedRequest,
1467-
})
1468+
}, versionHeaders(ctx)...)
14681469
if err != nil {
14691470
return wh.error(err, scope)
14701471
}
@@ -1547,7 +1548,7 @@ func (wh *WorkflowHandler) RespondQueryTaskCompleted(
15471548
CompletedRequest: completeRequest,
15481549
}
15491550

1550-
err = wh.GetMatchingClient().RespondQueryTaskCompleted(ctx, matchingRequest)
1551+
err = wh.GetMatchingClient().RespondQueryTaskCompleted(ctx, matchingRequest, versionHeaders(ctx)...)
15511552
if err != nil {
15521553
return wh.error(err, scope)
15531554
}
@@ -1692,7 +1693,7 @@ func (wh *WorkflowHandler) StartWorkflowExecution(
16921693
}
16931694

16941695
wh.GetLogger().Debug("Start workflow execution request domainID", tag.WorkflowDomainID(domainID))
1695-
resp, err = wh.GetHistoryClient().StartWorkflowExecution(ctx, common.CreateHistoryStartWorkflowRequest(domainID, startRequest))
1696+
resp, err = wh.GetHistoryClient().StartWorkflowExecution(ctx, common.CreateHistoryStartWorkflowRequest(domainID, startRequest), versionHeaders(ctx)...)
16961697

16971698
if err != nil {
16981699
return nil, wh.error(err, scope)
@@ -1763,7 +1764,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionRawHistory(
17631764
Execution: execution,
17641765
ExpectedNextEventId: nil,
17651766
CurrentBranchToken: currentBranchToken,
1766-
})
1767+
}, versionHeaders(ctx)...)
17671768

17681769
if err != nil {
17691770
return nil, "", 0, err
@@ -1919,7 +1920,7 @@ func (wh *WorkflowHandler) GetWorkflowExecutionHistory(
19191920
Execution: execution,
19201921
ExpectedNextEventId: common.Int64Ptr(expectedNextEventID),
19211922
CurrentBranchToken: currentBranchToken,
1922-
})
1923+
}, versionHeaders(ctx)...)
19231924

19241925
if err != nil {
19251926
return nil, "", 0, 0, false, err
@@ -2131,7 +2132,7 @@ func (wh *WorkflowHandler) SignalWorkflowExecution(
21312132
err = wh.GetHistoryClient().SignalWorkflowExecution(ctx, &h.SignalWorkflowExecutionRequest{
21322133
DomainUUID: common.StringPtr(domainID),
21332134
SignalRequest: signalRequest,
2134-
})
2135+
}, versionHeaders(ctx)...)
21352136
if err != nil {
21362137
return wh.error(err, scope)
21372138
}
@@ -2266,7 +2267,7 @@ func (wh *WorkflowHandler) SignalWithStartWorkflowExecution(
22662267
resp, err = wh.GetHistoryClient().SignalWithStartWorkflowExecution(ctx, &h.SignalWithStartWorkflowExecutionRequest{
22672268
DomainUUID: common.StringPtr(domainID),
22682269
SignalWithStartRequest: signalWithStartRequest,
2269-
})
2270+
}, versionHeaders(ctx)...)
22702271
return err
22712272
}
22722273

@@ -2317,7 +2318,7 @@ func (wh *WorkflowHandler) TerminateWorkflowExecution(
23172318
err = wh.GetHistoryClient().TerminateWorkflowExecution(ctx, &h.TerminateWorkflowExecutionRequest{
23182319
DomainUUID: common.StringPtr(domainID),
23192320
TerminateRequest: terminateRequest,
2320-
})
2321+
}, versionHeaders(ctx)...)
23212322
if err != nil {
23222323
return wh.error(err, scope)
23232324
}
@@ -2364,7 +2365,7 @@ func (wh *WorkflowHandler) ResetWorkflowExecution(
23642365
resp, err = wh.GetHistoryClient().ResetWorkflowExecution(ctx, &h.ResetWorkflowExecutionRequest{
23652366
DomainUUID: common.StringPtr(domainID),
23662367
ResetRequest: resetRequest,
2367-
})
2368+
}, versionHeaders(ctx)...)
23682369
if err != nil {
23692370
return nil, wh.error(err, scope)
23702371
}
@@ -2410,7 +2411,7 @@ func (wh *WorkflowHandler) RequestCancelWorkflowExecution(
24102411
err = wh.GetHistoryClient().RequestCancelWorkflowExecution(ctx, &h.RequestCancelWorkflowExecutionRequest{
24112412
DomainUUID: common.StringPtr(domainID),
24122413
CancelRequest: cancelRequest,
2413-
})
2414+
}, versionHeaders(ctx)...)
24142415
if err != nil {
24152416
return wh.error(err, scope)
24162417
}
@@ -2973,7 +2974,7 @@ func (wh *WorkflowHandler) ResetStickyTaskList(
29732974
_, err = wh.GetHistoryClient().ResetStickyTaskList(ctx, &h.ResetStickyTaskListRequest{
29742975
DomainUUID: common.StringPtr(domainID),
29752976
Execution: resetRequest.Execution,
2976-
})
2977+
}, versionHeaders(ctx)...)
29772978
if err != nil {
29782979
return nil, wh.error(err, scope)
29792980
}
@@ -3041,7 +3042,7 @@ func (wh *WorkflowHandler) QueryWorkflow(
30413042
DomainUUID: common.StringPtr(domainID),
30423043
Request: queryRequest,
30433044
}
3044-
hResponse, err := wh.GetHistoryClient().QueryWorkflow(ctx, req)
3045+
hResponse, err := wh.GetHistoryClient().QueryWorkflow(ctx, req, versionHeaders(ctx)...)
30453046
if err != nil {
30463047
return nil, wh.error(err, scope)
30473048
}
@@ -3085,7 +3086,7 @@ func (wh *WorkflowHandler) DescribeWorkflowExecution(
30853086
response, err := wh.GetHistoryClient().DescribeWorkflowExecution(ctx, &h.DescribeWorkflowExecutionRequest{
30863087
DomainUUID: common.StringPtr(domainID),
30873088
Request: request,
3088-
})
3089+
}, versionHeaders(ctx)...)
30893090

30903091
if err != nil {
30913092
return nil, wh.error(err, scope)
@@ -3140,7 +3141,7 @@ func (wh *WorkflowHandler) DescribeTaskList(
31403141
response, err = wh.GetMatchingClient().DescribeTaskList(ctx, &m.DescribeTaskListRequest{
31413142
DomainUUID: common.StringPtr(domainID),
31423143
DescRequest: request,
3143-
})
3144+
}, versionHeaders(ctx)...)
31443145
return err
31453146
}
31463147

@@ -3178,7 +3179,7 @@ func (wh *WorkflowHandler) ListTaskListPartitions(ctx context.Context, request *
31783179
resp, err := wh.GetMatchingClient().ListTaskListPartitions(ctx, &m.ListTaskListPartitionsRequest{
31793180
Domain: request.Domain,
31803181
TaskList: request.TaskList,
3181-
})
3182+
}, versionHeaders(ctx)...)
31823183
return resp, err
31833184
}
31843185

@@ -3651,7 +3652,7 @@ func (wh *WorkflowHandler) historyArchived(ctx context.Context, request *gen.Get
36513652
DomainUUID: common.StringPtr(domainID),
36523653
Execution: request.Execution,
36533654
}
3654-
_, err := wh.GetHistoryClient().GetMutableState(ctx, getMutableStateRequest)
3655+
_, err := wh.GetHistoryClient().GetMutableState(ctx, getMutableStateRequest, versionHeaders(ctx)...)
36553656
if err == nil {
36563657
return false
36573658
}
@@ -3777,3 +3778,14 @@ type domainWrapper struct {
37773778
func (d domainWrapper) GetDomain() string {
37783779
return d.domain
37793780
}
3781+
3782+
// TODO: Remove this func after history and matching services gRPC migration is complete
3783+
// It sets version headers in YARPC format
3784+
func versionHeaders(ctx context.Context) []yarpc.CallOption {
3785+
headers := client.GetHeadersValue(ctx, common.LibraryVersionHeaderName, common.FeatureVersionHeaderName, common.ClientImplHeaderName)
3786+
return []yarpc.CallOption{
3787+
yarpc.WithHeader(common.LibraryVersionHeaderName, headers[0]),
3788+
yarpc.WithHeader(common.FeatureVersionHeaderName, headers[1]),
3789+
yarpc.WithHeader(common.ClientImplHeaderName, headers[2]),
3790+
}
3791+
}

service/frontend/workflowHandler_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,8 @@ func (s *workflowHandlerSuite) TestHistoryArchived() {
834834
}
835835
s.False(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain"))
836836

837-
s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1)
837+
// TODO: remove last 3 `gomock.Any()` after YARPC migration
838+
s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).Times(1)
838839
getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{
839840
Execution: &shared.WorkflowExecution{
840841
WorkflowId: common.StringPtr(testWorkflowID),
@@ -843,7 +844,7 @@ func (s *workflowHandlerSuite) TestHistoryArchived() {
843844
}
844845
s.False(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain"))
845846

846-
s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, &shared.EntityNotExistsError{Message: "got archival indication error"}).Times(1)
847+
s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, &shared.EntityNotExistsError{Message: "got archival indication error"}).Times(1)
847848
getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{
848849
Execution: &shared.WorkflowExecution{
849850
WorkflowId: common.StringPtr(testWorkflowID),
@@ -852,7 +853,7 @@ func (s *workflowHandlerSuite) TestHistoryArchived() {
852853
}
853854
s.True(wh.historyArchived(context.Background(), getHistoryRequest, "test-domain"))
854855

855-
s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any()).Return(nil, errors.New("got non-archival indication error")).Times(1)
856+
s.mockHistoryClient.EXPECT().GetMutableState(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, errors.New("got non-archival indication error")).Times(1)
856857
getHistoryRequest = &shared.GetWorkflowExecutionHistoryRequest{
857858
Execution: &shared.WorkflowExecution{
858859
WorkflowId: common.StringPtr(testWorkflowID),

0 commit comments

Comments
 (0)