Skip to content

Commit 7501319

Browse files
authored
fix: Support placeholders for hyperparameters passed to TrainingStep (#159)
* Make hyperparameters compatible with placeholders
1 parent 74d0f07 commit 7501319

File tree

2 files changed

+52
-11
lines changed

2 files changed

+52
-11
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
6969
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
7070
:class:`sagemaker.amazon.amazon_estimator.RecordSet` objects,
7171
where each instance is a different channel of training data.
72-
hyperparameters (dict, optional): Parameters used for training.
73-
Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator.
72+
hyperparameters: Parameters used for training.
73+
* (dict, optional) - Hyperparameters supplied will be merged with the Hyperparameters specified in the estimator.
7474
If there are duplicate entries, the value provided through this property will be used. (Default: Hyperparameters specified in the estimator.)
75+
* (Placeholder, optional) - The TrainingStep will use the hyperparameters specified by the Placeholder's value instead of the hyperparameters specified in the estimator.
7576
mini_batch_size (int): Specify this argument only when estimator is a built-in estimator of an Amazon algorithm. For other estimators, batch size should be specified in the estimator.
7677
experiment_config (dict, optional): Specify the experiment config for the training. (Default: None)
7778
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait for the training job to complete before proceeding to the next step in the workflow. Set to `False` if the Task state should submit the training job and proceed to the next step. (default: True)
@@ -127,8 +128,9 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
127128
parameters['InputDataConfig'][0]['DataSource']['S3DataSource']['S3Uri.$'] = data_uri
128129

129130
if hyperparameters is not None:
130-
if estimator.hyperparameters() is not None:
131-
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
131+
if not isinstance(hyperparameters, Placeholder):
132+
if estimator.hyperparameters() is not None:
133+
hyperparameters = self.__merge_hyperparameters(hyperparameters, estimator.hyperparameters())
132134
parameters['HyperParameters'] = hyperparameters
133135

134136
if experiment_config is not None:

tests/unit/test_sagemaker_steps.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def test_training_step_creation_with_placeholders(pca_estimator):
275275
execution_input = ExecutionInput(schema={
276276
'Data': str,
277277
'OutputPath': str,
278+
'HyperParameters': str
278279
})
279280

280281
step_input = StepInput(schema={
@@ -292,6 +293,7 @@ def test_training_step_creation_with_placeholders(pca_estimator):
292293
'TrialComponentDisplayName': 'Training'
293294
},
294295
tags=DEFAULT_TAGS,
296+
hyperparameters=execution_input['HyperParameters']
295297
)
296298
assert step.to_dict() == {
297299
'Type': 'Task',
@@ -312,13 +314,7 @@ def test_training_step_creation_with_placeholders(pca_estimator):
312314
'VolumeSizeInGB': 30
313315
},
314316
'RoleArn': EXECUTION_ROLE,
315-
'HyperParameters': {
316-
'feature_dim': '50000',
317-
'num_components': '10',
318-
'subtract_mean': 'True',
319-
'algorithm_mode': 'randomized',
320-
'mini_batch_size': '200'
321-
},
317+
'HyperParameters.$': "$$.Execution.Input['HyperParameters']",
322318
'InputDataConfig': [
323319
{
324320
'ChannelName': 'training',
@@ -344,6 +340,49 @@ def test_training_step_creation_with_placeholders(pca_estimator):
344340
}
345341

346342

343+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
344+
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
345+
def test_training_step_creation_with_hyperparameters_containing_placeholders(pca_estimator):
346+
execution_input = ExecutionInput(schema={
347+
'Data': str,
348+
'OutputPath': str,
349+
'num_components': str,
350+
'HyperParamA': str,
351+
'HyperParamB': str,
352+
})
353+
354+
step_input = StepInput(schema={
355+
'JobName': str,
356+
})
357+
358+
step = TrainingStep('Training',
359+
estimator=pca_estimator,
360+
job_name=step_input['JobName'],
361+
data=execution_input['Data'],
362+
output_data_config_path=execution_input['OutputPath'],
363+
experiment_config={
364+
'ExperimentName': 'pca_experiment',
365+
'TrialName': 'pca_trial',
366+
'TrialComponentDisplayName': 'Training'
367+
},
368+
tags=DEFAULT_TAGS,
369+
hyperparameters={
370+
'num_components': execution_input['num_components'], # This will overwrite the value that was defined in the pca_estimator
371+
'HyperParamA': execution_input['HyperParamA'],
372+
'HyperParamB': execution_input['HyperParamB']
373+
}
374+
)
375+
assert step.to_dict()['Parameters']['HyperParameters'] == {
376+
'HyperParamA.$': "$$.Execution.Input['HyperParamA']",
377+
'HyperParamB.$': "$$.Execution.Input['HyperParamB']",
378+
'algorithm_mode': 'randomized',
379+
'feature_dim': 50000,
380+
'mini_batch_size': 200,
381+
'num_components.$': "$$.Execution.Input['num_components']",
382+
'subtract_mean': True
383+
}
384+
385+
347386
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
348387
@patch.object(boto3.session.Session, 'region_name', 'us-east-1')
349388
def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):

0 commit comments

Comments
 (0)