Skip to content

Commit 7f223e8

Browse files
authored
fix: Retrier and Catcher passed to constructor for Task, Parallel and Map states are not added to the state's Retriers and Catchers (#169)
1 parent 58bed5d commit 7f223e8

File tree

3 files changed

+239
-64
lines changed

3 files changed

+239
-64
lines changed

src/stepfunctions/steps/states.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -254,27 +254,29 @@ def accept(self, visitor):
254254

255255
def add_retry(self, retry):
256256
"""
257-
Add a Retry block to the tail end of the list of retriers for the state.
257+
Add a retrier or a list of retriers to the tail end of the list of retriers for the state.
258+
See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
258259
259260
Args:
260-
retry (Retry): Retry block to add.
261+
retry (Retry or list(Retry)): A retrier or list of retriers to add.
261262
"""
262263
if Field.Retry in self.allowed_fields():
263-
self.retries.append(retry)
264+
self.retries.extend(retry) if isinstance(retry, list) else self.retries.append(retry)
264265
else:
265-
raise ValueError("{state_type} state does not support retry field. ".format(state_type=type(self).__name__))
266+
raise ValueError(f"{type(self).__name__} state does not support retry field. ")
266267

267268
def add_catch(self, catch):
268269
"""
269-
Add a Catch block to the tail end of the list of catchers for the state.
270+
Add a catcher or a list of catchers to the tail end of the list of catchers for the state.
271+
See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
270272
271273
Args:
272-
catch (Catch): Catch block to add.
274+
catch (Catch or list(Catch): catcher or list of catchers to add.
273275
"""
274276
if Field.Catch in self.allowed_fields():
275-
self.catches.append(catch)
277+
self.catches.extend(catch) if isinstance(catch, list) else self.catches.append(catch)
276278
else:
277-
raise ValueError("{state_type} state does not support catch field. ".format(state_type=type(self).__name__))
279+
raise ValueError(f"{type(self).__name__} state does not support catch field. ")
278280

279281
def to_dict(self):
280282
result = super(State, self).to_dict()
@@ -487,10 +489,12 @@ class Parallel(State):
487489
A Parallel state causes the interpreter to execute each branch as concurrently as possible, and wait until each branch terminates (reaches a terminal state) before processing the next state in the Chain.
488490
"""
489491

490-
def __init__(self, state_id, **kwargs):
492+
def __init__(self, state_id, retry=None, catch=None, **kwargs):
491493
"""
492494
Args:
493495
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
496+
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
497+
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
494498
comment (str, optional): Human-readable comment or description. (default: None)
495499
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
496500
parameters (dict, optional): The value of this field becomes the effective input for the state.
@@ -500,6 +504,12 @@ def __init__(self, state_id, **kwargs):
500504
super(Parallel, self).__init__(state_id, 'Parallel', **kwargs)
501505
self.branches = []
502506

507+
if retry:
508+
self.add_retry(retry)
509+
510+
if catch:
511+
self.add_catch(catch)
512+
503513
def allowed_fields(self):
504514
return [
505515
Field.Comment,
@@ -536,11 +546,13 @@ class Map(State):
536546
A Map state can accept an input with a list of items, execute a state or chain for each item in the list, and return a list, with all corresponding results of each execution, as its output.
537547
"""
538548

539-
def __init__(self, state_id, **kwargs):
549+
def __init__(self, state_id, retry=None, catch=None, **kwargs):
540550
"""
541551
Args:
542552
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
543553
iterator (State or Chain): State or chain to execute for each of the items in `items_path`.
554+
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
555+
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
544556
items_path (str, optional): Path in the input for items to iterate over. (default: '$')
545557
max_concurrency (int, optional): Maximum number of iterations to have running at any given point in time. (default: 0)
546558
comment (str, optional): Human-readable comment or description. (default: None)
@@ -551,6 +563,12 @@ def __init__(self, state_id, **kwargs):
551563
"""
552564
super(Map, self).__init__(state_id, 'Map', **kwargs)
553565

566+
if retry:
567+
self.add_retry(retry)
568+
569+
if catch:
570+
self.add_catch(catch)
571+
554572
def attach_iterator(self, iterator):
555573
"""
556574
Attach `State` or `Chain` as iterator to the Map state, that will execute for each of the items in `items_path`. If an iterator was attached previously with the Map state, it will be replaced.
@@ -586,10 +604,12 @@ class Task(State):
586604
Task State causes the interpreter to execute the work identified by the state’s `resource` field.
587605
"""
588606

589-
def __init__(self, state_id, **kwargs):
607+
def __init__(self, state_id, retry=None, catch=None, **kwargs):
590608
"""
591609
Args:
592610
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
611+
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
612+
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
593613
resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI.
594614
timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60)
595615
timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.
@@ -608,6 +628,12 @@ def __init__(self, state_id, **kwargs):
608628
if self.heartbeat_seconds is not None and self.heartbeat_seconds_path is not None:
609629
raise ValueError("Only one of 'heartbeat_seconds' or 'heartbeat_seconds_path' can be provided.")
610630

