Skip to content

fix: make arns of all task resources aws-partition aware #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/stepfunctions/steps/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import get_aws_partition


class LambdaStep(Task):
Expand All @@ -38,9 +39,9 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
if wait_for_callback:
kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke.waitForTaskToken'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::lambda:invoke.waitForTaskToken'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily blocking: I would like to at least consider class for ARN construction instead of building strings with repeated calls to get_aws_partition() inline everywhere. The AWS SDK in other languages have utilities for this, but python doesn't.

Something like:

kwards[Field.Resource.value] = ARN(service='states', resource='lambda:invoke')

Or even scoping it to service integration ARNs:

kwargs[Field.Resource.value] = ServiceIntegrationARN(service='lambda', api='invoke', integration_pattern=IntegrationPattern.WAIT_FOR_TASK_TOKEN)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will definitely make the code cleaner. I will look into it.
If I can not finish this quickly I will prepare a PR specific for it later.

else:
kwargs[Field.Resource.value] = 'arn:aws:states:::lambda:invoke'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::lambda:invoke'

super(LambdaStep, self).__init__(state_id, **kwargs)

Expand All @@ -67,9 +68,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun.sync'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::glue:startJobRun.sync'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice if we extracted the logic for generating ARNs into a utility.

additionally, consider string interpolation using f-strings rather than string concatenation, which is supported in Python 3.6+ (our current pre-requisites)

you can read more about it here

else:
kwargs[Field.Resource.value] = 'arn:aws:states:::glue:startJobRun'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::glue:startJobRun'

super(GlueStartJobRunStep, self).__init__(state_id, **kwargs)

Expand All @@ -96,9 +97,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob.sync'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::batch:submitJob.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::batch:submitJob'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::batch:submitJob'

super(BatchSubmitJobStep, self).__init__(state_id, **kwargs)

Expand All @@ -125,8 +126,8 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask.sync'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::ecs:runTask.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::ecs:runTask'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::ecs:runTask'

super(EcsRunTaskStep, self).__init__(state_id, **kwargs)
26 changes: 13 additions & 13 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from stepfunctions.inputs import ExecutionInput, StepInput
from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import tags_dict_to_kv_list
from stepfunctions.steps.utils import tags_dict_to_kv_list, get_aws_partition

from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
from sagemaker.model import Model, FrameworkModel
Expand Down Expand Up @@ -58,9 +58,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
self.job_name = job_name

if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob.sync'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTrainingJob.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTrainingJob'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTrainingJob'

if isinstance(job_name, str):
parameters = training_config(estimator=estimator, inputs=data, job_name=job_name, mini_batch_size=mini_batch_size)
Expand Down Expand Up @@ -141,9 +141,9 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
join_source (str): The source of data to be joined to the transform output. It can be set to ‘Input’ meaning the entire input record will be joined to the inference result. You can use OutputFilter to select the useful portion before uploading to S3. (default: None). Valid values: Input, None.
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob.sync'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTransformJob.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createTransformJob'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createTransformJob'

if isinstance(job_name, str):
parameters = transform_config(
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
parameters['Tags'] = tags_dict_to_kv_list(tags)

kwargs[Field.Parameters.value] = parameters
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createModel'

super(ModelStep, self).__init__(state_id, **kwargs)

Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)

kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpointConfig'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createEndpointConfig'
kwargs[Field.Parameters.value] = parameters

super(EndpointConfigStep, self).__init__(state_id, **kwargs)
Expand Down Expand Up @@ -298,9 +298,9 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
parameters['Tags'] = tags_dict_to_kv_list(tags)

if update:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:updateEndpoint'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:updateEndpoint'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpoint'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createEndpoint'

kwargs[Field.Parameters.value] = parameters

Expand Down Expand Up @@ -338,9 +338,9 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createHyperParameterTuningJob.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createHyperParameterTuningJob'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createHyperParameterTuningJob'

parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()

Expand Down Expand Up @@ -387,9 +387,9 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
tags (list[dict], optional): `List to tags <https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html>`_ to associate with the resource.
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob.sync'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createProcessingJob.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createProcessingJob'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sagemaker:createProcessingJob'

if isinstance(job_name, str):
parameters = processing_config(processor=processor, inputs=inputs, outputs=outputs, container_arguments=container_arguments, container_entrypoint=container_entrypoint, kms_key_id=kms_key_id, job_name=job_name)
Expand Down
37 changes: 19 additions & 18 deletions src/stepfunctions/steps/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import get_aws_partition


class DynamoDBGetItemStep(Task):
Expand All @@ -35,7 +36,7 @@ def __init__(self, state_id, **kwargs):
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:getItem'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:getItem'
super(DynamoDBGetItemStep, self).__init__(state_id, **kwargs)


