From 7a3887753b853d2eefddfc2657c260cc3711da67 Mon Sep 17 00:00:00 2001 From: Bowen Yuan Date: Tue, 4 May 2021 20:33:36 -0700 Subject: [PATCH 1/6] fix: make arns of all task resources aws-partition aware --- src/stepfunctions/steps/compute.py | 17 +++--- src/stepfunctions/steps/sagemaker.py | 26 ++++---- src/stepfunctions/steps/service.py | 37 ++++++------ src/stepfunctions/steps/utils.py | 29 ++++++++- tests/integ/conftest.py | 5 +- tests/integ/test_state_machine_definition.py | 7 ++- tests/unit/test_steps.py | 62 +++++++++++++++++--- 7 files changed, 130 insertions(+), 53 deletions(-) diff --git a/src/stepfunctions/steps/compute.py b/src/stepfunctions/steps/compute.py index da29b4c..871d200 100644 --- a/src/stepfunctions/steps/compute.py +++ b/src/stepfunctions/steps/compute.py @@ -14,6 +14,7 @@ from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field +from stepfunctions.steps.utils import get_aws_partition class LambdaStep(Task): @@ -38,9 +39,9 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke.waitForTaskToken' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::lambda:invoke.waitForTaskToken' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::lambda:invoke' super(LambdaStep, self).__init__(state_id, **kwargs) @@ -67,9 +68,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun.sync' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::glue:startJobRun.sync' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::glue:startJobRun' super(GlueStartJobRunStep, self).__init__(state_id, **kwargs) @@ -96,9 +97,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob.sync' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::batch:submitJob.sync' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::batch:submitJob' super(BatchSubmitJobStep, self).__init__(state_id, **kwargs) @@ -125,8 +126,8 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask.sync' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::ecs:runTask.sync' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::ecs:runTask' super(EcsRunTaskStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index acc7506..93b826a 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -15,7 +15,7 @@ from stepfunctions.inputs import ExecutionInput, StepInput from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import tags_dict_to_kv_list +from stepfunctions.steps.utils import tags_dict_to_kv_list, get_aws_partition from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config from sagemaker.model import Model, FrameworkModel @@ -58,9 +58,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non self.job_name = job_name if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob.sync' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTrainingJob.sync' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTrainingJob' if isinstance(job_name, str): parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size) @@ -141,9 +141,9 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= join_source (str): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None. """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob.sync' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTransformJob.sync' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTransformJob' if isinstance(job_name, str): parameters = transform_config( @@ -225,7 +225,7 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No parameters['Tags'] = tags_dict_to_kv_list(tags) kwargs[Field.Parameters.value] = parameters - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createModel' super(ModelStep, self).__init__(state_id, **kwargs) @@ -266,7 +266,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_ if tags: parameters['Tags'] = tags_dict_to_kv_list(tags) - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpointConfig' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createEndpointConfig' kwargs[Field.Parameters.value] = parameters super(EndpointConfigStep, self).__init__(state_id, **kwargs) @@ -298,9 +298,9 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd parameters['Tags'] = tags_dict_to_kv_list(tags) if update: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:updateEndpoint' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:updateEndpoint' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpoint' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createEndpoint' kwargs[Field.Parameters.value] = parameters @@ -338,9 +338,9 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta tags (list[dict], optional): `List to tags `_ to associate with the resource. """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createHyperParameterTuningJob.sync' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createHyperParameterTuningJob' parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy() @@ -387,9 +387,9 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp tags (list[dict], optional): `List to tags `_ to associate with the resource. """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob.sync' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createProcessingJob.sync' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createProcessingJob' if isinstance(job_name, str): parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) diff --git a/src/stepfunctions/steps/service.py b/src/stepfunctions/steps/service.py index 20509f8..4952f70 100644 --- a/src/stepfunctions/steps/service.py +++ b/src/stepfunctions/steps/service.py @@ -14,6 +14,7 @@ from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field +from stepfunctions.steps.utils import get_aws_partition class DynamoDBGetItemStep(Task): @@ -35,7 +36,7 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:getItem' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:getItem' super(DynamoDBGetItemStep, self).__init__(state_id, **kwargs) @@ -59,7 +60,7 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:putItem' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:putItem' super(DynamoDBPutItemStep, self).__init__(state_id, **kwargs) @@ -83,7 +84,7 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:deleteItem' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:deleteItem' super(DynamoDBDeleteItemStep, self).__init__(state_id, **kwargs) @@ -107,7 +108,7 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:updateItem' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:updateItem' super(DynamoDBUpdateItemStep, self).__init__(state_id, **kwargs) @@ -133,9 +134,9 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:aws:states:::sns:publish.waitForTaskToken' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sns:publish.waitForTaskToken' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sns:publish' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sns:publish' super(SnsPublishStep, self).__init__(state_id, **kwargs) @@ -162,9 +163,9 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage.waitForTaskToken' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sqs:sendMessage.waitForTaskToken' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sqs:sendMessage' super(SqsSendMessageStep, self).__init__(state_id, **kwargs) @@ -190,9 +191,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster.sync' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:createCluster.sync' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:createCluster' super(EmrCreateClusterStep, self).__init__(state_id, **kwargs) @@ -218,9 +219,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:terminateCluster.sync' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:terminateCluster.sync' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:terminateCluster' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:terminateCluster' super(EmrTerminateClusterStep, self).__init__(state_id, **kwargs) @@ -246,9 +247,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep.sync' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:addStep.sync' else: - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:addStep' super(EmrAddStepStep, self).__init__(state_id, **kwargs) @@ -272,7 +273,7 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:cancelStep' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:cancelStep' super(EmrCancelStepStep, self).__init__(state_id, **kwargs) @@ -296,7 +297,7 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:setClusterTerminationProtection' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:setClusterTerminationProtection' super(EmrSetClusterTerminationProtectionStep, self).__init__(state_id, **kwargs) @@ -320,7 +321,7 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:modifyInstanceFleetByName' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:modifyInstanceFleetByName' super(EmrModifyInstanceFleetByNameStep, self).__init__(state_id, **kwargs) @@ -344,7 +345,7 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:modifyInstanceGroupByName' + kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:modifyInstanceGroupByName' super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index 8e35d36..dff22f4 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -12,6 +12,31 @@ # permissions and limitations under the License. from __future__ import absolute_import +import boto3 +import logging + +logger = logging.getLogger('stepfunctions') + + def tags_dict_to_kv_list(tags_dict): - kv_list = [{"Key": k, "Value": v} for k,v in tags_dict.items()] - return kv_list \ No newline at end of file + kv_list = [{"Key": k, "Value": v} for k, v in tags_dict.items()] + return kv_list + + +# Obtain matching aws partition name based on region +def get_aws_partition(): + partitions = boto3.session.Session().get_available_partitions() + cur_region = boto3.session.Session().region_name + cur_partition = "aws" + + if cur_region is None: + logger.warning("No region detected for the session, will use default partition: aws") + return cur_partition + + for partition in partitions: + regions = boto3.session.Session().get_available_regions("stepfunctions", partition) + if cur_region in regions: + cur_partition = partition + return cur_partition + + return cur_partition diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index cb3bee0..aa0f7bc 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -21,6 +21,7 @@ from sagemaker import Session from sagemaker.amazon import pca from sagemaker.sklearn.processing import SKLearnProcessor +from stepfunctions.steps.utils import get_aws_partition from tests.integ import DATA_DIR @pytest.fixture(scope="session") @@ -43,11 +44,11 @@ def aws_account_id(): @pytest.fixture(scope="session") def sfn_role_arn(aws_account_id): - return "arn:aws:iam::{}:role/StepFunctionsMLWorkflowExecutionFullAccess".format(aws_account_id) + return "arn:" + get_aws_partition() + ":iam::{}:role/StepFunctionsMLWorkflowExecutionFullAccess".format(aws_account_id) @pytest.fixture(scope="session") def sagemaker_role_arn(aws_account_id): - return "arn:aws:iam::{}:role/SageMakerRole".format(aws_account_id) + return "arn:" + get_aws_partition() + ":iam::{}:role/SageMakerRole".format(aws_account_id) @pytest.fixture(scope="session") def pca_estimator_fixture(sagemaker_role_arn): diff --git a/tests/integ/test_state_machine_definition.py b/tests/integ/test_state_machine_definition.py index ac5e7ac..fe79b8a 100644 --- a/tests/integ/test_state_machine_definition.py +++ b/tests/integ/test_state_machine_definition.py @@ -19,6 +19,7 @@ from sagemaker.image_uris import retrieve from stepfunctions import steps from stepfunctions.workflow import Workflow +from stepfunctions.steps.utils import get_aws_partition from tests.integ.utils import state_machine_delete_wait @pytest.fixture(scope="module") @@ -380,7 +381,7 @@ def test_choice_state_machine_creation(sfn_client, sfn_role_arn): def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters): task_state_name = "TaskState" final_state_name = "FinalState" - resource = "arn:aws:states:::sagemaker:createTrainingJob.sync" + resource = "arn:" + get_aws_partition() + ":states:::sagemaker:createTrainingJob.sync" task_state_result = "Task State Result" asl_state_machine_definition = { "StartAt": task_state_name, @@ -426,7 +427,7 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par task_failed_state_name = "Task Failed End" all_error_state_name = "Catch All End" catch_state_result = "Catch Result" - task_resource = "arn:aws:states:::sagemaker:createTrainingJob.sync" + task_resource = "arn:" + get_aws_partition() + ":states:::sagemaker:createTrainingJob.sync" # change the parameters to cause task state to fail training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" @@ -482,7 +483,7 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par interval_seconds = 1 max_attempts = 2 backoff_rate = 2 - task_resource = "arn:aws:states:::sagemaker:createTrainingJob.sync" + task_resource = "arn:" + get_aws_partition() + ":states:::sagemaker:createTrainingJob.sync" # change the parameters to cause task state to fail training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" diff --git a/tests/unit/test_steps.py b/tests/unit/test_steps.py index 32b118f..d3886ab 100644 --- a/tests/unit/test_steps.py +++ b/tests/unit/test_steps.py @@ -13,24 +13,28 @@ from __future__ import absolute_import import pytest +import boto3 +import os from stepfunctions.exceptions import DuplicateStatesInChain -from stepfunctions.steps import Pass, Succeed, Fail, Wait, Choice, ChoiceRule, Parallel, Map, Task, Retry, Catch, Chain +from stepfunctions.steps import Pass, Succeed, Fail, Wait, Choice, ChoiceRule, Parallel, Map, Task, Retry, Catch, Chain, \ + utils from stepfunctions.steps.states import State, to_pascalcase def test_to_pascalcase(): assert 'InputPath' == to_pascalcase('input_path') + def test_state_creation(): state = State( state_id='StartState', state_type='Void', - comment = 'This is a comment', - input_path = '$.Input', - output_path = '$.Output', - parameters = {'Key': 'Value'}, - result_path = '$.Result' + comment='This is a comment', + input_path='$.Input', + output_path='$.Output', + parameters={'Key': 'Value'}, + result_path='$.Result' ) assert state.to_dict() == { @@ -48,6 +52,7 @@ def test_state_creation(): with pytest.raises(TypeError): State(state_id='State', unknown_attribute=True) + def test_pass_state_creation(): pass_state = Pass('Pass', result='Pass') assert pass_state.state_id == 'Pass' @@ -57,6 +62,7 @@ def test_pass_state_creation(): 'End': True } + def test_verify_pass_state_fields(): pass_state = Pass( state_id='Pass', @@ -76,6 +82,7 @@ def test_verify_pass_state_fields(): with pytest.raises(TypeError): Pass('Pass', unknown_field='Unknown Field') + def test_succeed_state_creation(): succeed_state = Succeed( state_id='Succeed', @@ -88,10 +95,12 @@ def test_succeed_state_creation(): 'Comment': 'This is a comment' } + def test_verify_succeed_state_fields(): with pytest.raises(TypeError): Succeed('Succeed', unknown_field='Unknown Field') + def test_fail_creation(): fail_state = Fail( state_id='Fail', @@ -110,10 +119,12 @@ def test_fail_creation(): 'Cause': 'Kaiju attack' } + def test_verify_fail_state_fields(): with pytest.raises(TypeError): Fail('Succeed', unknown_field='Unknown Field') + def test_wait_state_creation(): wait_state = Wait( state_id='Wait', @@ -139,6 +150,7 @@ def test_wait_state_creation(): 'End': True } + def test_verify_wait_state_fields(): with pytest.raises(ValueError): Wait( @@ -147,6 +159,7 @@ def test_verify_wait_state_fields(): seconds_path='$.SecondsPath' ) + def test_choice_state_creation(): choice_state = Choice('Choice', input_path='$.Input') choice_state.add_choice(ChoiceRule.IsPresent("$.StringVariable1", True), Pass("End State 1")) @@ -182,6 +195,7 @@ def test_choice_state_creation(): with pytest.raises(TypeError): Choice('Choice', unknown_field='Unknown Field') + def test_task_state_creation(): task_state = Task('Task', resource='arn:aws:lambda:us-east-1:1234567890:function:StartLambda') task_state.add_retry(Retry(error_equals=['ErrorA', 'ErrorB'], interval_seconds=1, max_attempts=2, backoff_rate=2)) @@ -214,6 +228,7 @@ def test_task_state_creation(): 'End': True } + def test_task_state_creation_with_dynamic_timeout(): task_state = Task( 'Task', @@ -229,6 +244,7 @@ def test_task_state_creation_with_dynamic_timeout(): 'End': True } + def test_task_state_create_fail_for_duplicated_dynamic_timeout_fields(): with pytest.raises(ValueError): Task( @@ -246,6 +262,7 @@ def test_task_state_create_fail_for_duplicated_dynamic_timeout_fields(): heartbeat_seconds_path='$.heartbeat', ) + def test_parallel_state_creation(): parallel_state = Parallel('Parallel') parallel_state.add_branch(Pass('Branch 1')) @@ -287,6 +304,7 @@ def test_parallel_state_creation(): 'End': True } + def test_map_state_creation(): map_state = Map('Map', iterator=Pass('FirstIteratorState'), items_path='$', max_concurrency=0) assert map_state.to_dict() == { @@ -305,9 +323,11 @@ def test_map_state_creation(): 'End': True } + def test_nested_chain_is_now_allowed(): chain = Chain([Chain([Pass('S1')])]) + def test_catch_creation(): catch = Catch(error_equals=['States.ALL'], next_step=Fail('End')) assert catch.to_dict() == { @@ -315,6 +335,7 @@ def test_catch_creation(): 'Next': 'End' } + def test_append_states_after_terminal_state_will_fail(): with pytest.raises(ValueError): chain = Chain() @@ -372,6 +393,7 @@ def test_chaining_steps(): assert s1.next_step == s2 assert s2.next_step == s3 + def test_catch_fail_for_unsupported_state(): s1 = Pass('Step - One') @@ -380,7 +402,6 @@ def test_catch_fail_for_unsupported_state(): def test_retry_fail_for_unsupported_state(): - c1 = Choice('My Choice') with pytest.raises(ValueError): @@ -424,3 +445,30 @@ def test_default_paths_not_converted_to_null(): assert '"ResultPath": null' not in task_state.to_json() assert '"InputPath": null' not in task_state.to_json() assert '"OutputPath": null' not in task_state.to_json() + + +# Test if boto3 session can fetch correct aws partition info from test environment +def test_util_get_aws_partition(): + aws_partition = "aws" + aws_cn_partition = "aws-cn" + default_region = None + + # Boto3 used either info from ~/.aws/config or AWS_DEFAULT_REGION in environment + # to determine current region. We will replace/create AWS_DEFAULT_REGION with regions in + # different aws partition to test that when regions are changed, correct partition info + # can be retrieved. + if "AWS_DEFAULT_REGION" in os.environ: + default_region = os.getenv('AWS_DEFAULT_REGION') + + os.environ['AWS_DEFAULT_REGION'] = 'us-east-1' + cur_partition = utils.get_aws_partition() + assert cur_partition == aws_partition + + os.environ['AWS_DEFAULT_REGION'] = 'cn-north-1' + cur_partition = utils.get_aws_partition() + assert cur_partition == aws_cn_partition + + if default_region is None: + del os.environ['AWS_DEFAULT_REGION'] + else: + os.environ['AWS_DEFAULT_REGION'] = default_region From 723214e9ee33af6aae12cc86b373b3fb5efb4d6e Mon Sep 17 00:00:00 2001 From: Bowen Yuan <83046180+yuan-bwn@users.noreply.github.com> Date: Fri, 7 May 2021 09:19:47 -0700 Subject: [PATCH 2/6] Update utils.py Add comment to method explanation. --- src/stepfunctions/steps/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index dff22f4..9fe0310 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -24,6 +24,7 @@ def tags_dict_to_kv_list(tags_dict): # Obtain matching aws partition name based on region +# Retrun "aws" as default if no region detected def get_aws_partition(): partitions = boto3.session.Session().get_available_partitions() cur_region = boto3.session.Session().region_name From bee295308f000abbcacf0477c01aed34dfdc9702 Mon Sep 17 00:00:00 2001 From: Bowen Yuan Date: Wed, 12 May 2021 10:26:47 -0700 Subject: [PATCH 3/6] fix: use boto3 mock for utils test and create arn builder for task integration resource --- src/stepfunctions/steps/compute.py | 56 +++++-- .../steps/integration_resources.py | 89 +++++++++++ src/stepfunctions/steps/sagemaker.py | 83 ++++++++-- src/stepfunctions/steps/service.py | 126 ++++++++++++--- src/stepfunctions/steps/utils.py | 20 ++- tests/integ/conftest.py | 4 +- tests/integ/test_state_machine_definition.py | 8 +- tests/unit/test_compute_steps.py | 9 ++ tests/unit/test_pipeline.py | 8 + tests/unit/test_sagemaker_steps.py | 30 ++++ tests/unit/test_service_steps.py | 20 +++ tests/unit/test_steps.py | 29 +--- tests/unit/test_steps_utils.py | 143 ++++++++++++++++++ 13 files changed, 546 insertions(+), 79 deletions(-) create mode 100644 src/stepfunctions/steps/integration_resources.py create mode 100644 tests/unit/test_steps_utils.py diff --git a/src/stepfunctions/steps/compute.py b/src/stepfunctions/steps/compute.py index 871d200..49ba89c 100644 --- a/src/stepfunctions/steps/compute.py +++ b/src/stepfunctions/steps/compute.py @@ -14,7 +14,8 @@ from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import get_aws_partition +from stepfunctions.steps.utils import resource_integration_arn_builder +from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, LambdaApi, GlueApi, BatchApi, EcsApi class LambdaStep(Task): @@ -38,10 +39,20 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ + if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::lambda:invoke.waitForTaskToken' + """ + Example resource arn: arn:aws:states:::lambda:invoke.waitForTaskToken + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Lambda, + LambdaApi.Invoke, + IntegrationPattern.WaitForTaskToken) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::lambda:invoke' + """ + Example resource arn: arn:aws:states:::lambda:invoke + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Lambda, LambdaApi.Invoke) + super(LambdaStep, self).__init__(state_id, **kwargs) @@ -68,9 +79,18 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::glue:startJobRun.sync' + """ + Example resource arn: arn:aws:states:::glue:startJobRun.sync + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Glue, + GlueApi.StartJobRun, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::glue:startJobRun' + """ + Example resource arn: arn:aws:states:::glue:startJobRun + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Glue, + GlueApi.StartJobRun) super(GlueStartJobRunStep, self).__init__(state_id, **kwargs) @@ -97,9 +117,18 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::batch:submitJob.sync' + """ + Example resource arn: arn:aws:states:::batch:submitJob.sync + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Batch, + BatchApi.SubmitJob, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::batch:submitJob' + """ + Example resource arn: arn:aws:states:::batch:submitJob + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Batch, + BatchApi.SubmitJob) super(BatchSubmitJobStep, self).__init__(state_id, **kwargs) @@ -126,8 +155,17 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::ecs:runTask.sync' + """ + Example resource arn: arn:aws:states:::ecs:runTask.sync + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ECS, + EcsApi.RunTask, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::ecs:runTask' + """ + Example resource arn: arn:aws:states:::ecs:runTask + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ECS, + EcsApi.RunTask) super(EcsRunTaskStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/integration_resources.py b/src/stepfunctions/steps/integration_resources.py new file mode 100644 index 0000000..8784980 --- /dev/null +++ b/src/stepfunctions/steps/integration_resources.py @@ -0,0 +1,89 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from __future__ import absolute_import + +from enum import Enum + +""" +Enum classes for task integration resource arn builder +""" + + +class IntegrationPattern(Enum): + WaitForTaskToken = "waitForTaskToken" + WaitForCompletion = "sync" + + +class IntegrationServices(Enum): + Lambda = "lambda" + SageMaker = "sagemaker" + Glue = "glue" + ECS = "ecs" + Batch = "batch" + DynamoDB = "dynamodb" + SNS = "sns" + SQS = "sqs" + ElasticMapReduce = "elasticmapreduce" + + +class LambdaApi(Enum): + Invoke = "invoke" + + +class SageMakerApi(Enum): + CreateTrainingJob = "createTrainingJob" + CreateTransformJob = "createTransformJob" + CreateModel = "createModel" + CreateEndpointConfig = "createEndpointConfig" + UpdateEndpoint = "updateEndpoint" + CreateEndpoint = "createEndpoint" + CreateHyperParameterTuningJob = "createHyperParameterTuningJob" + CreateProcessingJob = "createProcessingJob" + + +class GlueApi(Enum): + StartJobRun = "startJobRun" + + +class EcsApi(Enum): + RunTask = "runTask" + + +class BatchApi(Enum): + SubmitJob = "submitJob" + + +class DynamoDBApi(Enum): + GetItem = "getItem" + PutItem = "putItem" + DeleteItem = "deleteItem" + UpdateItem = "updateItem" + + +class SnsApi(Enum): + Publish = "publish" + + +class SqsApi(Enum): + SendMessage = "sendMessage" + + +class ElasticMapReduceApi(Enum): + CreateCluster = "createCluster" + TerminateCluster = "terminateCluster" + AddStep = "addStep" + CancelStep = "cancelStep" + SetClusterTerminationProtection = "setClusterTerminationProtection" + ModifyInstanceFleetByName = "modifyInstanceFleetByName" + ModifyInstanceGroupByName = "modifyInstanceGroupByName" diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 93b826a..abbebc9 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -15,7 +15,8 @@ from stepfunctions.inputs import ExecutionInput, StepInput from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import tags_dict_to_kv_list, get_aws_partition +from stepfunctions.steps.utils import tags_dict_to_kv_list, resource_integration_arn_builder +from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, SageMakerApi from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config from sagemaker.model import Model, FrameworkModel @@ -58,9 +59,18 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non self.job_name = job_name if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTrainingJob.sync' + """ + Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateTrainingJob, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTrainingJob' + """ + Example resource arn: arn:aws:states:::sagemaker:createTrainingJob + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateTrainingJob) if isinstance(job_name, str): parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size) @@ -141,9 +151,18 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= join_source (str): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None. """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTransformJob.sync' + """ + Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateTransformJob, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTransformJob' + """ + Example resource arn: arn:aws:states:::sagemaker:createTransformJob + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateTransformJob) if isinstance(job_name, str): parameters = transform_config( @@ -225,7 +244,12 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No parameters['Tags'] = tags_dict_to_kv_list(tags) kwargs[Field.Parameters.value] = parameters - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createModel' + + """ + Example resource arn: arn:aws:states:::sagemaker:createModel + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateModel) super(ModelStep, self).__init__(state_id, **kwargs) @@ -266,7 +290,12 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_ if tags: parameters['Tags'] = tags_dict_to_kv_list(tags) - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createEndpointConfig' + """ + Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateEndpointConfig) + kwargs[Field.Parameters.value] = parameters super(EndpointConfigStep, self).__init__(state_id, **kwargs) @@ -298,9 +327,17 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd parameters['Tags'] = tags_dict_to_kv_list(tags) if update: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:updateEndpoint' + """ + Example resource arn: arn:aws:states:::sagemaker:updateEndpoint + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.UpdateEndpoint) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createEndpoint' + """ + Example resource arn: arn:aws:states:::sagemaker:createEndpoint + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateEndpoint) kwargs[Field.Parameters.value] = parameters @@ -338,9 +375,18 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta tags (list[dict], optional): `List to tags `_ to associate with the resource. """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createHyperParameterTuningJob.sync' + """ + Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateHyperParameterTuningJob, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createHyperParameterTuningJob' + """ + Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateHyperParameterTuningJob) parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy() @@ -387,10 +433,19 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp tags (list[dict], optional): `List to tags `_ to associate with the resource. """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createProcessingJob.sync' + """ + Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateProcessingJob, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createProcessingJob' - + """ + Example resource arn: arn:aws:states:::sagemaker:createProcessingJob + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, + SageMakerApi.CreateProcessingJob) + if isinstance(job_name, str): parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) else: diff --git a/src/stepfunctions/steps/service.py b/src/stepfunctions/steps/service.py index 4952f70..752f1a8 100644 --- a/src/stepfunctions/steps/service.py +++ b/src/stepfunctions/steps/service.py @@ -14,7 +14,9 @@ from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import get_aws_partition +from stepfunctions.steps.utils import resource_integration_arn_builder +from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, \ + DynamoDBApi, SnsApi, SqsApi, ElasticMapReduceApi class DynamoDBGetItemStep(Task): @@ -36,7 +38,12 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:getItem' + + """ + Example resource arn: arn:aws:states:::dynamodb:getItem + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.DynamoDB, + DynamoDBApi.GetItem) super(DynamoDBGetItemStep, self).__init__(state_id, **kwargs) @@ -60,7 +67,12 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:putItem' + + """ + Example resource arn: arn:aws:states:::dynamodb:putItem + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.DynamoDB, + DynamoDBApi.PutItem) super(DynamoDBPutItemStep, self).__init__(state_id, **kwargs) @@ -84,7 +96,12 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:deleteItem' + + """ + Example resource arn: arn:aws:states:::dynamodb:deleteItem + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.DynamoDB, + DynamoDBApi.DeleteItem) super(DynamoDBDeleteItemStep, self).__init__(state_id, **kwargs) @@ -108,7 +125,12 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:updateItem' + + """ + Example resource arn: arn:aws:states:::dynamodb:updateItem + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.DynamoDB, + DynamoDBApi.UpdateItem) super(DynamoDBUpdateItemStep, self).__init__(state_id, **kwargs) @@ -134,9 +156,18 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sns:publish.waitForTaskToken' + """ + Example resource arn: arn:aws:states:::sns:publish.waitForTaskToken + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SNS, + SnsApi.Publish, + IntegrationPattern.WaitForTaskToken) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sns:publish' + """ + Example resource arn: arn:aws:states:::sns:publish + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SNS, + SnsApi.Publish) super(SnsPublishStep, self).__init__(state_id, **kwargs) @@ -163,9 +194,18 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ if wait_for_callback: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sqs:sendMessage.waitForTaskToken' + """ + Example resource arn: arn:aws:states:::sqs:sendMessage.waitForTaskToken + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SQS, + SqsApi.SendMessage, + IntegrationPattern.WaitForTaskToken) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sqs:sendMessage' + """ + Example resource arn: arn:aws:states:::sqs:sendMessage + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SQS, + SqsApi.SendMessage) super(SqsSendMessageStep, self).__init__(state_id, **kwargs) @@ -191,9 +231,18 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:createCluster.sync' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:createCluster.sync + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, + ElasticMapReduceApi.CreateCluster, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:createCluster' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:createCluster + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, + ElasticMapReduceApi.CreateCluster) super(EmrCreateClusterStep, self).__init__(state_id, **kwargs) @@ -219,9 +268,18 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:terminateCluster.sync' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:terminateCluster.sync + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, + ElasticMapReduceApi.TerminateCluster, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:terminateCluster' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:terminateCluster + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, + ElasticMapReduceApi.TerminateCluster) super(EmrTerminateClusterStep, self).__init__(state_id, **kwargs) @@ -247,9 +305,18 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True) """ if wait_for_completion: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:addStep.sync' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:addStep.sync + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, + ElasticMapReduceApi.AddStep, + IntegrationPattern.WaitForCompletion) else: - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:addStep' + """ + Example resource arn: arn:aws:states:::elasticmapreduce:addStep + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, + ElasticMapReduceApi.AddStep) super(EmrAddStepStep, self).__init__(state_id, **kwargs) @@ -273,7 +340,12 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:cancelStep' + + """ + Example resource arn: arn:aws:states:::elasticmapreduce:cancelStep + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, + ElasticMapReduceApi.CancelStep) super(EmrCancelStepStep, self).__init__(state_id, **kwargs) @@ -297,7 +369,12 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:setClusterTerminationProtection' + + """ + Example resource arn: arn:aws:states:::elasticmapreduce:setClusterTerminationProtection + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, + ElasticMapReduceApi.SetClusterTerminationProtection) super(EmrSetClusterTerminationProtectionStep, self).__init__(state_id, **kwargs) @@ -321,7 +398,12 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:modifyInstanceFleetByName' + + """ + Example resource arn: arn:aws:states:::elasticmapreduce:modifyInstanceFleetByName + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, + ElasticMapReduceApi.ModifyInstanceFleetByName) super(EmrModifyInstanceFleetByNameStep, self).__init__(state_id, **kwargs) @@ -345,7 +427,11 @@ def __init__(self, state_id, **kwargs): result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$') output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$') """ - kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:modifyInstanceGroupByName' - super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs) + """ + Example resource arn: arn:aws:states:::elasticmapreduce:modifyInstanceGroupByName + """ + kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, + ElasticMapReduceApi.ModifyInstanceGroupByName) + super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index 9fe0310..89785f0 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -23,15 +23,17 @@ def tags_dict_to_kv_list(tags_dict): return kv_list -# Obtain matching aws partition name based on region -# Retrun "aws" as default if no region detected +""" +Obtain matching aws partition name based on region +Return "aws" as default if no region detected +""" def get_aws_partition(): partitions = boto3.session.Session().get_available_partitions() cur_region = boto3.session.Session().region_name cur_partition = "aws" if cur_region is None: - logger.warning("No region detected for the session, will use default partition: aws") + logger.warning("No region detected for the boto3 session. Using default partition: aws") return cur_partition for partition in partitions: @@ -41,3 +43,15 @@ def get_aws_partition(): return cur_partition return cur_partition + + +""" +ARN builder for task integration +""" +def resource_integration_arn_builder(service, api, integration_pattern=None): + arn = "" + if integration_pattern is None: + arn = f"arn:{get_aws_partition()}:states:::{service.value}:{api.value}" + else: + arn = f"arn:{get_aws_partition()}:states:::{service.value}:{api.value}.{integration_pattern.value}" + return arn diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index aa0f7bc..b2bb979 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -44,11 +44,11 @@ def aws_account_id(): @pytest.fixture(scope="session") def sfn_role_arn(aws_account_id): - return "arn:" + get_aws_partition() + ":iam::{}:role/StepFunctionsMLWorkflowExecutionFullAccess".format(aws_account_id) + return f"arn:{get_aws_partition()}:iam::{aws_account_id}:role/StepFunctionsMLWorkflowExecutionFullAccess" @pytest.fixture(scope="session") def sagemaker_role_arn(aws_account_id): - return "arn:" + get_aws_partition() + ":iam::{}:role/SageMakerRole".format(aws_account_id) + return f"arn:{get_aws_partition()}:iam::{aws_account_id}:role/SageMakerRole" @pytest.fixture(scope="session") def pca_estimator_fixture(sagemaker_role_arn): diff --git a/tests/integ/test_state_machine_definition.py b/tests/integ/test_state_machine_definition.py index fe79b8a..0fec9f5 100644 --- a/tests/integ/test_state_machine_definition.py +++ b/tests/integ/test_state_machine_definition.py @@ -22,6 +22,7 @@ from stepfunctions.steps.utils import get_aws_partition from tests.integ.utils import state_machine_delete_wait + @pytest.fixture(scope="module") def training_job_parameters(sagemaker_session, sagemaker_role_arn, record_set_fixture): parameters = { @@ -63,6 +64,7 @@ def training_job_parameters(sagemaker_session, sagemaker_role_arn, record_set_fi return parameters + def workflow_test_suite(sfn_client, workflow, asl_state_machine_definition, output_result, inputs=None): state_machine_arn = workflow.create() execution = workflow.execute(inputs=inputs) @@ -381,7 +383,7 @@ def test_choice_state_machine_creation(sfn_client, sfn_role_arn): def test_task_state_machine_creation(sfn_client, sfn_role_arn, training_job_parameters): task_state_name = "TaskState" final_state_name = "FinalState" - resource = "arn:" + get_aws_partition() + ":states:::sagemaker:createTrainingJob.sync" + resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" task_state_result = "Task State Result" asl_state_machine_definition = { "StartAt": task_state_name, @@ -427,7 +429,7 @@ def test_catch_state_machine_creation(sfn_client, sfn_role_arn, training_job_par task_failed_state_name = "Task Failed End" all_error_state_name = "Catch All End" catch_state_result = "Catch Result" - task_resource = "arn:" + get_aws_partition() + ":states:::sagemaker:createTrainingJob.sync" + task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" # change the parameters to cause task state to fail training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" @@ -483,7 +485,7 @@ def test_retry_state_machine_creation(sfn_client, sfn_role_arn, training_job_par interval_seconds = 1 max_attempts = 2 backoff_rate = 2 - task_resource = "arn:" + get_aws_partition() + ":states:::sagemaker:createTrainingJob.sync" + task_resource = f"arn:{get_aws_partition()}:states:::sagemaker:createTrainingJob.sync" # change the parameters to cause task state to fail training_job_parameters["AlgorithmSpecification"]["TrainingImage"] = "not_an_image" diff --git a/tests/unit/test_compute_steps.py b/tests/unit/test_compute_steps.py index 030cf35..368010a 100644 --- a/tests/unit/test_compute_steps.py +++ b/tests/unit/test_compute_steps.py @@ -13,10 +13,13 @@ from __future__ import absolute_import import pytest +import boto3 +from unittest.mock import patch from stepfunctions.steps.compute import LambdaStep, GlueStartJobRunStep, BatchSubmitJobStep, EcsRunTaskStep +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_lambda_step_creation(): step = LambdaStep('Echo') @@ -45,6 +48,8 @@ def test_lambda_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_glue_start_job_run_step_creation(): step = GlueStartJobRunStep('Glue Job', wait_for_completion=False) @@ -67,6 +72,8 @@ def test_glue_start_job_run_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_batch_submit_job_step_creation(): step = BatchSubmitJobStep('Batch Job', wait_for_completion=False) @@ -91,6 +98,8 @@ def test_batch_submit_job_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_ecs_run_task_step_creation(): step = EcsRunTaskStep('Ecs Job', wait_for_completion=False) diff --git a/tests/unit/test_pipeline.py b/tests/unit/test_pipeline.py index 2123a3d..c7ab502 100644 --- a/tests/unit/test_pipeline.py +++ b/tests/unit/test_pipeline.py @@ -14,6 +14,7 @@ import pytest import sagemaker +import boto3 from sagemaker.sklearn.estimator import SKLearn from unittest.mock import MagicMock, patch @@ -27,6 +28,7 @@ PCA_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/pca:1' LINEAR_LEARNER_IMAGE = '382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1' + @pytest.fixture def pca_estimator(): s3_output_location = 's3://sagemaker/models' @@ -52,6 +54,7 @@ def pca_estimator(): return pca + @pytest.fixture def sklearn_preprocessor(): script_path = 'sklearn_abalone_featurizer.py' @@ -75,6 +78,7 @@ def sklearn_preprocessor(): return sklearn_preprocessor + @pytest.fixture def linear_learner_estimator(): s3_output_location = 's3://sagemaker/models' @@ -101,7 +105,9 @@ def linear_learner_estimator(): return ll_estimator + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_pca_training_pipeline(pca_estimator): s3_inputs = { 'train': 's3://sagemaker/pca/train' @@ -228,7 +234,9 @@ def test_pca_training_pipeline(pca_estimator): workflow.execute.assert_called_with(name=job_name, inputs=inputs) + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_inference_pipeline(sklearn_preprocessor, linear_learner_estimator): s3_inputs = { 'train': 's3://sagemaker-us-east-1/inference/train' diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index d45ee79..2ce9d3b 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -37,6 +37,7 @@ DEFAULT_TAGS = {'Purpose': 'unittests'} DEFAULT_TAGS_LIST = [{'Key': 'Purpose', 'Value': 'unittests'}] + @pytest.fixture def pca_estimator(): s3_output_location = 's3://sagemaker/models' @@ -63,6 +64,7 @@ def pca_estimator(): return pca + @pytest.fixture def pca_estimator_with_debug_hook(): s3_output_location = 's3://sagemaker/models' @@ -139,6 +141,7 @@ def pca_estimator_with_falsy_debug_hook(): return pca + @pytest.fixture def pca_model(): model_data = 's3://sagemaker/models/pca.tar.gz' @@ -149,6 +152,7 @@ def pca_model(): name='pca-model' ) + @pytest.fixture def pca_transformer(pca_model): return Transformer( @@ -158,6 +162,7 @@ def pca_transformer(pca_model): output_path='s3://sagemaker/transform-output' ) + @pytest.fixture def tensorflow_estimator(): s3_output_location = 's3://sagemaker/models' @@ -190,6 +195,7 @@ def tensorflow_estimator(): return estimator + @pytest.fixture def sklearn_processor(): sagemaker_session = MagicMock() @@ -206,7 +212,9 @@ def sklearn_processor(): return processor + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation(pca_estimator): step = TrainingStep('Training', estimator=pca_estimator, @@ -256,7 +264,9 @@ def test_training_step_creation(pca_estimator): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook): step = TrainingStep('Training', estimator=pca_estimator_with_debug_hook, @@ -315,7 +325,9 @@ def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_falsy_debug_hook(pca_estimator_with_falsy_debug_hook): step = TrainingStep('Training', estimator=pca_estimator_with_falsy_debug_hook, @@ -352,7 +364,9 @@ def test_training_step_creation_with_falsy_debug_hook(pca_estimator_with_falsy_d 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_model(pca_estimator): training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob') model_step = ModelStep('Training - Save Model', training_step.get_expected_model(model_name=training_step.output()['TrainingJobName'])) @@ -404,7 +418,9 @@ def test_training_step_creation_with_model(pca_estimator): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_training_step_creation_with_framework(tensorflow_estimator): step = TrainingStep('Training', estimator=tensorflow_estimator, @@ -466,6 +482,8 @@ def test_training_step_creation_with_framework(tensorflow_estimator): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_transform_step_creation(pca_transformer): step = TransformStep('Inference', transformer=pca_transformer, @@ -518,7 +536,9 @@ def test_transform_step_creation(pca_transformer): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_get_expected_model(pca_estimator): training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob') expected_model = training_step.get_expected_model() @@ -538,7 +558,9 @@ def test_get_expected_model(pca_estimator): 'End': True } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_get_expected_model_with_framework_estimator(tensorflow_estimator): training_step = TrainingStep('Training', estimator=tensorflow_estimator, @@ -569,6 +591,8 @@ def test_get_expected_model_with_framework_estimator(tensorflow_estimator): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_model_step_creation(pca_model): step = ModelStep('Create model', model=pca_model, model_name='pca-model', tags=DEFAULT_TAGS) assert step.to_dict() == { @@ -587,6 +611,8 @@ def test_model_step_creation(pca_model): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_endpoint_config_step_creation(pca_model): data_capture_config = DataCaptureConfig( enable_capture=True, @@ -629,6 +655,8 @@ def test_endpoint_config_step_creation(pca_model): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_endpoint_step_creation(pca_model): step = EndpointStep('Endpoint', endpoint_name='MyEndPoint', endpoint_config_name='MyEndpointConfig', tags=DEFAULT_TAGS) assert step.to_dict() == { @@ -654,6 +682,8 @@ def test_endpoint_step_creation(pca_model): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_processing_step_creation(sklearn_processor): inputs = [ProcessingInput(source='dataset.csv', destination='/opt/ml/processing/input')] outputs = [ diff --git a/tests/unit/test_service_steps.py b/tests/unit/test_service_steps.py index 64809e9..6576aaf 100644 --- a/tests/unit/test_service_steps.py +++ b/tests/unit/test_service_steps.py @@ -13,12 +13,15 @@ from __future__ import absolute_import import pytest +import boto3 +from unittest.mock import patch from stepfunctions.steps.service import DynamoDBGetItemStep, DynamoDBPutItemStep, DynamoDBUpdateItemStep, DynamoDBDeleteItemStep from stepfunctions.steps.service import SnsPublishStep, SqsSendMessageStep from stepfunctions.steps.service import EmrCreateClusterStep, EmrTerminateClusterStep, EmrAddStepStep, EmrCancelStepStep, EmrSetClusterTerminationProtectionStep, EmrModifyInstanceFleetByNameStep, EmrModifyInstanceGroupByNameStep +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_sns_publish_step_creation(): step = SnsPublishStep('Publish to SNS', parameters={ 'TopicArn': 'arn:aws:sns:us-east-1:123456789012:myTopic', @@ -57,6 +60,7 @@ def test_sns_publish_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_sqs_send_message_step_creation(): step = SqsSendMessageStep('Send to SQS', parameters={ 'QueueUrl': 'https://sqs.us-east-1.amazonaws.com/123456789012/myQueue', @@ -95,6 +99,7 @@ def test_sqs_send_message_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_dynamodb_get_item_step_creation(): step = DynamoDBGetItemStep('Read Message From DynamoDB', parameters={ 'TableName': 'TransferDataRecords-DDBTable-3I41R5L5EAGT', @@ -120,6 +125,7 @@ def test_dynamodb_get_item_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_dynamodb_put_item_step_creation(): step = DynamoDBPutItemStep('Add Message From DynamoDB', parameters={ 'TableName': 'TransferDataRecords-DDBTable-3I41R5L5EAGT', @@ -145,6 +151,7 @@ def test_dynamodb_put_item_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_dynamodb_delete_item_step_creation(): step = DynamoDBDeleteItemStep('Delete Message From DynamoDB', parameters={ 'TableName': 'TransferDataRecords-DDBTable-3I41R5L5EAGT', @@ -170,6 +177,7 @@ def test_dynamodb_delete_item_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_dynamodb_update_item_step_creation(): step = DynamoDBUpdateItemStep('Update Message From DynamoDB', parameters={ 'TableName': 'TransferDataRecords-DDBTable-3I41R5L5EAGT', @@ -203,6 +211,7 @@ def test_dynamodb_update_item_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_create_cluster_step_creation(): step = EmrCreateClusterStep('Create EMR cluster', parameters={ 'Name': 'MyWorkflowCluster', @@ -372,6 +381,7 @@ def test_emr_create_cluster_step_creation(): } +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_terminate_cluster_step_creation(): step = EmrTerminateClusterStep('Terminate EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId' @@ -399,6 +409,8 @@ def test_emr_terminate_cluster_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_add_step_step_creation(): step = EmrAddStepStep('Add step to EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', @@ -498,6 +510,8 @@ def test_emr_add_step_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_cancel_step_step_creation(): step = EmrCancelStepStep('Cancel step from EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', @@ -514,6 +528,8 @@ def test_emr_cancel_step_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_set_cluster_termination_protection_step_creation(): step = EmrSetClusterTerminationProtectionStep('Set termination protection for EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', @@ -530,6 +546,8 @@ def test_emr_set_cluster_termination_protection_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_modify_instance_fleet_by_name_step_creation(): step = EmrModifyInstanceFleetByNameStep('Modify Instance Fleet by name for EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', @@ -554,6 +572,8 @@ def test_emr_modify_instance_fleet_by_name_step_creation(): 'End': True } + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_emr_modify_instance_group_by_name_step_creation(): step = EmrModifyInstanceGroupByNameStep('Modify Instance Group by name for EMR cluster', parameters={ 'ClusterId': 'MyWorkflowClusterId', diff --git a/tests/unit/test_steps.py b/tests/unit/test_steps.py index d3886ab..dfc6262 100644 --- a/tests/unit/test_steps.py +++ b/tests/unit/test_steps.py @@ -13,12 +13,9 @@ from __future__ import absolute_import import pytest -import boto3 -import os from stepfunctions.exceptions import DuplicateStatesInChain -from stepfunctions.steps import Pass, Succeed, Fail, Wait, Choice, ChoiceRule, Parallel, Map, Task, Retry, Catch, Chain, \ - utils +from stepfunctions.steps import Pass, Succeed, Fail, Wait, Choice, ChoiceRule, Parallel, Map, Task, Retry, Catch, Chain from stepfunctions.steps.states import State, to_pascalcase @@ -447,28 +444,4 @@ def test_default_paths_not_converted_to_null(): assert '"OutputPath": null' not in task_state.to_json() -# Test if boto3 session can fetch correct aws partition info from test environment -def test_util_get_aws_partition(): - aws_partition = "aws" - aws_cn_partition = "aws-cn" - default_region = None - # Boto3 used either info from ~/.aws/config or AWS_DEFAULT_REGION in environment - # to determine current region. We will replace/create AWS_DEFAULT_REGION with regions in - # different aws partition to test that when regions are changed, correct partition info - # can be retrieved. - if "AWS_DEFAULT_REGION" in os.environ: - default_region = os.getenv('AWS_DEFAULT_REGION') - - os.environ['AWS_DEFAULT_REGION'] = 'us-east-1' - cur_partition = utils.get_aws_partition() - assert cur_partition == aws_partition - - os.environ['AWS_DEFAULT_REGION'] = 'cn-north-1' - cur_partition = utils.get_aws_partition() - assert cur_partition == aws_cn_partition - - if default_region is None: - del os.environ['AWS_DEFAULT_REGION'] - else: - os.environ['AWS_DEFAULT_REGION'] = default_region diff --git a/tests/unit/test_steps_utils.py b/tests/unit/test_steps_utils.py new file mode 100644 index 0000000..a2beb00 --- /dev/null +++ b/tests/unit/test_steps_utils.py @@ -0,0 +1,143 @@ +# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +# Test if boto3 session can fetch correct aws partition info from test environment + +from stepfunctions.steps.utils import get_aws_partition, resource_integration_arn_builder +from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, LambdaApi, SageMakerApi,\ + GlueApi, EcsApi, BatchApi, DynamoDBApi, SnsApi, SqsApi, ElasticMapReduceApi +import boto3 +from unittest.mock import patch + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_util_get_aws_partition_aws(): + cur_partition = get_aws_partition() + assert cur_partition == "aws" + + +@patch.object(boto3.session.Session, 'region_name', 'cn-north-1') +def test_util_get_aws_partition_aws_cn(): + cur_partition = get_aws_partition() + assert cur_partition == "aws-cn" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_lambda_no_wait(): + arn = resource_integration_arn_builder(IntegrationServices.Lambda, LambdaApi.Invoke) + assert arn == "arn:aws:states:::lambda:invoke" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_lambda_wait_token(): + arn = resource_integration_arn_builder(IntegrationServices.Lambda, LambdaApi.Invoke, + IntegrationPattern.WaitForTaskToken) + assert arn == "arn:aws:states:::lambda:invoke.waitForTaskToken" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_sagemaker_no_wait(): + arn = resource_integration_arn_builder(IntegrationServices.SageMaker, SageMakerApi.CreateTrainingJob) + assert arn == "arn:aws:states:::sagemaker:createTrainingJob" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_sagemaker_wait_completion(): + arn = resource_integration_arn_builder(IntegrationServices.SageMaker, SageMakerApi.CreateTrainingJob, + IntegrationPattern.WaitForCompletion) + assert arn == "arn:aws:states:::sagemaker:createTrainingJob.sync" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_glue_no_wait(): + arn = resource_integration_arn_builder(IntegrationServices.Glue, GlueApi.StartJobRun) + assert arn == "arn:aws:states:::glue:startJobRun" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_glue_wait_completion(): + arn = resource_integration_arn_builder(IntegrationServices.Glue, GlueApi.StartJobRun, + IntegrationPattern.WaitForCompletion) + assert arn == "arn:aws:states:::glue:startJobRun.sync" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_ecs_no_wait(): + arn = resource_integration_arn_builder(IntegrationServices.ECS, EcsApi.RunTask) + assert arn == "arn:aws:states:::ecs:runTask" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_ecs_wait_completion(): + arn = resource_integration_arn_builder(IntegrationServices.ECS, EcsApi.RunTask, + IntegrationPattern.WaitForCompletion) + assert arn == "arn:aws:states:::ecs:runTask.sync" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_batch_no_wait(): + arn = resource_integration_arn_builder(IntegrationServices.Batch, BatchApi.SubmitJob) + assert arn == "arn:aws:states:::batch:submitJob" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_batch_wait_completion(): + arn = resource_integration_arn_builder(IntegrationServices.Batch, BatchApi.SubmitJob, + IntegrationPattern.WaitForCompletion) + assert arn == "arn:aws:states:::batch:submitJob.sync" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_dynamodb_no_wait(): + arn = resource_integration_arn_builder(IntegrationServices.DynamoDB, DynamoDBApi.GetItem) + assert arn == "arn:aws:states:::dynamodb:getItem" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_sns_no_wait(): + arn = resource_integration_arn_builder(IntegrationServices.SNS, SnsApi.Publish) + assert arn == "arn:aws:states:::sns:publish" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_sns_wait_token(): + arn = resource_integration_arn_builder(IntegrationServices.SNS, SnsApi.Publish, + IntegrationPattern.WaitForTaskToken) + assert arn == "arn:aws:states:::sns:publish.waitForTaskToken" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_sqs_no_wait(): + arn = resource_integration_arn_builder(IntegrationServices.SQS, SqsApi.SendMessage) + assert arn == "arn:aws:states:::sqs:sendMessage" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_sqs_wait_token(): + arn = resource_integration_arn_builder(IntegrationServices.SQS, SqsApi.SendMessage, + IntegrationPattern.WaitForTaskToken) + assert arn == "arn:aws:states:::sqs:sendMessage.waitForTaskToken" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_elasticmapreduce_no_wait(): + arn = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, ElasticMapReduceApi.CreateCluster) + assert arn == "arn:aws:states:::elasticmapreduce:createCluster" + + +@patch.object(boto3.session.Session, 'region_name', 'us-east-1') +def test_arn_builder_elasticmapreduce_wait_completion(): + arn = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, ElasticMapReduceApi.CreateCluster, + IntegrationPattern.WaitForCompletion) + assert arn == "arn:aws:states:::elasticmapreduce:createCluster.sync" + From 0c9b0f0469a9968c1d826f440658cc195e90c9ad Mon Sep 17 00:00:00 2001 From: Bowen Yuan Date: Thu, 13 May 2021 22:03:47 -0700 Subject: [PATCH 4/6] fix: update arn builder method and its usage --- src/stepfunctions/steps/compute.py | 72 +++++++--- .../steps/integration_resources.py | 69 +-------- src/stepfunctions/steps/sagemaker.py | 87 +++++++---- src/stepfunctions/steps/service.py | 136 ++++++++++++------ src/stepfunctions/steps/utils.py | 32 +++-- tests/unit/test_steps_utils.py | 116 ++------------- 6 files changed, 239 insertions(+), 273 deletions(-) diff --git a/src/stepfunctions/steps/compute.py b/src/stepfunctions/steps/compute.py index 49ba89c..f8d17ac 100644 --- a/src/stepfunctions/steps/compute.py +++ b/src/stepfunctions/steps/compute.py @@ -12,10 +12,32 @@ # permissions and limitations under the License. from __future__ import absolute_import +from enum import Enum from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import resource_integration_arn_builder -from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, LambdaApi, GlueApi, BatchApi, EcsApi +from stepfunctions.steps.utils import get_service_integration_arn +from stepfunctions.steps.integration_resources import IntegrationPattern + +Lambda = "lambda" +Glue = "glue" +Ecs = "ecs" +Batch = "batch" + + +class LambdaApi(Enum): + Invoke = "invoke" + + +class GlueApi(Enum): + StartJobRun = "startJobRun" + + +class EcsApi(Enum): + RunTask = "runTask" + + +class BatchApi(Enum): + SubmitJob = "submitJob" class LambdaStep(Task): @@ -44,14 +66,16 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): """ Example resource arn: arn:aws:states:::lambda:invoke.waitForTaskToken """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Lambda, - LambdaApi.Invoke, - IntegrationPattern.WaitForTaskToken) + + kwargs[Field.Resource.value] = get_service_integration_arn(Lambda, + LambdaApi.Invoke, + IntegrationPattern.WaitForTaskToken) else: """ Example resource arn: arn:aws:states:::lambda:invoke """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Lambda, LambdaApi.Invoke) + + kwargs[Field.Resource.value] = get_service_integration_arn(Lambda, LambdaApi.Invoke) super(LambdaStep, self).__init__(state_id, **kwargs) @@ -82,15 +106,17 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): """ Example resource arn: arn:aws:states:::glue:startJobRun.sync """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Glue, - GlueApi.StartJobRun, - IntegrationPattern.WaitForCompletion) + + kwargs[Field.Resource.value] = get_service_integration_arn(Glue, + GlueApi.StartJobRun, + IntegrationPattern.WaitForCompletion) else: """ Example resource arn: arn:aws:states:::glue:startJobRun """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Glue, - GlueApi.StartJobRun) + + kwargs[Field.Resource.value] = get_service_integration_arn(Glue, + GlueApi.StartJobRun) super(GlueStartJobRunStep, self).__init__(state_id, **kwargs) @@ -120,15 +146,17 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): """ Example resource arn: arn:aws:states:::batch:submitJob.sync """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Batch, - BatchApi.SubmitJob, - IntegrationPattern.WaitForCompletion) + + kwargs[Field.Resource.value] = get_service_integration_arn(Batch, + BatchApi.SubmitJob, + IntegrationPattern.WaitForCompletion) else: """ Example resource arn: arn:aws:states:::batch:submitJob """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.Batch, - BatchApi.SubmitJob) + + kwargs[Field.Resource.value] = get_service_integration_arn(Batch, + BatchApi.SubmitJob) super(BatchSubmitJobStep, self).__init__(state_id, **kwargs) @@ -158,14 +186,16 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): """ Example resource arn: arn:aws:states:::ecs:runTask.sync """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ECS, - EcsApi.RunTask, - IntegrationPattern.WaitForCompletion) + + kwargs[Field.Resource.value] = get_service_integration_arn(Ecs, + EcsApi.RunTask, + IntegrationPattern.WaitForCompletion) else: """ Example resource arn: arn:aws:states:::ecs:runTask """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ECS, - EcsApi.RunTask) + + kwargs[Field.Resource.value] = get_service_integration_arn(Ecs, + EcsApi.RunTask) super(EcsRunTaskStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/integration_resources.py b/src/stepfunctions/steps/integration_resources.py index 8784980..a2d56dc 100644 --- a/src/stepfunctions/steps/integration_resources.py +++ b/src/stepfunctions/steps/integration_resources.py @@ -15,75 +15,16 @@ from enum import Enum -""" -Enum classes for task integration resource arn builder -""" - class IntegrationPattern(Enum): + """ + Integration pattern enum classes for task integration resource arn builder + """ + WaitForTaskToken = "waitForTaskToken" WaitForCompletion = "sync" + RequestResponse = "" -class IntegrationServices(Enum): - Lambda = "lambda" - SageMaker = "sagemaker" - Glue = "glue" - ECS = "ecs" - Batch = "batch" - DynamoDB = "dynamodb" - SNS = "sns" - SQS = "sqs" - ElasticMapReduce = "elasticmapreduce" - - -class LambdaApi(Enum): - Invoke = "invoke" - - -class SageMakerApi(Enum): - CreateTrainingJob = "createTrainingJob" - CreateTransformJob = "createTransformJob" - CreateModel = "createModel" - CreateEndpointConfig = "createEndpointConfig" - UpdateEndpoint = "updateEndpoint" - CreateEndpoint = "createEndpoint" - CreateHyperParameterTuningJob = "createHyperParameterTuningJob" - CreateProcessingJob = "createProcessingJob" - - -class GlueApi(Enum): - StartJobRun = "startJobRun" - - -class EcsApi(Enum): - RunTask = "runTask" - - -class BatchApi(Enum): - SubmitJob = "submitJob" - - -class DynamoDBApi(Enum): - GetItem = "getItem" - PutItem = "putItem" - DeleteItem = "deleteItem" - UpdateItem = "updateItem" - - -class SnsApi(Enum): - Publish = "publish" - - -class SqsApi(Enum): - SendMessage = "sendMessage" -class ElasticMapReduceApi(Enum): - CreateCluster = "createCluster" - TerminateCluster = "terminateCluster" - AddStep = "addStep" - CancelStep = "cancelStep" - SetClusterTerminationProtection = "setClusterTerminationProtection" - ModifyInstanceFleetByName = "modifyInstanceFleetByName" - ModifyInstanceGroupByName = "modifyInstanceGroupByName" diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index abbebc9..812d6ad 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -12,16 +12,31 @@ # permissions and limitations under the License. from __future__ import absolute_import +from enum import Enum from stepfunctions.inputs import ExecutionInput, StepInput from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import tags_dict_to_kv_list, resource_integration_arn_builder -from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, SageMakerApi +from stepfunctions.steps.utils import tags_dict_to_kv_list, get_service_integration_arn +from stepfunctions.steps.integration_resources import IntegrationPattern from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config from sagemaker.model import Model, FrameworkModel from sagemaker.model_monitor import DataCaptureConfig +SageMaker = "sagemaker" + + +class SageMakerApi(Enum): + CreateTrainingJob = "createTrainingJob" + CreateTransformJob = "createTransformJob" + CreateModel = "createModel" + CreateEndpointConfig = "createEndpointConfig" + UpdateEndpoint = "updateEndpoint" + CreateEndpoint = "createEndpoint" + CreateHyperParameterTuningJob = "createHyperParameterTuningJob" + CreateProcessingJob = "createProcessingJob" + + class TrainingStep(Task): """ @@ -62,15 +77,17 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non """ Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateTrainingJob, - IntegrationPattern.WaitForCompletion) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateTrainingJob, + IntegrationPattern.WaitForCompletion) else: """ Example resource arn: arn:aws:states:::sagemaker:createTrainingJob """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateTrainingJob) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateTrainingJob) if isinstance(job_name, str): parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size) @@ -154,15 +171,17 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= """ Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateTransformJob, - IntegrationPattern.WaitForCompletion) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateTransformJob, + IntegrationPattern.WaitForCompletion) else: """ Example resource arn: arn:aws:states:::sagemaker:createTransformJob """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateTransformJob) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateTransformJob) if isinstance(job_name, str): parameters = transform_config( @@ -248,8 +267,9 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No """ Example resource arn: arn:aws:states:::sagemaker:createModel """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateModel) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateModel) super(ModelStep, self).__init__(state_id, **kwargs) @@ -293,8 +313,9 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_ """ Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateEndpointConfig) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateEndpointConfig) kwargs[Field.Parameters.value] = parameters @@ -330,14 +351,16 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd """ Example resource arn: arn:aws:states:::sagemaker:updateEndpoint """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.UpdateEndpoint) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.UpdateEndpoint) else: """ Example resource arn: arn:aws:states:::sagemaker:createEndpoint """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateEndpoint) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateEndpoint) kwargs[Field.Parameters.value] = parameters @@ -378,15 +401,17 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta """ Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateHyperParameterTuningJob, - IntegrationPattern.WaitForCompletion) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateHyperParameterTuningJob, + IntegrationPattern.WaitForCompletion) else: """ Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateHyperParameterTuningJob) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateHyperParameterTuningJob) parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy() @@ -436,15 +461,17 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp """ Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateProcessingJob, - IntegrationPattern.WaitForCompletion) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateProcessingJob, + IntegrationPattern.WaitForCompletion) else: """ Example resource arn: arn:aws:states:::sagemaker:createProcessingJob """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SageMaker, - SageMakerApi.CreateProcessingJob) + + kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + SageMakerApi.CreateProcessingJob) if isinstance(job_name, str): parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name) diff --git a/src/stepfunctions/steps/service.py b/src/stepfunctions/steps/service.py index 752f1a8..1bc368f 100644 --- a/src/stepfunctions/steps/service.py +++ b/src/stepfunctions/steps/service.py @@ -12,11 +12,41 @@ # permissions and limitations under the License. from __future__ import absolute_import +from enum import Enum from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import resource_integration_arn_builder -from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, \ - DynamoDBApi, SnsApi, SqsApi, ElasticMapReduceApi +from stepfunctions.steps.utils import get_service_integration_arn +from stepfunctions.steps.integration_resources import IntegrationPattern + +DynamoDB = "dynamodb" +Sns = "sns" +Sqs = "sqs" +ElasticMapReduce = "elasticmapreduce" + + +class DynamoDBApi(Enum): + GetItem = "getItem" + PutItem = "putItem" + DeleteItem = "deleteItem" + UpdateItem = "updateItem" + + +class SnsApi(Enum): + Publish = "publish" + + +class SqsApi(Enum): + SendMessage = "sendMessage" + + +class ElasticMapReduceApi(Enum): + CreateCluster = "createCluster" + TerminateCluster = "terminateCluster" + AddStep = "addStep" + CancelStep = "cancelStep" + SetClusterTerminationProtection = "setClusterTerminationProtection" + ModifyInstanceFleetByName = "modifyInstanceFleetByName" + ModifyInstanceGroupByName = "modifyInstanceGroupByName" class DynamoDBGetItemStep(Task): @@ -42,8 +72,9 @@ def __init__(self, state_id, **kwargs): """ Example resource arn: arn:aws:states:::dynamodb:getItem """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.DynamoDB, - DynamoDBApi.GetItem) + + kwargs[Field.Resource.value] = get_service_integration_arn(DynamoDB, + DynamoDBApi.GetItem) super(DynamoDBGetItemStep, self).__init__(state_id, **kwargs) @@ -71,8 +102,9 @@ def __init__(self, state_id, **kwargs): """ Example resource arn: arn:aws:states:::dynamodb:putItem """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.DynamoDB, - DynamoDBApi.PutItem) + + kwargs[Field.Resource.value] = get_service_integration_arn(DynamoDB, + DynamoDBApi.PutItem) super(DynamoDBPutItemStep, self).__init__(state_id, **kwargs) @@ -100,8 +132,9 @@ def __init__(self, state_id, **kwargs): """ Example resource arn: arn:aws:states:::dynamodb:deleteItem """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.DynamoDB, - DynamoDBApi.DeleteItem) + + kwargs[Field.Resource.value] = get_service_integration_arn(DynamoDB, + DynamoDBApi.DeleteItem) super(DynamoDBDeleteItemStep, self).__init__(state_id, **kwargs) @@ -129,8 +162,9 @@ def __init__(self, state_id, **kwargs): """ Example resource arn: arn:aws:states:::dynamodb:updateItem """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.DynamoDB, - DynamoDBApi.UpdateItem) + + kwargs[Field.Resource.value] = get_service_integration_arn(DynamoDB, + DynamoDBApi.UpdateItem) super(DynamoDBUpdateItemStep, self).__init__(state_id, **kwargs) @@ -159,15 +193,17 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): """ Example resource arn: arn:aws:states:::sns:publish.waitForTaskToken """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SNS, - SnsApi.Publish, - IntegrationPattern.WaitForTaskToken) + + kwargs[Field.Resource.value] = get_service_integration_arn(Sns, + SnsApi.Publish, + IntegrationPattern.WaitForTaskToken) else: """ Example resource arn: arn:aws:states:::sns:publish """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SNS, - SnsApi.Publish) + + kwargs[Field.Resource.value] = get_service_integration_arn(Sns, + SnsApi.Publish) super(SnsPublishStep, self).__init__(state_id, **kwargs) @@ -197,15 +233,17 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): """ Example resource arn: arn:aws:states:::sqs:sendMessage.waitForTaskToken """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SQS, - SqsApi.SendMessage, - IntegrationPattern.WaitForTaskToken) + + kwargs[Field.Resource.value] = get_service_integration_arn(Sqs, + SqsApi.SendMessage, + IntegrationPattern.WaitForTaskToken) else: """ Example resource arn: arn:aws:states:::sqs:sendMessage """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.SQS, - SqsApi.SendMessage) + + kwargs[Field.Resource.value] = get_service_integration_arn(Sqs, + SqsApi.SendMessage) super(SqsSendMessageStep, self).__init__(state_id, **kwargs) @@ -234,15 +272,17 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): """ Example resource arn: arn:aws:states:::elasticmapreduce:createCluster.sync """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, - ElasticMapReduceApi.CreateCluster, - IntegrationPattern.WaitForCompletion) + + kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + ElasticMapReduceApi.CreateCluster, + IntegrationPattern.WaitForCompletion) else: """ Example resource arn: arn:aws:states:::elasticmapreduce:createCluster """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, - ElasticMapReduceApi.CreateCluster) + + kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + ElasticMapReduceApi.CreateCluster) super(EmrCreateClusterStep, self).__init__(state_id, **kwargs) @@ -271,15 +311,17 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): """ Example resource arn: arn:aws:states:::elasticmapreduce:terminateCluster.sync """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, - ElasticMapReduceApi.TerminateCluster, - IntegrationPattern.WaitForCompletion) + + kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + ElasticMapReduceApi.TerminateCluster, + IntegrationPattern.WaitForCompletion) else: """ Example resource arn: arn:aws:states:::elasticmapreduce:terminateCluster """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, - ElasticMapReduceApi.TerminateCluster) + + kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + ElasticMapReduceApi.TerminateCluster) super(EmrTerminateClusterStep, self).__init__(state_id, **kwargs) @@ -308,15 +350,17 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): """ Example resource arn: arn:aws:states:::elasticmapreduce:addStep.sync """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, - ElasticMapReduceApi.AddStep, - IntegrationPattern.WaitForCompletion) + + kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + ElasticMapReduceApi.AddStep, + IntegrationPattern.WaitForCompletion) else: """ Example resource arn: arn:aws:states:::elasticmapreduce:addStep """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, - ElasticMapReduceApi.AddStep) + + kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + ElasticMapReduceApi.AddStep) super(EmrAddStepStep, self).__init__(state_id, **kwargs) @@ -344,8 +388,9 @@ def __init__(self, state_id, **kwargs): """ Example resource arn: arn:aws:states:::elasticmapreduce:cancelStep """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, - ElasticMapReduceApi.CancelStep) + + kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + ElasticMapReduceApi.CancelStep) super(EmrCancelStepStep, self).__init__(state_id, **kwargs) @@ -373,8 +418,9 @@ def __init__(self, state_id, **kwargs): """ Example resource arn: arn:aws:states:::elasticmapreduce:setClusterTerminationProtection """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, - ElasticMapReduceApi.SetClusterTerminationProtection) + + kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + ElasticMapReduceApi.SetClusterTerminationProtection) super(EmrSetClusterTerminationProtectionStep, self).__init__(state_id, **kwargs) @@ -402,8 +448,9 @@ def __init__(self, state_id, **kwargs): """ Example resource arn: arn:aws:states:::elasticmapreduce:modifyInstanceFleetByName """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, - ElasticMapReduceApi.ModifyInstanceFleetByName) + + kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + ElasticMapReduceApi.ModifyInstanceFleetByName) super(EmrModifyInstanceFleetByNameStep, self).__init__(state_id, **kwargs) @@ -431,7 +478,8 @@ def __init__(self, state_id, **kwargs): """ Example resource arn: arn:aws:states:::elasticmapreduce:modifyInstanceGroupByName """ - kwargs[Field.Resource.value] = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, - ElasticMapReduceApi.ModifyInstanceGroupByName) + + kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + ElasticMapReduceApi.ModifyInstanceGroupByName) super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index 89785f0..36103f0 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -15,6 +15,8 @@ import boto3 import logging +from stepfunctions.steps.integration_resources import IntegrationPattern + logger = logging.getLogger('stepfunctions') @@ -23,11 +25,13 @@ def tags_dict_to_kv_list(tags_dict): return kv_list -""" -Obtain matching aws partition name based on region -Return "aws" as default if no region detected -""" def get_aws_partition(): + + """ + Returns the aws partition for the current boto3 session. + Defaults to 'aws' if the region could not be detected. + """ + partitions = boto3.session.Session().get_available_partitions() cur_region = boto3.session.Session().region_name cur_partition = "aws" @@ -45,13 +49,19 @@ def get_aws_partition(): return cur_partition -""" -ARN builder for task integration -""" -def resource_integration_arn_builder(service, api, integration_pattern=None): +def get_service_integration_arn(service, api, integration_pattern=IntegrationPattern.RequestResponse): + + """ + ARN builder for task integration + Args: + service(str): name of the task resource service + api(enum): api to be integrated of the task resource service + integration_pattern(enum, optional): integration pattern for the task resource. + Default as request response. + """ arn = "" - if integration_pattern is None: - arn = f"arn:{get_aws_partition()}:states:::{service.value}:{api.value}" + if integration_pattern == IntegrationPattern.RequestResponse: + arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}" else: - arn = f"arn:{get_aws_partition()}:states:::{service.value}:{api.value}.{integration_pattern.value}" + arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}.{integration_pattern.value}" return arn diff --git a/tests/unit/test_steps_utils.py b/tests/unit/test_steps_utils.py index a2beb00..446c1d8 100644 --- a/tests/unit/test_steps_utils.py +++ b/tests/unit/test_steps_utils.py @@ -13,11 +13,18 @@ # Test if boto3 session can fetch correct aws partition info from test environment -from stepfunctions.steps.utils import get_aws_partition, resource_integration_arn_builder -from stepfunctions.steps.integration_resources import IntegrationPattern, IntegrationServices, LambdaApi, SageMakerApi,\ - GlueApi, EcsApi, BatchApi, DynamoDBApi, SnsApi, SqsApi, ElasticMapReduceApi +from stepfunctions.steps.utils import get_aws_partition, get_service_integration_arn +from stepfunctions.steps.integration_resources import IntegrationPattern import boto3 from unittest.mock import patch +from enum import Enum + + +testService = "sagemaker" + + +class TestApi(Enum): + CreateTrainingJob = "createTrainingJob" @patch.object(boto3.session.Session, 'region_name', 'us-east-1') @@ -32,112 +39,15 @@ def test_util_get_aws_partition_aws_cn(): assert cur_partition == "aws-cn" -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_lambda_no_wait(): - arn = resource_integration_arn_builder(IntegrationServices.Lambda, LambdaApi.Invoke) - assert arn == "arn:aws:states:::lambda:invoke" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_lambda_wait_token(): - arn = resource_integration_arn_builder(IntegrationServices.Lambda, LambdaApi.Invoke, - IntegrationPattern.WaitForTaskToken) - assert arn == "arn:aws:states:::lambda:invoke.waitForTaskToken" - - @patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_arn_builder_sagemaker_no_wait(): - arn = resource_integration_arn_builder(IntegrationServices.SageMaker, SageMakerApi.CreateTrainingJob) + arn = get_service_integration_arn(testService, TestApi.CreateTrainingJob) assert arn == "arn:aws:states:::sagemaker:createTrainingJob" @patch.object(boto3.session.Session, 'region_name', 'us-east-1') def test_arn_builder_sagemaker_wait_completion(): - arn = resource_integration_arn_builder(IntegrationServices.SageMaker, SageMakerApi.CreateTrainingJob, - IntegrationPattern.WaitForCompletion) + arn = get_service_integration_arn(testService, TestApi.CreateTrainingJob, + IntegrationPattern.WaitForCompletion) assert arn == "arn:aws:states:::sagemaker:createTrainingJob.sync" - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_glue_no_wait(): - arn = resource_integration_arn_builder(IntegrationServices.Glue, GlueApi.StartJobRun) - assert arn == "arn:aws:states:::glue:startJobRun" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_glue_wait_completion(): - arn = resource_integration_arn_builder(IntegrationServices.Glue, GlueApi.StartJobRun, - IntegrationPattern.WaitForCompletion) - assert arn == "arn:aws:states:::glue:startJobRun.sync" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_ecs_no_wait(): - arn = resource_integration_arn_builder(IntegrationServices.ECS, EcsApi.RunTask) - assert arn == "arn:aws:states:::ecs:runTask" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_ecs_wait_completion(): - arn = resource_integration_arn_builder(IntegrationServices.ECS, EcsApi.RunTask, - IntegrationPattern.WaitForCompletion) - assert arn == "arn:aws:states:::ecs:runTask.sync" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_batch_no_wait(): - arn = resource_integration_arn_builder(IntegrationServices.Batch, BatchApi.SubmitJob) - assert arn == "arn:aws:states:::batch:submitJob" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_batch_wait_completion(): - arn = resource_integration_arn_builder(IntegrationServices.Batch, BatchApi.SubmitJob, - IntegrationPattern.WaitForCompletion) - assert arn == "arn:aws:states:::batch:submitJob.sync" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_dynamodb_no_wait(): - arn = resource_integration_arn_builder(IntegrationServices.DynamoDB, DynamoDBApi.GetItem) - assert arn == "arn:aws:states:::dynamodb:getItem" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_sns_no_wait(): - arn = resource_integration_arn_builder(IntegrationServices.SNS, SnsApi.Publish) - assert arn == "arn:aws:states:::sns:publish" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_sns_wait_token(): - arn = resource_integration_arn_builder(IntegrationServices.SNS, SnsApi.Publish, - IntegrationPattern.WaitForTaskToken) - assert arn == "arn:aws:states:::sns:publish.waitForTaskToken" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_sqs_no_wait(): - arn = resource_integration_arn_builder(IntegrationServices.SQS, SqsApi.SendMessage) - assert arn == "arn:aws:states:::sqs:sendMessage" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_sqs_wait_token(): - arn = resource_integration_arn_builder(IntegrationServices.SQS, SqsApi.SendMessage, - IntegrationPattern.WaitForTaskToken) - assert arn == "arn:aws:states:::sqs:sendMessage.waitForTaskToken" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_elasticmapreduce_no_wait(): - arn = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, ElasticMapReduceApi.CreateCluster) - assert arn == "arn:aws:states:::elasticmapreduce:createCluster" - - -@patch.object(boto3.session.Session, 'region_name', 'us-east-1') -def test_arn_builder_elasticmapreduce_wait_completion(): - arn = resource_integration_arn_builder(IntegrationServices.ElasticMapReduce, ElasticMapReduceApi.CreateCluster, - IntegrationPattern.WaitForCompletion) - assert arn == "arn:aws:states:::elasticmapreduce:createCluster.sync" - From 162f9db42e11ba9ba5771461388817a8254d75e1 Mon Sep 17 00:00:00 2001 From: Bowen Yuan Date: Mon, 17 May 2021 14:15:03 -0700 Subject: [PATCH 5/6] fix: move arn builder method into integration_resources module --- src/stepfunctions/steps/compute.py | 31 ++++++------ .../steps/integration_resources.py | 18 +++++++ src/stepfunctions/steps/sagemaker.py | 30 ++++++------ src/stepfunctions/steps/service.py | 47 +++++++++---------- src/stepfunctions/steps/utils.py | 20 -------- tests/unit/test_steps_utils.py | 4 +- 6 files changed, 73 insertions(+), 77 deletions(-) diff --git a/src/stepfunctions/steps/compute.py b/src/stepfunctions/steps/compute.py index f8d17ac..203ed47 100644 --- a/src/stepfunctions/steps/compute.py +++ b/src/stepfunctions/steps/compute.py @@ -15,13 +15,12 @@ from enum import Enum from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import get_service_integration_arn -from stepfunctions.steps.integration_resources import IntegrationPattern +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn -Lambda = "lambda" -Glue = "glue" -Ecs = "ecs" -Batch = "batch" +LAMBDA_SERVICE_NAME = "lambda" +GLUE_SERVICE_NAME = "glue" +ECS_SERVICE_NAME = "ecs" +BATCH_SERVICE_NAME = "batch" class LambdaApi(Enum): @@ -67,15 +66,15 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): Example resource arn: arn:aws:states:::lambda:invoke.waitForTaskToken """ - kwargs[Field.Resource.value] = get_service_integration_arn(Lambda, - LambdaApi.Invoke, - IntegrationPattern.WaitForTaskToken) + kwargs[Field.Resource.value] = get_service_integration_arn(LAMBDA_SERVICE_NAME, + LambdaApi.Invoke, + IntegrationPattern.WaitForTaskToken) else: """ Example resource arn: arn:aws:states:::lambda:invoke """ - kwargs[Field.Resource.value] = get_service_integration_arn(Lambda, LambdaApi.Invoke) + kwargs[Field.Resource.value] = get_service_integration_arn(LAMBDA_SERVICE_NAME, LambdaApi.Invoke) super(LambdaStep, self).__init__(state_id, **kwargs) @@ -107,7 +106,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::glue:startJobRun.sync """ - kwargs[Field.Resource.value] = get_service_integration_arn(Glue, + kwargs[Field.Resource.value] = get_service_integration_arn(GLUE_SERVICE_NAME, GlueApi.StartJobRun, IntegrationPattern.WaitForCompletion) else: @@ -115,7 +114,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::glue:startJobRun """ - kwargs[Field.Resource.value] = get_service_integration_arn(Glue, + kwargs[Field.Resource.value] = get_service_integration_arn(GLUE_SERVICE_NAME, GlueApi.StartJobRun) super(GlueStartJobRunStep, self).__init__(state_id, **kwargs) @@ -147,7 +146,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::batch:submitJob.sync """ - kwargs[Field.Resource.value] = get_service_integration_arn(Batch, + kwargs[Field.Resource.value] = get_service_integration_arn(BATCH_SERVICE_NAME, BatchApi.SubmitJob, IntegrationPattern.WaitForCompletion) else: @@ -155,7 +154,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::batch:submitJob """ - kwargs[Field.Resource.value] = get_service_integration_arn(Batch, + kwargs[Field.Resource.value] = get_service_integration_arn(BATCH_SERVICE_NAME, BatchApi.SubmitJob) super(BatchSubmitJobStep, self).__init__(state_id, **kwargs) @@ -187,7 +186,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::ecs:runTask.sync """ - kwargs[Field.Resource.value] = get_service_integration_arn(Ecs, + kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, EcsApi.RunTask, IntegrationPattern.WaitForCompletion) else: @@ -195,7 +194,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::ecs:runTask """ - kwargs[Field.Resource.value] = get_service_integration_arn(Ecs, + kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME, EcsApi.RunTask) super(EcsRunTaskStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/integration_resources.py b/src/stepfunctions/steps/integration_resources.py index a2d56dc..ff654f5 100644 --- a/src/stepfunctions/steps/integration_resources.py +++ b/src/stepfunctions/steps/integration_resources.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from enum import Enum +from stepfunctions.steps.utils import get_aws_partition class IntegrationPattern(Enum): @@ -26,5 +27,22 @@ class IntegrationPattern(Enum): RequestResponse = "" +def get_service_integration_arn(service, api, integration_pattern=IntegrationPattern.RequestResponse): + + """ + ARN builder for task integration + Args: + service(str): name of the task resource service + api(Api): api to be integrated of the task resource service + integration_pattern(IntegrationPattern, optional): integration pattern for the task resource. + Default as request response. + """ + arn = "" + if integration_pattern == IntegrationPattern.RequestResponse: + arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}" + else: + arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}.{integration_pattern.value}" + return arn + diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 812d6ad..deb5176 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -16,14 +16,14 @@ from stepfunctions.inputs import ExecutionInput, StepInput from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import tags_dict_to_kv_list, get_service_integration_arn -from stepfunctions.steps.integration_resources import IntegrationPattern +from stepfunctions.steps.utils import tags_dict_to_kv_list +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config from sagemaker.model import Model, FrameworkModel from sagemaker.model_monitor import DataCaptureConfig -SageMaker = "sagemaker" +SAGEMAKER_SERVICE_NAME = "sagemaker" class SageMakerApi(Enum): @@ -78,7 +78,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateTrainingJob, IntegrationPattern.WaitForCompletion) else: @@ -86,7 +86,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non Example resource arn: arn:aws:states:::sagemaker:createTrainingJob """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateTrainingJob) if isinstance(job_name, str): @@ -172,7 +172,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateTransformJob, IntegrationPattern.WaitForCompletion) else: @@ -180,7 +180,7 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type= Example resource arn: arn:aws:states:::sagemaker:createTransformJob """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateTransformJob) if isinstance(job_name, str): @@ -268,7 +268,7 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No Example resource arn: arn:aws:states:::sagemaker:createModel """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateModel) super(ModelStep, self).__init__(state_id, **kwargs) @@ -314,7 +314,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_ Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateEndpointConfig) kwargs[Field.Parameters.value] = parameters @@ -352,14 +352,14 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd Example resource arn: arn:aws:states:::sagemaker:updateEndpoint """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.UpdateEndpoint) else: """ Example resource arn: arn:aws:states:::sagemaker:createEndpoint """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateEndpoint) kwargs[Field.Parameters.value] = parameters @@ -402,7 +402,7 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateHyperParameterTuningJob, IntegrationPattern.WaitForCompletion) else: @@ -410,7 +410,7 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateHyperParameterTuningJob) parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy() @@ -462,7 +462,7 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateProcessingJob, IntegrationPattern.WaitForCompletion) else: @@ -470,7 +470,7 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp Example resource arn: arn:aws:states:::sagemaker:createProcessingJob """ - kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker, + kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME, SageMakerApi.CreateProcessingJob) if isinstance(job_name, str): diff --git a/src/stepfunctions/steps/service.py b/src/stepfunctions/steps/service.py index 1bc368f..6bf155f 100644 --- a/src/stepfunctions/steps/service.py +++ b/src/stepfunctions/steps/service.py @@ -15,13 +15,12 @@ from enum import Enum from stepfunctions.steps.states import Task from stepfunctions.steps.fields import Field -from stepfunctions.steps.utils import get_service_integration_arn -from stepfunctions.steps.integration_resources import IntegrationPattern +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn -DynamoDB = "dynamodb" -Sns = "sns" -Sqs = "sqs" -ElasticMapReduce = "elasticmapreduce" +DYNAMODB_SERVICE_NAME = "dynamodb" +SNS_SERVICE_NAME = "sns" +SQS_SERVICE_NAME = "sqs" +ELASTICMAPREDUCE_SERVICE_NAME = "elasticmapreduce" class DynamoDBApi(Enum): @@ -73,7 +72,7 @@ def __init__(self, state_id, **kwargs): Example resource arn: arn:aws:states:::dynamodb:getItem """ - kwargs[Field.Resource.value] = get_service_integration_arn(DynamoDB, + kwargs[Field.Resource.value] = get_service_integration_arn(DYNAMODB_SERVICE_NAME, DynamoDBApi.GetItem) super(DynamoDBGetItemStep, self).__init__(state_id, **kwargs) @@ -103,7 +102,7 @@ def __init__(self, state_id, **kwargs): Example resource arn: arn:aws:states:::dynamodb:putItem """ - kwargs[Field.Resource.value] = get_service_integration_arn(DynamoDB, + kwargs[Field.Resource.value] = get_service_integration_arn(DYNAMODB_SERVICE_NAME, DynamoDBApi.PutItem) super(DynamoDBPutItemStep, self).__init__(state_id, **kwargs) @@ -133,7 +132,7 @@ def __init__(self, state_id, **kwargs): Example resource arn: arn:aws:states:::dynamodb:deleteItem """ - kwargs[Field.Resource.value] = get_service_integration_arn(DynamoDB, + kwargs[Field.Resource.value] = get_service_integration_arn(DYNAMODB_SERVICE_NAME, DynamoDBApi.DeleteItem) super(DynamoDBDeleteItemStep, self).__init__(state_id, **kwargs) @@ -163,7 +162,7 @@ def __init__(self, state_id, **kwargs): Example resource arn: arn:aws:states:::dynamodb:updateItem """ - kwargs[Field.Resource.value] = get_service_integration_arn(DynamoDB, + kwargs[Field.Resource.value] = get_service_integration_arn(DYNAMODB_SERVICE_NAME, DynamoDBApi.UpdateItem) super(DynamoDBUpdateItemStep, self).__init__(state_id, **kwargs) @@ -194,7 +193,7 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): Example resource arn: arn:aws:states:::sns:publish.waitForTaskToken """ - kwargs[Field.Resource.value] = get_service_integration_arn(Sns, + kwargs[Field.Resource.value] = get_service_integration_arn(SNS_SERVICE_NAME, SnsApi.Publish, IntegrationPattern.WaitForTaskToken) else: @@ -202,7 +201,7 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): Example resource arn: arn:aws:states:::sns:publish """ - kwargs[Field.Resource.value] = get_service_integration_arn(Sns, + kwargs[Field.Resource.value] = get_service_integration_arn(SNS_SERVICE_NAME, SnsApi.Publish) super(SnsPublishStep, self).__init__(state_id, **kwargs) @@ -234,7 +233,7 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): Example resource arn: arn:aws:states:::sqs:sendMessage.waitForTaskToken """ - kwargs[Field.Resource.value] = get_service_integration_arn(Sqs, + kwargs[Field.Resource.value] = get_service_integration_arn(SQS_SERVICE_NAME, SqsApi.SendMessage, IntegrationPattern.WaitForTaskToken) else: @@ -242,7 +241,7 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs): Example resource arn: arn:aws:states:::sqs:sendMessage """ - kwargs[Field.Resource.value] = get_service_integration_arn(Sqs, + kwargs[Field.Resource.value] = get_service_integration_arn(SQS_SERVICE_NAME, SqsApi.SendMessage) super(SqsSendMessageStep, self).__init__(state_id, **kwargs) @@ -273,7 +272,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::elasticmapreduce:createCluster.sync """ - kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, ElasticMapReduceApi.CreateCluster, IntegrationPattern.WaitForCompletion) else: @@ -281,7 +280,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::elasticmapreduce:createCluster """ - kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, ElasticMapReduceApi.CreateCluster) super(EmrCreateClusterStep, self).__init__(state_id, **kwargs) @@ -312,7 +311,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::elasticmapreduce:terminateCluster.sync """ - kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, ElasticMapReduceApi.TerminateCluster, IntegrationPattern.WaitForCompletion) else: @@ -320,7 +319,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::elasticmapreduce:terminateCluster """ - kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, ElasticMapReduceApi.TerminateCluster) super(EmrTerminateClusterStep, self).__init__(state_id, **kwargs) @@ -351,7 +350,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::elasticmapreduce:addStep.sync """ - kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, ElasticMapReduceApi.AddStep, IntegrationPattern.WaitForCompletion) else: @@ -359,7 +358,7 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs): Example resource arn: arn:aws:states:::elasticmapreduce:addStep """ - kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, ElasticMapReduceApi.AddStep) super(EmrAddStepStep, self).__init__(state_id, **kwargs) @@ -389,7 +388,7 @@ def __init__(self, state_id, **kwargs): Example resource arn: arn:aws:states:::elasticmapreduce:cancelStep """ - kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, ElasticMapReduceApi.CancelStep) super(EmrCancelStepStep, self).__init__(state_id, **kwargs) @@ -419,7 +418,7 @@ def __init__(self, state_id, **kwargs): Example resource arn: arn:aws:states:::elasticmapreduce:setClusterTerminationProtection """ - kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, ElasticMapReduceApi.SetClusterTerminationProtection) super(EmrSetClusterTerminationProtectionStep, self).__init__(state_id, **kwargs) @@ -449,7 +448,7 @@ def __init__(self, state_id, **kwargs): Example resource arn: arn:aws:states:::elasticmapreduce:modifyInstanceFleetByName """ - kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, ElasticMapReduceApi.ModifyInstanceFleetByName) super(EmrModifyInstanceFleetByNameStep, self).__init__(state_id, **kwargs) @@ -479,7 +478,7 @@ def __init__(self, state_id, **kwargs): Example resource arn: arn:aws:states:::elasticmapreduce:modifyInstanceGroupByName """ - kwargs[Field.Resource.value] = get_service_integration_arn(ElasticMapReduce, + kwargs[Field.Resource.value] = get_service_integration_arn(ELASTICMAPREDUCE_SERVICE_NAME, ElasticMapReduceApi.ModifyInstanceGroupByName) super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs) diff --git a/src/stepfunctions/steps/utils.py b/src/stepfunctions/steps/utils.py index 36103f0..6f44481 100644 --- a/src/stepfunctions/steps/utils.py +++ b/src/stepfunctions/steps/utils.py @@ -15,8 +15,6 @@ import boto3 import logging -from stepfunctions.steps.integration_resources import IntegrationPattern - logger = logging.getLogger('stepfunctions') @@ -47,21 +45,3 @@ def get_aws_partition(): return cur_partition return cur_partition - - -def get_service_integration_arn(service, api, integration_pattern=IntegrationPattern.RequestResponse): - - """ - ARN builder for task integration - Args: - service(str): name of the task resource service - api(enum): api to be integrated of the task resource service - integration_pattern(enum, optional): integration pattern for the task resource. - Default as request response. - """ - arn = "" - if integration_pattern == IntegrationPattern.RequestResponse: - arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}" - else: - arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}.{integration_pattern.value}" - return arn diff --git a/tests/unit/test_steps_utils.py b/tests/unit/test_steps_utils.py index 446c1d8..6eb0885 100644 --- a/tests/unit/test_steps_utils.py +++ b/tests/unit/test_steps_utils.py @@ -13,8 +13,8 @@ # Test if boto3 session can fetch correct aws partition info from test environment -from stepfunctions.steps.utils import get_aws_partition, get_service_integration_arn -from stepfunctions.steps.integration_resources import IntegrationPattern +from stepfunctions.steps.utils import get_aws_partition +from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn import boto3 from unittest.mock import patch from enum import Enum From ae8570682e7d97579e7062a80adfd8393e6113ac Mon Sep 17 00:00:00 2001 From: Bowen Yuan <83046180+yuan-bwn@users.noreply.github.com> Date: Wed, 19 May 2021 16:43:13 -0700 Subject: [PATCH 6/6] Update src/stepfunctions/steps/integration_resources.py Co-authored-by: Adam Wong <55506708+wong-a@users.noreply.github.com> --- src/stepfunctions/steps/integration_resources.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/stepfunctions/steps/integration_resources.py b/src/stepfunctions/steps/integration_resources.py index ff654f5..5223f14 100644 --- a/src/stepfunctions/steps/integration_resources.py +++ b/src/stepfunctions/steps/integration_resources.py @@ -32,10 +32,9 @@ def get_service_integration_arn(service, api, integration_pattern=IntegrationPat """ ARN builder for task integration Args: - service(str): name of the task resource service - api(Api): api to be integrated of the task resource service - integration_pattern(IntegrationPattern, optional): integration pattern for the task resource. - Default as request response. + service (str): The service name for the service integration + api (str): The api of the service integration + integration_pattern (IntegrationPattern, optional): The integration pattern for the task. (Default: IntegrationPattern.RequestResponse) """ arn = "" if integration_pattern == IntegrationPattern.RequestResponse: @@ -45,4 +44,3 @@ def get_service_integration_arn(service, api, integration_pattern=IntegrationPat return arn -