Skip to content

Commit ae5cad8

Browse files
authored
Fix workflow ID reuse when running on ScyllaDB (temporalio#3027)
1 parent 8d82a8c commit ae5cad8

File tree

3 files changed

+60
-34
lines changed

3 files changed

+60
-34
lines changed

common/persistence/cassandra/errors.go

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,18 @@ type (
5656
}
5757
)
5858

59+
// ScyllaDB will return rows with null values to match # of queries in a batch query (see #2683).
60+
// To support null values, fields type should be a pointer to pointer of underlying type (i.e. **int).
61+
// Resulting value will be converted to a pointer of underlying type (i.e. *int) and stored in the map.
62+
// We do it only for "type" field which is checked for `nil` value.
63+
// All other fields are created automatically by gocql with non-pointer types (i.e. int).
64+
func newConflictRecord() map[string]interface{} {
65+
t := new(int)
66+
return map[string]interface{}{
67+
"type": &t,
68+
}
69+
}
70+
5971
func convertErrors(
6072
conflictRecord map[string]interface{},
6173
conflictIter gocql.Iter,
@@ -74,7 +86,7 @@ func convertErrors(
7486
requestExecutionCASConditions,
7587
)
7688

77-
conflictRecord = make(map[string]interface{})
89+
conflictRecord = newConflictRecord()
7890
for conflictIter.MapScan(conflictRecord) {
7991
if conflictRecord["[applied]"].(bool) {
8092
// Should never happen. All records in batch should have [applied]=false.
@@ -90,7 +102,7 @@ func convertErrors(
90102
requestExecutionCASConditions,
91103
)...)
92104

93-
conflictRecord = make(map[string]interface{})
105+
conflictRecord = newConflictRecord()
94106
}
95107

96108
if len(errors) == 0 {
@@ -120,7 +132,6 @@ func extractErrors(
120132
) []error {
121133

122134
var errors []error
123-
124135
if err := extractShardOwnershipLostError(
125136
conflictRecord,
126137
requestShardID,
@@ -172,12 +183,12 @@ func extractShardOwnershipLostError(
172183
requestShardID int32,
173184
requestRangeID int64,
174185
) error {
175-
rowType, ok := conflictRecord["type"].(int)
176-
if !ok {
177-
// this case should not happen, maybe panic?
186+
rowType, ok := conflictRecord["type"].(*int)
187+
if !ok || rowType == nil {
188+
// This can happen on ScyllaDB.
178189
return nil
179190
}
180-
if rowType != rowTypeShard {
191+
if *rowType != rowTypeShard {
181192
return nil
182193
}
183194

@@ -198,12 +209,12 @@ func extractCurrentWorkflowConflictError(
198209
conflictRecord map[string]interface{},
199210
requestCurrentRunID string,
200211
) error {
201-
rowType, ok := conflictRecord["type"].(int)
202-
if !ok {
203-
// this case should not happen, maybe panic?
212+
rowType, ok := conflictRecord["type"].(*int)
213+
if !ok || rowType == nil {
214+
// This can happen on ScyllaDB.
204215
return nil
205216
}
206-
if rowType != rowTypeExecution {
217+
if *rowType != rowTypeExecution {
207218
return nil
208219
}
209220
if runID := gocql.UUIDToString(conflictRecord["run_id"]); runID != permanentRunID {
@@ -248,12 +259,12 @@ func extractWorkflowConflictError(
248259
requestDBVersion int64,
249260
requestNextEventID int64, // TODO deprecate this variable once DB version comparison is the default
250261
) error {
251-
rowType, ok := conflictRecord["type"].(int)
252-
if !ok {
253-
// this case should not happen, maybe panic?
262+
rowType, ok := conflictRecord["type"].(*int)
263+
if !ok || rowType == nil {
264+
// This can happen on ScyllaDB.
254265
return nil
255266
}
256-
if rowType != rowTypeExecution {
267+
if *rowType != rowTypeExecution {
257268
return nil
258269
}
259270
if runID := gocql.UUIDToString(conflictRecord["run_id"]); runID != requestRunID {

common/persistence/cassandra/errors_test.go

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -164,23 +164,26 @@ func (s *cassandraErrorsSuite) TestExtractShardOwnershipLostError_Failed() {
164164
err := extractShardOwnershipLostError(map[string]interface{}{}, rand.Int31(), rangeID)
165165
s.NoError(err)
166166

167+
t := rowTypeExecution
167168
err = extractShardOwnershipLostError(map[string]interface{}{
168-
"type": rowTypeExecution,
169+
"type": &t,
169170
"range_id": rangeID,
170171
}, rand.Int31(), rangeID)
171172
s.NoError(err)
172173

174+
t = rowTypeShard
173175
err = extractShardOwnershipLostError(map[string]interface{}{
174-
"type": rowTypeShard,
176+
"type": &t,
175177
"range_id": rangeID,
176178
}, rand.Int31(), rangeID)
177179
s.NoError(err)
178180
}
179181

180182
func (s *cassandraErrorsSuite) TestExtractShardOwnershipLostError_Success() {
181183
rangeID := int64(1234)
184+
t := rowTypeShard
182185
record := map[string]interface{}{
183-
"type": rowTypeShard,
186+
"type": &t,
184187
"range_id": rangeID,
185188
}
186189

@@ -195,22 +198,25 @@ func (s *cassandraErrorsSuite) TestExtractCurrentWorkflowConflictError_Failed()
195198
err := extractCurrentWorkflowConflictError(map[string]interface{}{}, uuid.New().String())
196199
s.NoError(err)
197200

201+
t := rowTypeShard
198202
err = extractCurrentWorkflowConflictError(map[string]interface{}{
199-
"type": rowTypeShard,
203+
"type": &t,
200204
"run_id": gocql.UUID(runID),
201205
"current_run_id": gocql.UUID(currentRunID),
202206
}, uuid.New().String())
203207
s.NoError(err)
204208

209+
t = rowTypeExecution
205210
err = extractCurrentWorkflowConflictError(map[string]interface{}{
206-
"type": rowTypeExecution,
211+
"type": &t,
207212
"run_id": gocql.UUID([16]byte{}),
208213
"current_run_id": gocql.UUID(currentRunID),
209214
}, uuid.New().String())
210215
s.NoError(err)
211216

217+
t = rowTypeExecution
212218
err = extractCurrentWorkflowConflictError(map[string]interface{}{
213-
"type": rowTypeExecution,
219+
"type": &t,
214220
"run_id": gocql.UUID(runID),
215221
"current_run_id": gocql.UUID(currentRunID),
216222
}, currentRunID.String())
@@ -223,8 +229,9 @@ func (s *cassandraErrorsSuite) TestExtractCurrentWorkflowConflictError_Success()
223229
workflowState := &persistencespb.WorkflowExecutionState{}
224230
blob, err := serialization.WorkflowExecutionStateToBlob(workflowState)
225231
s.NoError(err)
232+
t := rowTypeExecution
226233
record := map[string]interface{}{
227-
"type": rowTypeExecution,
234+
"type": &t,
228235
"run_id": gocql.UUID(runID),
229236
"current_run_id": gocql.UUID(currentRunID),
230237
"execution_state": blob.Data,
@@ -243,22 +250,25 @@ func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Failed() {
243250
err := extractWorkflowConflictError(map[string]interface{}{}, runID.String(), dbVersion, rand.Int63())
244251
s.NoError(err)
245252

253+
t := rowTypeShard
246254
err = extractWorkflowConflictError(map[string]interface{}{
247-
"type": rowTypeShard,
255+
"type": &t,
248256
"run_id": gocql.UUID(runID),
249257
"db_record_version": dbVersion,
250258
}, runID.String(), dbVersion+1, rand.Int63())
251259
s.NoError(err)
252260

261+
t = rowTypeExecution
253262
err = extractWorkflowConflictError(map[string]interface{}{
254-
"type": rowTypeExecution,
263+
"type": &t,
255264
"run_id": gocql.UUID([16]byte{}),
256265
"db_record_version": dbVersion,
257266
}, runID.String(), dbVersion+1, rand.Int63())
258267
s.NoError(err)
259268

269+
t = rowTypeExecution
260270
err = extractWorkflowConflictError(map[string]interface{}{
261-
"type": rowTypeExecution,
271+
"type": &t,
262272
"run_id": gocql.UUID(runID),
263273
"db_record_version": dbVersion,
264274
}, runID.String(), dbVersion, rand.Int63())
@@ -268,8 +278,9 @@ func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Failed() {
268278
func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Success() {
269279
runID := uuid.New()
270280
dbVersion := rand.Int63() + 1
281+
t := rowTypeExecution
271282
record := map[string]interface{}{
272-
"type": rowTypeExecution,
283+
"type": &t,
273284
"run_id": gocql.UUID(runID),
274285
"db_record_version": dbVersion,
275286
}
@@ -286,22 +297,25 @@ func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Failed_NextEvent
286297
err := extractWorkflowConflictError(map[string]interface{}{}, runID.String(), 0, nextEventID)
287298
s.NoError(err)
288299

300+
t := rowTypeShard
289301
err = extractWorkflowConflictError(map[string]interface{}{
290-
"type": rowTypeShard,
302+
"type": &t,
291303
"run_id": gocql.UUID(runID),
292304
"next_event_id": nextEventID + 1,
293305
}, runID.String(), 0, nextEventID)
294306
s.NoError(err)
295307

308+
t = rowTypeExecution
296309
err = extractWorkflowConflictError(map[string]interface{}{
297-
"type": rowTypeExecution,
310+
"type": &t,
298311
"run_id": gocql.UUID([16]byte{}),
299312
"next_event_id": nextEventID + 1,
300313
}, runID.String(), 0, nextEventID)
301314
s.NoError(err)
302315

316+
t = rowTypeExecution
303317
err = extractWorkflowConflictError(map[string]interface{}{
304-
"type": rowTypeExecution,
318+
"type": &t,
305319
"run_id": gocql.UUID(runID),
306320
"next_event_id": nextEventID,
307321
}, runID.String(), 0, nextEventID)
@@ -312,8 +326,9 @@ func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Failed_NextEvent
312326
func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Success_NextEventID() {
313327
runID := uuid.New()
314328
nextEventID := int64(1234)
329+
t := rowTypeExecution
315330
record := map[string]interface{}{
316-
"type": rowTypeExecution,
331+
"type": &t,
317332
"run_id": gocql.UUID(runID),
318333
"next_event_id": nextEventID,
319334
}

common/persistence/cassandra/mutable_state_store.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ func (d *MutableStateStore) CreateWorkflowExecution(
453453
request.RangeID,
454454
)
455455

456-
conflictRecord := make(map[string]interface{})
456+
conflictRecord := newConflictRecord()
457457
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
458458
if err != nil {
459459
return nil, gocql.ConvertError("CreateWorkflowExecution", err)
@@ -677,7 +677,7 @@ func (d *MutableStateStore) UpdateWorkflowExecution(
677677
request.RangeID,
678678
)
679679

680-
conflictRecord := make(map[string]interface{})
680+
conflictRecord := newConflictRecord()
681681
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
682682
if err != nil {
683683
return gocql.ConvertError("UpdateWorkflowExecution", err)
@@ -828,7 +828,7 @@ func (d *MutableStateStore) ConflictResolveWorkflowExecution(
828828
request.RangeID,
829829
)
830830

831-
conflictRecord := make(map[string]interface{})
831+
conflictRecord := newConflictRecord()
832832
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
833833
if err != nil {
834834
return gocql.ConvertError("ConflictResolveWorkflowExecution", err)
@@ -998,7 +998,7 @@ func (d *MutableStateStore) SetWorkflowExecution(
998998
request.RangeID,
999999
)
10001000

1001-
conflictRecord := make(map[string]interface{})
1001+
conflictRecord := newConflictRecord()
10021002
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
10031003
if err != nil {
10041004
return gocql.ConvertError("SetWorkflowExecution", err)

0 commit comments

Comments
 (0)