diff --git a/src/stepfunctions/steps/sagemaker.py b/src/stepfunctions/steps/sagemaker.py index 6321563..5be4a8f 100644 --- a/src/stepfunctions/steps/sagemaker.py +++ b/src/stepfunctions/steps/sagemaker.py @@ -67,7 +67,7 @@ def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=Non else: parameters = training_config(estimator=estimator, inputs=data, mini_batch_size=mini_batch_size) - if estimator.debugger_hook_config != None: + if estimator.debugger_hook_config != None and estimator.debugger_hook_config is not False: parameters['DebugHookConfig'] = estimator.debugger_hook_config._to_request_dict() if estimator.rules != None: diff --git a/tests/unit/test_sagemaker_steps.py b/tests/unit/test_sagemaker_steps.py index bdc7a57..763e9e6 100644 --- a/tests/unit/test_sagemaker_steps.py +++ b/tests/unit/test_sagemaker_steps.py @@ -111,6 +111,34 @@ def pca_estimator_with_debug_hook(): return pca + +@pytest.fixture +def pca_estimator_with_falsy_debug_hook(): + s3_output_location = 's3://sagemaker/models' + + pca = sagemaker.estimator.Estimator( + PCA_IMAGE, + role=EXECUTION_ROLE, + train_instance_count=1, + train_instance_type='ml.c4.xlarge', + output_path=s3_output_location, + debugger_hook_config = False + ) + + pca.set_hyperparameters( + feature_dim=50000, + num_components=10, + subtract_mean=True, + algorithm_mode='randomized', + mini_batch_size=200 + ) + + pca.sagemaker_session = MagicMock() + pca.sagemaker_session.boto_region_name = 'us-east-1' + pca.sagemaker_session._default_bucket = 'sagemaker' + + return pca + @pytest.fixture def pca_model(): model_data = 's3://sagemaker/models/pca.tar.gz' @@ -283,6 +311,43 @@ def test_training_step_creation_with_debug_hook(pca_estimator_with_debug_hook): 'End': True } +@patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) +def test_training_step_creation_with_falsy_debug_hook(pca_estimator_with_falsy_debug_hook): + step = TrainingStep('Training', + estimator=pca_estimator_with_falsy_debug_hook, + job_name='TrainingJob') + assert step.to_dict() == { + 'Type': 'Task', + 'Parameters': { + 'AlgorithmSpecification': { + 'TrainingImage': PCA_IMAGE, + 'TrainingInputMode': 'File' + }, + 'OutputDataConfig': { + 'S3OutputPath': 's3://sagemaker/models' + }, + 'StoppingCondition': { + 'MaxRuntimeInSeconds': 86400 + }, + 'ResourceConfig': { + 'InstanceCount': 1, + 'InstanceType': 'ml.c4.xlarge', + 'VolumeSizeInGB': 30 + }, + 'RoleArn': EXECUTION_ROLE, + 'HyperParameters': { + 'feature_dim': '50000', + 'num_components': '10', + 'subtract_mean': 'True', + 'algorithm_mode': 'randomized', + 'mini_batch_size': '200' + }, + 'TrainingJobName': 'TrainingJob' + }, + 'Resource': 'arn:aws:states:::sagemaker:createTrainingJob.sync', + 'End': True + } + @patch('botocore.client.BaseClient._make_api_call', new=mock_boto_api_call) def test_training_step_creation_with_model(pca_estimator): training_step = TrainingStep('Training', estimator=pca_estimator, job_name='TrainingJob')