@@ -479,6 +479,59 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
479
479
workflow_test_suite (sfn_client , workflow , asl_state_machine_definition , catch_state_result )
480
480
481
481
482
+ def test_state_machine_creation_with_catch_in_constructor (sfn_client , sfn_role_arn , training_job_parameters ):
483
+ catch_state_name = "TaskWithCatchState"
484
+ all_fail_error = "States.ALL"
485
+ all_error_state_name = "Catch All End"
486
+ catch_state_result = "Catch Result"
487
+ task_resource = f"arn:{ get_aws_partition ()} :states:::sagemaker:createTrainingJob.sync"
488
+
489
+ # change the parameters to cause task state to fail
490
+ training_job_parameters ["AlgorithmSpecification" ]["TrainingImage" ] = "not_an_image"
491
+
492
+ asl_state_machine_definition = {
493
+ "StartAt" : catch_state_name ,
494
+ "States" : {
495
+ catch_state_name : {
496
+ "Resource" : task_resource ,
497
+ "Parameters" : training_job_parameters ,
498
+ "Type" : "Task" ,
499
+ "End" : True ,
500
+ "Catch" : [
501
+ {
502
+ "ErrorEquals" : [
503
+ all_fail_error
504
+ ],
505
+ "Next" : all_error_state_name
506
+ }
507
+ ]
508
+ },
509
+ all_error_state_name : {
510
+ "Type" : "Pass" ,
511
+ "Result" : catch_state_result ,
512
+ "End" : True
513
+ }
514
+ }
515
+ }
516
+ task = steps .Task (
517
+ catch_state_name ,
518
+ parameters = training_job_parameters ,
519
+ resource = task_resource ,
520
+ catch = steps .Catch (
521
+ error_equals = [all_fail_error ],
522
+ next_step = steps .Pass (all_error_state_name , result = catch_state_result )
523
+ )
524
+ )
525
+
526
+ workflow = Workflow (
527
+ unique_name_from_base ('Test_Catch_In_Constructor_Workflow' ),
528
+ definition = task ,
529
+ role = sfn_role_arn
530
+ )
531
+
532
+ workflow_test_suite (sfn_client , workflow , asl_state_machine_definition , catch_state_result )
533
+
534
+
482
535
def test_retry_state_machine_creation (sfn_client , sfn_role_arn , training_job_parameters ):
483
536
retry_state_name = "RetryStateName"
484
537
all_fail_error = "Starts.ALL"
@@ -531,4 +584,56 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par
531
584
role = sfn_role_arn
532
585
)
533
586
534
- workflow_test_suite (sfn_client , workflow , asl_state_machine_definition , None )
587
+ workflow_test_suite (sfn_client , workflow , asl_state_machine_definition , None )
588
+
589
+
590
+ def test_state_machine_creation_with_retry_in_constructor (sfn_client , sfn_role_arn , training_job_parameters ):
591
+ retry_state_name = "RetryStateName"
592
+ all_fail_error = "Starts.ALL"
593
+ interval_seconds = 1
594
+ max_attempts = 2
595
+ backoff_rate = 2
596
+ task_resource = f"arn:{ get_aws_partition ()} :states:::sagemaker:createTrainingJob.sync"
597
+
598
+ # change the parameters to cause task state to fail
599
+ training_job_parameters ["AlgorithmSpecification" ]["TrainingImage" ] = "not_an_image"
600
+
601
+ asl_state_machine_definition = {
602
+ "StartAt" : retry_state_name ,
603
+ "States" : {
604
+ retry_state_name : {
605
+ "Resource" : task_resource ,
606
+ "Parameters" : training_job_parameters ,
607
+ "Type" : "Task" ,
608
+ "End" : True ,
609
+ "Retry" : [
610
+ {
611
+ "ErrorEquals" : [all_fail_error ],
612
+ "IntervalSeconds" : interval_seconds ,
613
+ "MaxAttempts" : max_attempts ,
614
+ "BackoffRate" : backoff_rate
615
+ }
616
+ ]
617
+ }
618
+ }
619
+ }
620
+
621
+ task = steps .Task (
622
+ retry_state_name ,
623
+ parameters = training_job_parameters ,
624
+ resource = task_resource ,
625
+ retry = steps .Retry (
626
+ error_equals = [all_fail_error ],
627
+ interval_seconds = interval_seconds ,
628
+ max_attempts = max_attempts ,
629
+ backoff_rate = backoff_rate
630
+ )
631
+ )
632
+
633
+ workflow = Workflow (
634
+ unique_name_from_base ('Test_Retry_In_Constructor_Workflow' ),
635
+ definition = task ,
636
+ role = sfn_role_arn
637
+ )
638
+
639
+ workflow_test_suite (sfn_client , workflow , asl_state_machine_definition , None )
0 commit comments