631+
if retry:
632+
self.add_retry(retry)
633+
634+
if catch:
635+
self.add_catch(catch)
636+
611637
def allowed_fields(self):
612638
return [
613639
Field.Comment,

tests/integ/test_state_machine_definition.py

Lines changed: 79 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -422,18 +422,38 @@ def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_para
422422

423423
def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters):
424424
catch_state_name = "TaskWithCatchState"
425-
custom_error = "CustomError"
426425
task_failed_error = "States.TaskFailed"
427-
all_fail_error = "States.ALL"
428-
custom_error_state_name = "Custom Error End"
429-
task_failed_state_name = "Task Failed End"
430-
all_error_state_name = "Catch All End"
426+
timeout_error = "States.Timeout"
427+
task_failed_state_name = "Catch Task Failed End"
428+
timeout_state_name = "Catch Timeout End"
431429
catch_state_result = "Catch Result"
432430
task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync"
433431

434-
# change the parameters to cause task state to fail
432+
# Provide invalid TrainingImage to cause States.TaskFailed error
435433
training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image"
436434

435+
task = steps.Task(
436+
catch_state_name,
437+
parameters=training_job_parameters,
438+
resource=task_resource,
439+
catch=steps.Catch(
440+
error_equals=[timeout_error],
441+
next_step=steps.Pass(timeout_state_name, result=catch_state_result)
442+
)
443+
)
444+
task.add_catch(
445+
steps.Catch(
446+
error_equals=[task_failed_error],
447+
next_step=steps.Pass(task_failed_state_name, result=catch_state_result)
448+
)
449+
)
450+
451+
workflow = Workflow(
452+
unique_name_from_base('Test_Catch_Workflow'),
453+
definition=task,
454+
role=sfn_role_arn
455+
)
456+
437457
asl_state_machine_definition = {
438458
"StartAt": catch_state_name,
439459
"States": {
@@ -445,80 +465,61 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
445465
"Catch": [
446466
{
447467
"ErrorEquals": [
448-
all_fail_error
468+
timeout_error
449469
],
450-
"Next": all_error_state_name
470+
"Next": timeout_state_name
471+
},
472+
{
473+
"ErrorEquals": [
474+
task_failed_error
475+
],
476+
"Next": task_failed_state_name
451477
}
452478
]
453479
},
454-
all_error_state_name: {
480+
task_failed_state_name: {
455481
"Type": "Pass",
456482
"Result": catch_state_result,
457483
"End": True
458-
}
484+
},
485+
timeout_state_name: {
486+
"Type": "Pass",
487+
"Result": catch_state_result,
488+
"End": True
489+
},
459490
}
460491
}
461-
task = steps.Task(
462-
catch_state_name,
463-
parameters=training_job_parameters,
464-
resource=task_resource
465-
)
466-
task.add_catch(
467-
steps.Catch(
468-
error_equals=[all_fail_error],
469-
next_step=steps.Pass(all_error_state_name, result=catch_state_result)
470-
)
471-
)
472-
473-
workflow = Workflow(
474-
unique_name_from_base('Test_Catch_Workflow'),
475-
definition=task,
476-
role=sfn_role_arn
477-
)
478492

479493
workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, catch_state_result)
480494

481495

482496
def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters):
483497
retry_state_name = "RetryStateName"
484-
all_fail_error = "Starts.ALL"
498+
task_failed_error = "States.TaskFailed"
499+
timeout_error = "States.Timeout"
485500
interval_seconds = 1
486501
max_attempts = 2
487502
backoff_rate = 2
488503
task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync"
489504

490-
# change the parameters to cause task state to fail
505+
# Provide invalid TrainingImage to cause States.TaskFailed error
491506
training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image"
492507

493-
asl_state_machine_definition = {
494-
"StartAt": retry_state_name,
495-
"States": {
496-
retry_state_name: {
497-
"Resource": task_resource,
498-
"Parameters": training_job_parameters,
499-
"Type": "Task",
500-
"End": True,
501-
"Retry": [
502-
{
503-
"ErrorEquals": [all_fail_error],
504-
"IntervalSeconds": interval_seconds,
505-
"MaxAttempts": max_attempts,
506-
"BackoffRate": backoff_rate
507-
}
508-
]
509-
}
510-
}
511-
}
512-
513508
task = steps.Task(
514509
retry_state_name,
515510
parameters=training_job_parameters,
516-
resource=task_resource
511+
resource=task_resource,
512+
retry=steps.Retry(
513+
error_equals=[timeout_error],
514+
interval_seconds=interval_seconds,
515+
max_attempts=max_attempts,
516+
backoff_rate=backoff_rate
517+
)
517518
)
518519

519520
task.add_retry(
520521
steps.Retry(
521-
error_equals=[all_fail_error],
522+
error_equals=[task_failed_error],
522523
interval_seconds=interval_seconds,
523524
max_attempts=max_attempts,
524525
backoff_rate=backoff_rate
@@ -531,4 +532,30 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
531532
role=sfn_role_arn
532533
)
533534

534-
workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None)
535+
asl_state_machine_definition = {
536+
"StartAt": retry_state_name,
537+
"States": {
538+
retry_state_name: {
539+
"Resource": task_resource,
540+
"Parameters": training_job_parameters,
541+
"Type": "Task",
542+
"End": True,
543+
"Retry": [
544+
{
545+
"ErrorEquals": [timeout_error],
546+
"IntervalSeconds": interval_seconds,
547+
"MaxAttempts": max_attempts,
548+
"BackoffRate": backoff_rate
549+
},
550+
{
551+
"ErrorEquals": [task_failed_error],
552+
"IntervalSeconds": interval_seconds,
553+
"MaxAttempts": max_attempts,
554+
"BackoffRate": backoff_rate
555+
}
556+
]
557+
}
558+
}
559+
}
560+
561+
workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, None)

0 commit comments

Comments
 (0)