Skip to content

Commit ea3e482

Browse files
authored
Merge branch 'main' into fix-model-step-env
2 parents 04cd813 + d2ce83d commit ea3e482

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

src/stepfunctions/workflow/cloudformation.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818

1919
logger = logging.getLogger('stepfunctions')
2020

21+
2122
def repr_str(dumper, data):
2223
if '\n' in data:
2324
return dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')
2425
return dumper.org_represent_str(data)
2526

27+
2628
yaml.SafeDumper.org_represent_str = yaml.SafeDumper.represent_str
2729
yaml.add_representer(dict, lambda self, data: yaml.representer.SafeRepresenter.represent_dict(self, data.items()), Dumper=yaml.SafeDumper)
2830
yaml.add_representer(str, repr_str, Dumper=yaml.SafeDumper)
@@ -42,12 +44,19 @@ def repr_str(dumper, data):
4244
}
4345
}
4446

45-
def build_cloudformation_template(workflow):
47+
48+
def build_cloudformation_template(workflow, description=None):
49+
"""
50+
Creates a CloudFormation template from the provided Workflow
51+
Args:
52+
workflow (Workflow): Step Functions workflow instance
53+
description (str, optional): Description of the template. If none provided, the default description will be used: "CloudFormation template for AWS Step Functions - State Machine"
54+
"""
4655
logger.warning('To reuse the CloudFormation template in different regions, please make sure to update the region specific AWS resources in the StateMachine definition.')
4756

4857
template = CLOUDFORMATION_BASE_TEMPLATE.copy()
4958

50-
template["Description"] = "CloudFormation template for AWS Step Functions - State Machine"
59+
template["Description"] = description if description else "CloudFormation template for AWS Step Functions - State Machine"
5160
template["Resources"]["StateMachineComponent"]["Properties"]["StateMachineName"] = workflow.name
5261

5362
definition = workflow.definition.to_dict()

src/stepfunctions/workflow/stepfunctions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,11 +376,13 @@ def render_graph(self, portrait=False):
376376
widget = WorkflowGraphWidget(self.definition.to_json())
377377
return widget.show(portrait=portrait)
378378

379-
def get_cloudformation_template(self):
379+
def get_cloudformation_template(self, description=None):
380380
"""
381381
Returns a CloudFormation template that contains only the StateMachine resource. To reuse the CloudFormation template in a different region, please make sure to update the region specific AWS resources (e.g: Lambda ARN, Training Image) in the StateMachine definition.
382+
Args:
383+
description (str, optional): Description of the template
382384
"""
383-
return build_cloudformation_template(self)
385+
return build_cloudformation_template(self, description)
384386

385387
def __repr__(self):
386388
return '{}(name={!r}, role={!r}, state_machine_arn={!r})'.format(

tests/unit/test_workflow.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def client():
5656
})
5757
return sfn
5858

59+
5960
@pytest.fixture
6061
def workflow(client):
6162
workflow = Workflow(
@@ -67,9 +68,11 @@ def workflow(client):
6768
workflow.create()
6869
return workflow
6970

71+
7072
def test_workflow_creation(client, workflow):
7173
assert workflow.state_machine_arn == state_machine_arn
7274

75+
7376
def test_workflow_creation_failure_duplicate_state_ids(client):
7477
improper_definition = steps.Chain([steps.Pass('HelloWorld'), steps.Succeed('HelloWorld')])
7578
with pytest.raises(ValueError):
@@ -80,6 +83,7 @@ def test_workflow_creation_failure_duplicate_state_ids(client):
8083
client=client
8184
)
8285

