Skip to content

fix: make arns of all task resources aws-partition aware #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 21, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions src/stepfunctions/steps/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
from enum import Enum
from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import get_service_integration_arn
from stepfunctions.steps.integration_resources import IntegrationPattern
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn

Lambda = "lambda"
Glue = "glue"
Ecs = "ecs"
Batch = "batch"
LAMBDA_SERVICE_NAME = "lambda"
GLUE_SERVICE_NAME = "glue"
ECS_SERVICE_NAME = "ecs"
BATCH_SERVICE_NAME = "batch"


class LambdaApi(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (not blocking): Not necessary to be a class

Expand Down Expand Up @@ -67,15 +66,15 @@ def __init__(self, state_id, wait_for_callback=False, **kwargs):
Example resource arn: arn:aws:states:::lambda:invoke.waitForTaskToken
"""

kwargs[Field.Resource.value] = get_service_integration_arn(Lambda,
LambdaApi.Invoke,
IntegrationPattern.WaitForTaskToken)
kwargs[Field.Resource.value] = get_service_integration_arn(LAMBDA_SERVICE_NAME,
LambdaApi.Invoke,
IntegrationPattern.WaitForTaskToken)
else:
"""
Example resource arn: arn:aws:states:::lambda:invoke
"""

kwargs[Field.Resource.value] = get_service_integration_arn(Lambda, LambdaApi.Invoke)
kwargs[Field.Resource.value] = get_service_integration_arn(LAMBDA_SERVICE_NAME, LambdaApi.Invoke)


super(LambdaStep, self).__init__(state_id, **kwargs)
Expand Down Expand Up @@ -107,15 +106,15 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
Example resource arn: arn:aws:states:::glue:startJobRun.sync
"""

kwargs[Field.Resource.value] = get_service_integration_arn(Glue,
kwargs[Field.Resource.value] = get_service_integration_arn(GLUE_SERVICE_NAME,
GlueApi.StartJobRun,
IntegrationPattern.WaitForCompletion)
else:
"""
Example resource arn: arn:aws:states:::glue:startJobRun
"""

kwargs[Field.Resource.value] = get_service_integration_arn(Glue,
kwargs[Field.Resource.value] = get_service_integration_arn(GLUE_SERVICE_NAME,
GlueApi.StartJobRun)

super(GlueStartJobRunStep, self).__init__(state_id, **kwargs)
Expand Down Expand Up @@ -147,15 +146,15 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
Example resource arn: arn:aws:states:::batch:submitJob.sync
"""

kwargs[Field.Resource.value] = get_service_integration_arn(Batch,
kwargs[Field.Resource.value] = get_service_integration_arn(BATCH_SERVICE_NAME,
BatchApi.SubmitJob,
IntegrationPattern.WaitForCompletion)
else:
"""
Example resource arn: arn:aws:states:::batch:submitJob
"""

kwargs[Field.Resource.value] = get_service_integration_arn(Batch,
kwargs[Field.Resource.value] = get_service_integration_arn(BATCH_SERVICE_NAME,
BatchApi.SubmitJob)

super(BatchSubmitJobStep, self).__init__(state_id, **kwargs)
Expand Down Expand Up @@ -187,15 +186,15 @@ def __init__(self, state_id, wait_for_completion=True, **kwargs):
Example resource arn: arn:aws:states:::ecs:runTask.sync
"""

kwargs[Field.Resource.value] = get_service_integration_arn(Ecs,
kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME,
EcsApi.RunTask,
IntegrationPattern.WaitForCompletion)
else:
"""
Example resource arn: arn:aws:states:::ecs:runTask
"""

kwargs[Field.Resource.value] = get_service_integration_arn(Ecs,
kwargs[Field.Resource.value] = get_service_integration_arn(ECS_SERVICE_NAME,
EcsApi.RunTask)

super(EcsRunTaskStep, self).__init__(state_id, **kwargs)
18 changes: 18 additions & 0 deletions src/stepfunctions/steps/integration_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

from enum import Enum
from stepfunctions.steps.utils import get_aws_partition


class IntegrationPattern(Enum):
Expand All @@ -26,5 +27,22 @@ class IntegrationPattern(Enum):
RequestResponse = ""


def get_service_integration_arn(service, api, integration_pattern=IntegrationPattern.RequestResponse):

"""
ARN builder for task integration
Args:
service(str): name of the task resource service
api(<Service>Api): api to be integrated of the task resource service
integration_pattern(IntegrationPattern, optional): integration pattern for the task resource.
Default as request response.
"""
arn = ""
if integration_pattern == IntegrationPattern.RequestResponse:
arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}"
else:
arn = f"arn:{get_aws_partition()}:states:::{service}:{api.value}.{integration_pattern.value}"
return arn
Comment on lines +40 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really nice to see us leveraging f-strings!




30 changes: 15 additions & 15 deletions src/stepfunctions/steps/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from stepfunctions.inputs import ExecutionInput, StepInput
from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import tags_dict_to_kv_list, get_service_integration_arn
from stepfunctions.steps.integration_resources import IntegrationPattern
from stepfunctions.steps.utils import tags_dict_to_kv_list
from stepfunctions.steps.integration_resources import IntegrationPattern, get_service_integration_arn

from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config, processing_config
from sagemaker.model import Model, FrameworkModel
from sagemaker.model_monitor import DataCaptureConfig

SageMaker = "sagemaker"
SAGEMAKER_SERVICE_NAME = "sagemaker"


class SageMakerApi(Enum):
Expand Down Expand Up @@ -78,15 +78,15 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob.sync
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateTrainingJob,
IntegrationPattern.WaitForCompletion)
else:
"""
Example resource arn: arn:aws:states:::sagemaker:createTrainingJob
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateTrainingJob)

if isinstance(job_name, str):
Expand Down Expand Up @@ -172,15 +172,15 @@ def __init__(self, state_id, transformer, job_name, model_name, data, data_type=
Example resource arn: arn:aws:states:::sagemaker:createTransformJob.sync
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateTransformJob,
IntegrationPattern.WaitForCompletion)
else:
"""
Example resource arn: arn:aws:states:::sagemaker:createTransformJob
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateTransformJob)

if isinstance(job_name, str):
Expand Down Expand Up @@ -268,7 +268,7 @@ def __init__(self, state_id, model, model_name=None, instance_type=None, tags=No
Example resource arn: arn:aws:states:::sagemaker:createModel
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateModel)

super(ModelStep, self).__init__(state_id, **kwargs)
Expand Down Expand Up @@ -314,7 +314,7 @@ def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_
Example resource arn: arn:aws:states:::sagemaker:createEndpointConfig
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateEndpointConfig)

