Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions common/persistence/cassandra/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ type (
}
)

// ScyllaDB will return rows with null values to match # of queries in a batch query (see #2683).
// To support null values, fields type should be a pointer to pointer of underlying type (i.e. **int).
// Resulting value will be converted to a pointer of underlying type (i.e. *int) and stored in the map.
// We do it only for "type" field which is checked for `nil` value.
// All other fields are created automatically by gocql with non-pointer types (i.e. int).
func newConflictRecord() map[string]interface{} {
t := new(int)
return map[string]interface{}{
"type": &t,
}
}

func convertErrors(
conflictRecord map[string]interface{},
conflictIter gocql.Iter,
Expand All @@ -74,7 +86,7 @@ func convertErrors(
requestExecutionCASConditions,
)

conflictRecord = make(map[string]interface{})
conflictRecord = newConflictRecord()
for conflictIter.MapScan(conflictRecord) {
if conflictRecord["[applied]"].(bool) {
// Should never happen. All records in batch should have [applied]=false.
Expand All @@ -90,7 +102,7 @@ func convertErrors(
requestExecutionCASConditions,
)...)

conflictRecord = make(map[string]interface{})
conflictRecord = newConflictRecord()
}

if len(errors) == 0 {
Expand Down Expand Up @@ -120,7 +132,6 @@ func extractErrors(
) []error {

var errors []error

if err := extractShardOwnershipLostError(
conflictRecord,
requestShardID,
Expand Down Expand Up @@ -172,12 +183,12 @@ func extractShardOwnershipLostError(
requestShardID int32,
requestRangeID int64,
) error {
rowType, ok := conflictRecord["type"].(int)
if !ok {
// this case should not happen, maybe panic?
rowType, ok := conflictRecord["type"].(*int)
if !ok || rowType == nil {
// This can happen on ScyllaDB.
return nil
}
if rowType != rowTypeShard {
if *rowType != rowTypeShard {
return nil
}

Expand All @@ -198,12 +209,12 @@ func extractCurrentWorkflowConflictError(
conflictRecord map[string]interface{},
requestCurrentRunID string,
) error {
rowType, ok := conflictRecord["type"].(int)
if !ok {
// this case should not happen, maybe panic?
rowType, ok := conflictRecord["type"].(*int)
if !ok || rowType == nil {
// This can happen on ScyllaDB.
return nil
}
if rowType != rowTypeExecution {
if *rowType != rowTypeExecution {
return nil
}
if runID := gocql.UUIDToString(conflictRecord["run_id"]); runID != permanentRunID {
Expand Down Expand Up @@ -248,12 +259,12 @@ func extractWorkflowConflictError(
requestDBVersion int64,
requestNextEventID int64, // TODO deprecate this variable once DB version comparison is the default
) error {
rowType, ok := conflictRecord["type"].(int)
if !ok {
// this case should not happen, maybe panic?
rowType, ok := conflictRecord["type"].(*int)
if !ok || rowType == nil {
// This can happen on ScyllaDB.
return nil
}
if rowType != rowTypeExecution {
if *rowType != rowTypeExecution {
return nil
}
if runID := gocql.UUIDToString(conflictRecord["run_id"]); runID != requestRunID {
Expand Down
45 changes: 30 additions & 15 deletions common/persistence/cassandra/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,26 @@ func (s *cassandraErrorsSuite) TestExtractShardOwnershipLostError_Failed() {
err := extractShardOwnershipLostError(map[string]interface{}{}, rand.Int31(), rangeID)
s.NoError(err)

t := rowTypeExecution
err = extractShardOwnershipLostError(map[string]interface{}{
"type": rowTypeExecution,
"type": &t,
"range_id": rangeID,
}, rand.Int31(), rangeID)
s.NoError(err)

t = rowTypeShard
err = extractShardOwnershipLostError(map[string]interface{}{
"type": rowTypeShard,
"type": &t,
"range_id": rangeID,
}, rand.Int31(), rangeID)
s.NoError(err)
}

func (s *cassandraErrorsSuite) TestExtractShardOwnershipLostError_Success() {
rangeID := int64(1234)
t := rowTypeShard
record := map[string]interface{}{
"type": rowTypeShard,
"type": &t,
"range_id": rangeID,
}

Expand All @@ -195,22 +198,25 @@ func (s *cassandraErrorsSuite) TestExtractCurrentWorkflowConflictError_Failed()
err := extractCurrentWorkflowConflictError(map[string]interface{}{}, uuid.New().String())
s.NoError(err)

t := rowTypeShard
err = extractCurrentWorkflowConflictError(map[string]interface{}{
"type": rowTypeShard,
"type": &t,
"run_id": gocql.UUID(runID),
"current_run_id": gocql.UUID(currentRunID),
}, uuid.New().String())
s.NoError(err)

t = rowTypeExecution
err = extractCurrentWorkflowConflictError(map[string]interface{}{
"type": rowTypeExecution,
"type": &t,
"run_id": gocql.UUID([16]byte{}),
"current_run_id": gocql.UUID(currentRunID),
}, uuid.New().String())
s.NoError(err)

t = rowTypeExecution
err = extractCurrentWorkflowConflictError(map[string]interface{}{
"type": rowTypeExecution,
"type": &t,
"run_id": gocql.UUID(runID),
"current_run_id": gocql.UUID(currentRunID),
}, currentRunID.String())
Expand All @@ -223,8 +229,9 @@ func (s *cassandraErrorsSuite) TestExtractCurrentWorkflowConflictError_Success()
workflowState := &persistencespb.WorkflowExecutionState{}
blob, err := serialization.WorkflowExecutionStateToBlob(workflowState)
s.NoError(err)
t := rowTypeExecution
record := map[string]interface{}{
"type": rowTypeExecution,
"type": &t,
"run_id": gocql.UUID(runID),
"current_run_id": gocql.UUID(currentRunID),
"execution_state": blob.Data,
Expand All @@ -243,22 +250,25 @@ func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Failed() {
err := extractWorkflowConflictError(map[string]interface{}{}, runID.String(), dbVersion, rand.Int63())
s.NoError(err)

t := rowTypeShard
err = extractWorkflowConflictError(map[string]interface{}{
"type": rowTypeShard,
"type": &t,
"run_id": gocql.UUID(runID),
"db_record_version": dbVersion,
}, runID.String(), dbVersion+1, rand.Int63())
s.NoError(err)

t = rowTypeExecution
err = extractWorkflowConflictError(map[string]interface{}{
"type": rowTypeExecution,
"type": &t,
"run_id": gocql.UUID([16]byte{}),
"db_record_version": dbVersion,
}, runID.String(), dbVersion+1, rand.Int63())
s.NoError(err)

t = rowTypeExecution
err = extractWorkflowConflictError(map[string]interface{}{
"type": rowTypeExecution,
"type": &t,
"run_id": gocql.UUID(runID),
"db_record_version": dbVersion,
}, runID.String(), dbVersion, rand.Int63())
Expand All @@ -268,8 +278,9 @@ func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Failed() {
func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Success() {
runID := uuid.New()
dbVersion := rand.Int63() + 1
t := rowTypeExecution
record := map[string]interface{}{
"type": rowTypeExecution,
"type": &t,
"run_id": gocql.UUID(runID),
"db_record_version": dbVersion,
}
Expand All @@ -286,22 +297,25 @@ func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Failed_NextEvent
err := extractWorkflowConflictError(map[string]interface{}{}, runID.String(), 0, nextEventID)
s.NoError(err)

t := rowTypeShard
err = extractWorkflowConflictError(map[string]interface{}{
"type": rowTypeShard,
"type": &t,
"run_id": gocql.UUID(runID),
"next_event_id": nextEventID + 1,
}, runID.String(), 0, nextEventID)
s.NoError(err)

t = rowTypeExecution
err = extractWorkflowConflictError(map[string]interface{}{
"type": rowTypeExecution,
"type": &t,
"run_id": gocql.UUID([16]byte{}),
"next_event_id": nextEventID + 1,
}, runID.String(), 0, nextEventID)
s.NoError(err)

t = rowTypeExecution
err = extractWorkflowConflictError(map[string]interface{}{
"type": rowTypeExecution,
"type": &t,
"run_id": gocql.UUID(runID),
"next_event_id": nextEventID,
}, runID.String(), 0, nextEventID)
Expand All @@ -312,8 +326,9 @@ func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Failed_NextEvent
func (s *cassandraErrorsSuite) TestExtractWorkflowConflictError_Success_NextEventID() {
runID := uuid.New()
nextEventID := int64(1234)
t := rowTypeExecution
record := map[string]interface{}{
"type": rowTypeExecution,
"type": &t,
"run_id": gocql.UUID(runID),
"next_event_id": nextEventID,
}
Expand Down
8 changes: 4 additions & 4 deletions common/persistence/cassandra/mutable_state_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ func (d *MutableStateStore) CreateWorkflowExecution(
request.RangeID,
)

conflictRecord := make(map[string]interface{})
conflictRecord := newConflictRecord()
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
if err != nil {
return nil, gocql.ConvertError("CreateWorkflowExecution", err)
Expand Down Expand Up @@ -677,7 +677,7 @@ func (d *MutableStateStore) UpdateWorkflowExecution(
request.RangeID,
)

conflictRecord := make(map[string]interface{})
conflictRecord := newConflictRecord()
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
if err != nil {
return gocql.ConvertError("UpdateWorkflowExecution", err)
Expand Down Expand Up @@ -828,7 +828,7 @@ func (d *MutableStateStore) ConflictResolveWorkflowExecution(
request.RangeID,
)

conflictRecord := make(map[string]interface{})
conflictRecord := newConflictRecord()
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
if err != nil {
return gocql.ConvertError("ConflictResolveWorkflowExecution", err)
Expand Down Expand Up @@ -998,7 +998,7 @@ func (d *MutableStateStore) SetWorkflowExecution(
request.RangeID,
)

conflictRecord := make(map[string]interface{})
conflictRecord := newConflictRecord()
applied, conflictIter, err := d.Session.MapExecuteBatchCAS(batch, conflictRecord)
if err != nil {
return gocql.ConvertError("SetWorkflowExecution", err)
Expand Down