86+
8387
# calling update() before create()
8488
def test_workflow_update_when_statemachinearn_is_none(client):
8589
workflow = Workflow(
@@ -92,11 +96,13 @@ def test_workflow_update_when_statemachinearn_is_none(client):
9296
with pytest.raises(WorkflowNotFound):
9397
workflow.update(definition=new_definition)
9498

99+
95100
# calling update() after create() without arguments
96101
def test_workflow_update_when_arguments_are_missing(client, workflow):
97102
with pytest.raises(MissingRequiredParameter):
98103
workflow.update()
99104

105+
100106
# calling update() after create()
101107
def test_workflow_update(client, workflow):
102108
client.update_state_machine = MagicMock(return_value={
@@ -106,12 +112,14 @@ def test_workflow_update(client, workflow):
106112
new_role = 'arn:aws:iam::1234567890:role/service-role/StepFunctionsRoleNew'
107113
assert workflow.update(definition=new_definition, role=new_role) == state_machine_arn
108114

115+
109116
def test_attach_existing_workflow(client):
110117
workflow = Workflow.attach(state_machine_arn, client)
111118
assert workflow.name == state_machine_name
112119
assert workflow.role == role_arn
113120
assert workflow.state_machine_arn == state_machine_arn
114121

122+
115123
def test_workflow_list_executions(client, workflow):
116124
paginator = client.get_paginator('list_executions')
117125
paginator.paginate = MagicMock(return_value=[
@@ -140,12 +148,14 @@ def test_workflow_list_executions(client, workflow):
140148
workflow.state_machine_arn = None
141149
assert workflow.list_executions() == []
142150

151+
143152
def test_workflow_makes_deletion_call(client, workflow):
144153
client.delete_state_machine = MagicMock(return_value=None)
145154
workflow.delete()
146155

147156
client.delete_state_machine.assert_called_once_with(stateMachineArn=state_machine_arn)
148157

158+
149159
def test_workflow_execute_creation(client, workflow):
150160
execution = workflow.execute()
151161
assert execution.workflow.state_machine_arn == state_machine_arn
@@ -164,11 +174,13 @@ def test_workflow_execute_creation(client, workflow):
164174
input='{}'
165175
)
166176

177+
167178
def test_workflow_execute_when_statemachinearn_is_none(client, workflow):
168179
workflow.state_machine_arn = None
169180
with pytest.raises(WorkflowNotFound):
170181
workflow.execute()
171182

183+
172184
def test_execution_makes_describe_call(client, workflow):
173185
execution = workflow.execute()
174186

@@ -177,6 +189,7 @@ def test_execution_makes_describe_call(client, workflow):
177189

178190
client.describe_execution.assert_called_once()
179191

192+
180193
def test_execution_makes_stop_call(client, workflow):
181194
execution = workflow.execute()
182195

@@ -194,6 +207,7 @@ def test_execution_makes_stop_call(client, workflow):
194207
error='Error'
195208
)
196209

210+
197211
def test_execution_list_events(client, workflow):
198212
paginator = client.get_paginator('get_execution_history')
199213
paginator.paginate = MagicMock(return_value=[
@@ -229,6 +243,7 @@ def test_execution_list_events(client, workflow):
229243
}
230244
)
231245

246+
232247
def test_list_workflows(client):
233248
paginator = client.get_paginator('list_state_machines')
234249
paginator.paginate = MagicMock(return_value=[
@@ -254,11 +269,14 @@ def test_list_workflows(client):
254269
}
255270
)
256271

272+
257273
def test_cloudformation_export_with_simple_definition(workflow):
258274
cfn_template = workflow.get_cloudformation_template()
259275
cfn_template = yaml.load(cfn_template)
260276
assert 'StateMachineComponent' in cfn_template['Resources']
261277
assert workflow.role == cfn_template['Resources']['StateMachineComponent']['Properties']['RoleArn']
278+
assert cfn_template['Description'] == "CloudFormation template for AWS Step Functions - State Machine"
279+
262280

263281
def test_cloudformation_export_with_sagemaker_execution_role(workflow):
264282
workflow.definition.to_dict = MagicMock(return_value={
@@ -281,7 +299,8 @@ def test_cloudformation_export_with_sagemaker_execution_role(workflow):
281299
}
282300
}
283301
})
284-
cfn_template = workflow.get_cloudformation_template()
302+
cfn_template = workflow.get_cloudformation_template(description="CloudFormation template with Sagemaker role")
285303
cfn_template = yaml.load(cfn_template)
286304
assert json.dumps(workflow.definition.to_dict(), indent=2) == cfn_template['Resources']['StateMachineComponent']['Properties']['DefinitionString']
287305
assert workflow.role == cfn_template['Resources']['StateMachineComponent']['Properties']['RoleArn']
306+
assert cfn_template['Description'] == "CloudFormation template with Sagemaker role"

0 commit comments

Comments
 (0)