kwargs[Field.Parameters.value] = parameters
Expand Down Expand Up @@ -352,14 +352,14 @@ def __init__(self, state_id, endpoint_name, endpoint_config_name, tags=None, upd
Example resource arn: arn:aws:states:::sagemaker:updateEndpoint
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.UpdateEndpoint)
else:
"""
Example resource arn: arn:aws:states:::sagemaker:createEndpoint
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateEndpoint)

kwargs[Field.Parameters.value] = parameters
Expand Down Expand Up @@ -402,15 +402,15 @@ def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, ta
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob.sync
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateHyperParameterTuningJob,
IntegrationPattern.WaitForCompletion)
else:
"""
Example resource arn: arn:aws:states:::sagemaker:createHyperParameterTuningJob
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateHyperParameterTuningJob)

parameters = tuning_config(tuner=tuner, inputs=data, job_name=job_name).copy()
Expand Down Expand Up @@ -462,15 +462,15 @@ def __init__(self, state_id, processor, job_name, inputs=None, outputs=None, exp
Example resource arn: arn:aws:states:::sagemaker:createProcessingJob.sync
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateProcessingJob,
IntegrationPattern.WaitForCompletion)
else:
"""
Example resource arn: arn:aws:states:::sagemaker:createProcessingJob
"""

kwargs[Field.Resource.value] = get_service_integration_arn(SageMaker,
kwargs[Field.Resource.value] = get_service_integration_arn(SAGEMAKER_SERVICE_NAME,
SageMakerApi.CreateProcessingJob)

if isinstance(job_name, str):
Expand Down
Loading