Expand All @@ -59,7 +60,7 @@ def __init__(self, state_id, **kwargs):
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:putItem'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:putItem'
super(DynamoDBPutItemStep, self).__init__(state_id, **kwargs)


Expand All @@ -83,7 +84,7 @@ def __init__(self, state_id, **kwargs):
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:deleteItem'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:deleteItem'
super(DynamoDBDeleteItemStep, self).__init__(state_id, **kwargs)


Expand All @@ -107,7 +108,7 @@ def __init__(self, state_id, **kwargs):
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:updateItem'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::dynamodb:updateItem'
super(DynamoDBUpdateItemStep, self).__init__(state_id, **kwargs)


Expand All @@ -133,9 +134,9 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
if wait_for_callback:
kwargs[Field.Resource.value] = 'arn:aws:states:::sns:publish.waitForTaskToken'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sns:publish.waitForTaskToken'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::sns:publish'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sns:publish'

super(SnsPublishStep, self).__init__(state_id, **kwargs)

Expand All @@ -162,9 +163,9 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
if wait_for_callback:
kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage.waitForTaskToken'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sqs:sendMessage.waitForTaskToken'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::sqs:sendMessage'

super(SqsSendMessageStep, self).__init__(state_id, **kwargs)

Expand All @@ -190,9 +191,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True)
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster.sync'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:createCluster.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:createCluster'

super(EmrCreateClusterStep, self).__init__(state_id, **kwargs)

Expand All @@ -218,9 +219,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True)
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:terminateCluster.sync'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:terminateCluster.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:terminateCluster'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:terminateCluster'

super(EmrTerminateClusterStep, self).__init__(state_id, **kwargs)

Expand All @@ -246,9 +247,9 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True)
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep.sync'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:addStep.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:addStep'

super(EmrAddStepStep, self).__init__(state_id, **kwargs)

Expand All @@ -272,7 +273,7 @@ def __init__(self, state_id, **kwargs):
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:cancelStep'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:cancelStep'

super(EmrCancelStepStep, self).__init__(state_id, **kwargs)

Expand All @@ -296,7 +297,7 @@ def __init__(self, state_id, **kwargs):
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:setClusterTerminationProtection'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:setClusterTerminationProtection'

super(EmrSetClusterTerminationProtectionStep, self).__init__(state_id, **kwargs)

Expand All @@ -320,7 +321,7 @@ def __init__(self, state_id, **kwargs):
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:modifyInstanceFleetByName'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:modifyInstanceFleetByName'

super(EmrModifyInstanceFleetByNameStep, self).__init__(state_id, **kwargs)

Expand All @@ -344,7 +345,7 @@ def __init__(self, state_id, **kwargs):
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:modifyInstanceGroupByName'
kwargs[Field.Resource.value] = 'arn:' + get_aws_partition() + ':states:::elasticmapreduce:modifyInstanceGroupByName'

super(EmrModifyInstanceGroupByNameStep, self).__init__(state_id, **kwargs)

30 changes: 28 additions & 2 deletions src/stepfunctions/steps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,32 @@
# permissions and limitations under the License.
from __future__ import absolute_import

import boto3
import logging

logger = logging.getLogger('stepfunctions')


def tags_dict_to_kv_list(tags_dict):
kv_list = [{"Key": k, "Value": v} for k,v in tags_dict.items()]
return kv_list
kv_list = [{"Key": k, "Value": v} for k, v in tags_dict.items()]
return kv_list


# Obtain matching aws partition name based on region
# Retrun "aws" as default if no region detected
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo in "Retrun". May be worth mentioning this is from the default boto3 session.

Suggested change
# Obtain matching aws partition name based on region
# Retrun "aws" as default if no region detected
# Obtain matching aws partition name based on region
"""
Returns the aws partition for the the current boto3 session.
Defaults to 'aws' if the region could not be detected.
"""

Python method comments are also usually inside the method blocks as a docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing. Will make the fix in new commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oups! Typo here: "Return"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! I will make the change in next commit

def get_aws_partition():
partitions = boto3.session.Session().get_available_partitions()
cur_region = boto3.session.Session().region_name
cur_partition = "aws"

if cur_region is None:
logger.warning("No region detected for the session, will use default partition: aws")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it's worth mentioning this is for a boto3 session

Suggested change
logger.warning("No region detected for the session, will use default partition: aws")
logger.warning("No region detected for the boto3 session. Using default partition: aws")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Will do

return cur_partition

for partition in partitions:
regions = boto3.session.Session().get_available_regions("stepfunctions", partition)
if cur_region in regions:
cur_partition = partition
return cur_partition

return cur_partition
Loading