Skip to content

Commit abf2656

Browse files
committed
Write unit test for each state and update docstrings to use ASL retrier and catcher terms
1 parent 30e8da9 commit abf2656

File tree

2 files changed

+111
-49
lines changed

2 files changed

+111
-49
lines changed

src/stepfunctions/steps/states.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,11 @@ def accept(self, visitor):
254254

255255
def add_retry(self, retry):
256256
"""
257-
Add a Retry block or a list of Retry blocks to the tail end of the list of retriers for the state.
257+
Add a retrier or a list of retriers to the tail end of the list of retriers for the state.
258+
See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
258259
259260
Args:
260-
retry (Retry or list(Retry)): Retry block(s) to add.
261+
retry (Retry or list(Retry)): A retrier or list of retriers to add.
261262
"""
262263
if Field.Retry in self.allowed_fields():
263264
self.retries.extend(retry) if isinstance(retry, list) else self.retries.append(retry)
@@ -266,10 +267,11 @@ def add_retry(self, retry):
266267

267268
def add_catch(self, catch):
268269
"""
269-
Add a Catch block or a list of Catch blocks to the tail end of the list of catchers for the state.
270+
Add a catcher or a list of catchers to the tail end of the list of catchers for the state.
271+
See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
270272
271273
Args:
272-
catch (Catch or list(Catch): Catch block(s) to add.
274+
catch (Catch or list(Catch): catcher or list of catchers to add.
273275
"""
274276
if Field.Catch in self.allowed_fields():
275277
self.catches.extend(catch) if isinstance(catch, list) else self.catches.append(catch)
@@ -491,8 +493,8 @@ def __init__(self, state_id, retry=None, catch=None, **kwargs):
491493
"""
492494
Args:
493495
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
494-
retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors
495-
catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined
496+
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
497+
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
496498
comment (str, optional): Human-readable comment or description. (default: None)
497499
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
498500
parameters (dict, optional): The value of this field becomes the effective input for the state.
@@ -549,8 +551,8 @@ def __init__(self, state_id, retry=None, catch=None, **kwargs):
549551
Args:
550552
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
551553
iterator (State or Chain): State or chain to execute for each of the items in `items_path`.
552-
retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors
553-
catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined
554+
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
555+
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
554556
items_path (str, optional): Path in the input for items to iterate over. (default: '$')
555557
max_concurrency (int, optional): Maximum number of iterations to have running at any given point in time. (default: 0)
556558
comment (str, optional): Human-readable comment or description. (default: None)
@@ -606,8 +608,8 @@ def __init__(self, state_id, retry=None, catch=None, **kwargs):
606608
"""
607609
Args:
608610
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
609-
retry (Retry or list(Retry), optional): Retry block(s) to add to list of Retriers that define a retry policy in case the state encounters runtime errors
610-
catch (Catch or list(Catch), optional): Catch block(s) to add to list of Catchers that define a fallback state that is executed if the state encounters runtime errors and its retry policy is exhausted or isn't defined
611+
retry (Retry or list(Retry), optional): A retrier or list of retriers that define the state's retry policy. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-retrying-after-an-error>`_ for more details.
612+
catch (Catch or list(Catch), optional): A catcher or list of catchers that define a fallback state. See `Error handling in Step Functions <https://docs.aws.amazon.com/step-functions/latest/dg/concepts-error-handling.html#error-handling-fallback-states>`_ for more details.
611613
resource (str): A URI that uniquely identifies the specific task to execute. The States language does not constrain the URI scheme nor any other part of the URI.
612614
timeout_seconds (int, optional): Positive integer specifying timeout for the state in seconds. If the state runs longer than the specified timeout, then the interpreter fails the state with a `States.Timeout` Error Name. (default: 60)
613615
timeout_seconds_path (str, optional): Path specifying the state's timeout value in seconds from the state input. When resolved, the path must select a field whose value is a positive integer.

tests/unit/test_steps.py

Lines changed: 99 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -474,61 +474,121 @@ def test_default_paths_not_converted_to_null():
474474
EXPECTED_RETRY = [{'ErrorEquals': ['ErrorA', 'ErrorB'], 'IntervalSeconds': 1, 'BackoffRate': 2, 'MaxAttempts': 2}]
475475
EXPECTED_RETRIES = EXPECTED_RETRY + [{'ErrorEquals': ['ErrorC'], 'IntervalSeconds': 5}]
476476

