Skip to content

Commit b9b0371

Browse files
authored
Merge pull request #83 from vaib-amz/master
Support boolean False for debugger_hook_config parameter
2 parents 8795aee + 86b2bcf commit b9b0371

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

src/stepfunctions/steps/sagemaker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non
6767
else:
6868
parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size)
6969

70-
if estimator.debugger_hook_config != None:
70+
if estimator.debugger_hook_config != None and estimator.debugger_hook_config is not False:
7171
parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict()
7272

7373
if estimator.rules != None:

tests/unit/test_sagemaker_steps.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,34 @@ def pca_estimator_with_debug_hook():
111111

112112
return pca
113113

114+
115+
@pytest.fixture
116+
def pca_estimator_with_falsy_debug_hook():
117+
s3_output_location = 's3://sagemaker/models'
118+
119+
pca = sagemaker.estimator.Estimator(
120+
PCA_IMAGE,
121+
role=EXECUTION_ROLE,
122+
train_instance_count=1,
123+
train_instance_type='ml.c4.xlarge',
124+
output_path=s3_output_location,
125+
debugger_hook_config = False
126+
)
127+
128+
pca.set_hyperparameters(
129+
feature_dim=50000,
130+
num_components=10,
131+
subtract_mean=True,
132+
algorithm_mode='randomized',
133+
mini_batch_size=200
134+
)
135+
136+
pca.sagemaker_session = MagicMock()
137+
pca.sagemaker_session.boto_region_name = 'us-east-1'
138+
pca.sagemaker_session._default_bucket = 'sagemaker'
139+
140+
return pca
141+
114142
@pytest.fixture
115143
def pca_model():
116144
model_data = 's3://sagemaker/models/pca.tar.gz'
@@ -283,6 +311,43 @@ def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook):
283311
'End': True
284312
}
285313

314+
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
315+
def test_training_step_creation_with_falsy_debug_hook(pca_estimator_with_falsy_debug_hook):
316+
step = TrainingStep('Training',
317+
estimator=pca_estimator_with_falsy_debug_hook,
318+
job_name='TrainingJob')
319+
assert step.to_dict() == {
320+
'Type': 'Task',
321+
'Parameters': {
322+
'AlgorithmSpecification': {
323+
'TrainingImage': PCA_IMAGE,
324+
'TrainingInputMode': 'File'
325+
},
326+
'OutputDataConfig': {
327+
'S3OutputPath': 's3://sagemaker/models'
328+
},
329+
'StoppingCondition': {
330+
'MaxRuntimeInSeconds': 86400
331+
},
332+
'ResourceConfig': {
333+
'InstanceCount': 1,
334+
'InstanceType': 'ml.c4.xlarge',
335+
'VolumeSizeInGB': 30
336+
},
337+
'RoleArn': EXECUTION_ROLE,
338+
'HyperParameters': {
339+
'feature_dim': '50000',
340+
'num_components': '10',
341+
'subtract_mean': 'True',
342+
'algorithm_mode': 'randomized',
343+
'mini_batch_size': '200'
344+
},
345+
'TrainingJobName': 'TrainingJob'
346+
},
347+
'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync',
348+
'End': True
349+
}
350+
286351
@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call)
287352
def test_training_step_creation_with_model(pca_estimator):
288353
training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob')

0 commit comments

Comments
 (0)