477+
478+
@pytest.mark.parametrize("retry, expected_retry", [
479+
(RETRY, EXPECTED_RETRY),
480+
(RETRIES, EXPECTED_RETRIES),
481+
])
482+
def test_parallel_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry):
483+
step = Parallel('Parallel', retry=retry)
484+
assert step.to_dict()['Retry'] == expected_retry
485+
486+
487+
@pytest.mark.parametrize("retry, expected_retry", [
488+
(RETRY, EXPECTED_RETRY),
489+
(RETRIES, EXPECTED_RETRIES),
490+
])
491+
def test_parallel_state_add_retry_adds_retrier_to_retriers(retry, expected_retry):
492+
step = Parallel('Parallel')
493+
step.add_retry(retry)
494+
assert step.to_dict()['Retry'] == expected_retry
495+
496+
497+
@pytest.mark.parametrize("retry, expected_retry", [
498+
(RETRY, EXPECTED_RETRY),
499+
(RETRIES, EXPECTED_RETRIES),
500+
])
501+
def test_map_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry):
502+
step = Map('Map', retry=retry, iterator=Pass('Iterator'))
503+
assert step.to_dict()['Retry'] == expected_retry
504+
505+
506+
@pytest.mark.parametrize("retry, expected_retry", [
507+
(RETRIES, EXPECTED_RETRIES),
508+
(RETRY, EXPECTED_RETRY),
509+
])
510+
def test_map_state_add_retry_adds_retrier_to_retriers(retry, expected_retry):
511+
step = Map('Map', iterator=Pass('Iterator'))
512+
step.add_retry(retry)
513+
assert step.to_dict()['Retry'] == expected_retry
514+
515+
516+
@pytest.mark.parametrize("retry, expected_retry", [
517+
(RETRY, EXPECTED_RETRY),
518+
(RETRIES, EXPECTED_RETRIES)
519+
])
520+
def test_task_state_constructor_with_retry_adds_retrier_to_retriers(retry, expected_retry):
521+
step = Task('Task', retry=retry)
522+
assert step.to_dict()['Retry'] == expected_retry
523+
524+
525+
@pytest.mark.parametrize("retry, expected_retry", [
526+
(RETRY, EXPECTED_RETRY),
527+
(RETRIES, EXPECTED_RETRIES)
528+
])
529+
def test_task_state_add_retry_adds_retrier_to_retriers(retry, expected_retry):
530+
step = Task('Task')
531+
step.add_retry(retry)
532+
assert step.to_dict()['Retry'] == expected_retry
533+
534+
477535
CATCH = Catch(error_equals=['States.ALL'], next_step=Pass('End State'))
478536
CATCHES = [CATCH, Catch(error_equals=['States.TaskFailed'], next_step=Pass('Next State'))]
479537
EXPECTED_CATCH = [{'ErrorEquals': ['States.ALL'], 'Next': 'End State'}]
480538
EXPECTED_CATCHES = EXPECTED_CATCH + [{'ErrorEquals': ['States.TaskFailed'], 'Next': 'Next State'}]
481539

482540

483-
@pytest.mark.parametrize("state, state_id, extra_args, retry, expected_retry", [
484-
(Parallel, 'Parallel', {}, RETRY, EXPECTED_RETRY),
485-
(Parallel, 'Parallel', {}, RETRIES, EXPECTED_RETRIES),
486-
(Map, 'Map', {'iterator': Pass('Iterator')}, RETRY, EXPECTED_RETRY),
487-
(Map, 'Map', {'iterator': Pass('Iterator')}, RETRIES, EXPECTED_RETRIES),
488-
(Task, 'Task', {}, RETRY, EXPECTED_RETRY),
489-
(Task, 'Task', {}, RETRIES, EXPECTED_RETRIES)
541+
@pytest.mark.parametrize("catch, expected_catch", [
542+
(CATCH, EXPECTED_CATCH),
543+
(CATCHES, EXPECTED_CATCHES)
490544
])
491-
def test_state_creation_with_retry(state, state_id, extra_args, retry, expected_retry):
492-
step = state(state_id, retry=retry, **extra_args)
493-
assert step.to_dict()['Retry'] == expected_retry
545+
def test_parallel_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch):
546+
step = Parallel('Parallel', catch=catch)
547+
assert step.to_dict()['Catch'] == expected_catch
548+
549+
@pytest.mark.parametrize("catch, expected_catch", [
550+
(CATCH, EXPECTED_CATCH),
551+
(CATCHES, EXPECTED_CATCHES)
552+
])
553+
def test_parallel_state_add_catch_adds_catcher_to_catchers(catch, expected_catch):
554+
step = Parallel('Parallel')
555+
step.add_catch(catch)
556+
assert step.to_dict()['Catch'] == expected_catch
494557

495558

496-
@pytest.mark.parametrize("state, state_id, extra_args, catch, expected_catch", [
497-
(Parallel, 'Parallel', {}, CATCH, EXPECTED_CATCH),
498-
(Parallel, 'Parallel', {}, CATCHES, EXPECTED_CATCHES),
499-
(Map, 'Map', {'iterator': Pass('Iterator')}, CATCH, EXPECTED_CATCH),
500-
(Map, 'Map', {'iterator': Pass('Iterator')}, CATCHES, EXPECTED_CATCHES),
501-
(Task, 'Task', {}, CATCH, EXPECTED_CATCH),
502-
(Task, 'Task', {}, CATCHES, EXPECTED_CATCHES)
559+
@pytest.mark.parametrize("catch, expected_catch", [
560+
(CATCH, EXPECTED_CATCH),
561+
(CATCHES, EXPECTED_CATCHES)
503562
])
504-
def test_state_creation_with_catch(state, state_id, extra_args, catch, expected_catch):
505-
step = state(state_id, catch=catch, **extra_args)
563+
def test_map_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch):
564+
step = Map('Map', catch=catch, iterator=Pass('Iterator'))
506565
assert step.to_dict()['Catch'] == expected_catch
507566

508567

509-
@pytest.mark.parametrize("state, state_id, extra_args, retry, expected_retry", [
510-
(Parallel, 'Parallel', {}, RETRY, EXPECTED_RETRY),
511-
(Parallel, 'Parallel', {}, RETRIES, EXPECTED_RETRIES),
512-
(Map, 'Map', {'iterator': Pass('Iterator')}, RETRIES, EXPECTED_RETRIES),
513-
(Map, 'Map', {'iterator': Pass('Iterator')}, RETRY, EXPECTED_RETRY),
514-
(Task, 'Task', {}, RETRY, EXPECTED_RETRY),
515-
(Task, 'Task', {}, RETRIES, EXPECTED_RETRIES)
568+
@pytest.mark.parametrize("catch, expected_catch", [
569+
(CATCH, EXPECTED_CATCH),
570+
(CATCHES, EXPECTED_CATCHES)
516571
])
517-
def test_state_with_added_retry(state, state_id, extra_args, retry, expected_retry):
518-
step = state(state_id, **extra_args)
519-
step.add_retry(retry)
520-
assert step.to_dict()['Retry'] == expected_retry
572+
def test_map_state_add_catch_adds_catcher_to_catchers(catch, expected_catch):
573+
step = Map('Map', iterator=Pass('Iterator'))
574+
step.add_catch(catch)
575+
assert step.to_dict()['Catch'] == expected_catch
576+
577+
578+
@pytest.mark.parametrize("catch, expected_catch", [
579+
(CATCH, EXPECTED_CATCH),
580+
(CATCHES, EXPECTED_CATCHES)
581+
])
582+
def test_task_state_constructor_with_catch_adds_catcher_to_catchers(catch, expected_catch):
583+
step = Task('Task', catch=catch)
584+
assert step.to_dict()['Catch'] == expected_catch
521585

522586

523-
@pytest.mark.parametrize("state, state_id, extra_args, catch, expected_catch", [
524-
(Parallel, 'Parallel', {}, CATCH, EXPECTED_CATCH),
525-
(Parallel, 'Parallel', {}, CATCHES, EXPECTED_CATCHES),
526-
(Map, 'Map', {'iterator': Pass('Iterator')}, CATCH, EXPECTED_CATCH),
527-
(Map, 'Map', {'iterator': Pass('Iterator')}, CATCHES, EXPECTED_CATCHES),
528-
(Task, 'Task', {}, CATCHES, EXPECTED_CATCHES),
529-
(Task, 'Task', {}, CATCHES, EXPECTED_CATCHES)
587+
@pytest.mark.parametrize("catch, expected_catch", [
588+
(CATCH, EXPECTED_CATCH),
589+
(CATCHES, EXPECTED_CATCHES)
530590
])
531-
def test_state_with_added_catch(state, state_id, extra_args, catch, expected_catch):
532-
step = state(state_id, **extra_args)
591+
def test_task_state_add_catch_adds_catcher_to_catchers(catch, expected_catch):
592+
step = Task('Task')
533593
step.add_catch(catch)
534594
assert step.to_dict()['Catch'] == expected_catch

0 commit comments

Comments
 (0)