1717
1818import pytest
1919
20- from sagemaker .debugger import Rule
21- from sagemaker .debugger import DebuggerHookConfig
22- from sagemaker .debugger import TensorBoardOutputConfig
23-
24- from sagemaker .debugger import rule_configs
20+ from sagemaker .debugger import DebuggerHookConfig , Rule , rule_configs , TensorBoardOutputConfig
2521from sagemaker .mxnet .estimator import MXNet
26- from tests .integ import DATA_DIR , PYTHON_VERSION , TRAINING_DEFAULT_TIMEOUT_MINUTES
22+ from tests .integ import DATA_DIR , TRAINING_DEFAULT_TIMEOUT_MINUTES
2723from tests .integ .retry import retries
2824from tests .integ .timeout import timeout
2925
6056# TODO-reinvent-2019: test get_debugger_artifacts_path and get_tensorboard_artifacts_path
6157
6258
63- def test_mxnet_with_rules (sagemaker_session , mxnet_full_version , cpu_instance_type ):
59+ def test_mxnet_with_rules (
60+ sagemaker_session , mxnet_full_version , mxnet_full_py_version , cpu_instance_type
61+ ):
6462 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
6563 rules = [
6664 Rule .sagemaker (rule_configs .vanishing_gradient ()),
@@ -77,7 +75,7 @@ def test_mxnet_with_rules(sagemaker_session, mxnet_full_version, cpu_instance_ty
7775 entry_point = script_path ,
7876 role = "SageMakerRole" ,
7977 framework_version = mxnet_full_version ,
80- py_version = PYTHON_VERSION ,
78+ py_version = mxnet_full_py_version ,
8179 train_instance_count = 1 ,
8280 train_instance_type = cpu_instance_type ,
8381 sagemaker_session = sagemaker_session ,
@@ -119,7 +117,9 @@ def test_mxnet_with_rules(sagemaker_session, mxnet_full_version, cpu_instance_ty
119117 _wait_and_assert_that_no_rule_jobs_errored (training_job = mx .latest_training_job )
120118
121119
122- def test_mxnet_with_custom_rule (sagemaker_session , mxnet_full_version , cpu_instance_type ):
120+ def test_mxnet_with_custom_rule (
121+ sagemaker_session , mxnet_full_version , mxnet_full_py_version , cpu_instance_type
122+ ):
123123 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
124124 rules = [_get_custom_rule (sagemaker_session )]
125125
@@ -130,7 +130,7 @@ def test_mxnet_with_custom_rule(sagemaker_session, mxnet_full_version, cpu_insta
130130 entry_point = script_path ,
131131 role = "SageMakerRole" ,
132132 framework_version = mxnet_full_version ,
133- py_version = PYTHON_VERSION ,
133+ py_version = mxnet_full_py_version ,
134134 train_instance_count = 1 ,
135135 train_instance_type = cpu_instance_type ,
136136 sagemaker_session = sagemaker_session ,
@@ -166,7 +166,9 @@ def test_mxnet_with_custom_rule(sagemaker_session, mxnet_full_version, cpu_insta
166166 _wait_and_assert_that_no_rule_jobs_errored (training_job = mx .latest_training_job )
167167
168168
169- def test_mxnet_with_debugger_hook_config (sagemaker_session , mxnet_full_version , cpu_instance_type ):
169+ def test_mxnet_with_debugger_hook_config (
170+ sagemaker_session , mxnet_full_version , mxnet_full_py_version , cpu_instance_type
171+ ):
170172 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
171173 debugger_hook_config = DebuggerHookConfig (
172174 s3_output_path = os .path .join (
@@ -181,7 +183,7 @@ def test_mxnet_with_debugger_hook_config(sagemaker_session, mxnet_full_version,
181183 entry_point = script_path ,
182184 role = "SageMakerRole" ,
183185 framework_version = mxnet_full_version ,
184- py_version = PYTHON_VERSION ,
186+ py_version = mxnet_full_py_version ,
185187 train_instance_count = 1 ,
186188 train_instance_type = cpu_instance_type ,
187189 sagemaker_session = sagemaker_session ,
@@ -204,7 +206,7 @@ def test_mxnet_with_debugger_hook_config(sagemaker_session, mxnet_full_version,
204206
205207
206208def test_mxnet_with_rules_and_debugger_hook_config (
207- sagemaker_session , mxnet_full_version , cpu_instance_type
209+ sagemaker_session , mxnet_full_version , mxnet_full_py_version , cpu_instance_type
208210):
209211 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
210212 rules = [
@@ -227,7 +229,7 @@ def test_mxnet_with_rules_and_debugger_hook_config(
227229 entry_point = script_path ,
228230 role = "SageMakerRole" ,
229231 framework_version = mxnet_full_version ,
230- py_version = PYTHON_VERSION ,
232+ py_version = mxnet_full_py_version ,
231233 train_instance_count = 1 ,
232234 train_instance_type = cpu_instance_type ,
233235 sagemaker_session = sagemaker_session ,
@@ -272,7 +274,7 @@ def test_mxnet_with_rules_and_debugger_hook_config(
272274
273275
274276def test_mxnet_with_custom_rule_and_debugger_hook_config (
275- sagemaker_session , mxnet_full_version , cpu_instance_type
277+ sagemaker_session , mxnet_full_version , mxnet_full_py_version , cpu_instance_type
276278):
277279 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
278280 rules = [_get_custom_rule (sagemaker_session )]
@@ -289,7 +291,7 @@ def test_mxnet_with_custom_rule_and_debugger_hook_config(
289291 entry_point = script_path ,
290292 role = "SageMakerRole" ,
291293 framework_version = mxnet_full_version ,
292- py_version = PYTHON_VERSION ,
294+ py_version = mxnet_full_py_version ,
293295 train_instance_count = 1 ,
294296 train_instance_type = cpu_instance_type ,
295297 sagemaker_session = sagemaker_session ,
@@ -328,7 +330,7 @@ def test_mxnet_with_custom_rule_and_debugger_hook_config(
328330
329331
330332def test_mxnet_with_tensorboard_output_config (
331- sagemaker_session , mxnet_full_version , cpu_instance_type
333+ sagemaker_session , mxnet_full_version , mxnet_full_py_version , cpu_instance_type
332334):
333335 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
334336 tensorboard_output_config = TensorBoardOutputConfig (
@@ -344,7 +346,7 @@ def test_mxnet_with_tensorboard_output_config(
344346 entry_point = script_path ,
345347 role = "SageMakerRole" ,
346348 framework_version = mxnet_full_version ,
347- py_version = PYTHON_VERSION ,
349+ py_version = mxnet_full_py_version ,
348350 train_instance_count = 1 ,
349351 train_instance_type = cpu_instance_type ,
350352 sagemaker_session = sagemaker_session ,
@@ -370,7 +372,9 @@ def test_mxnet_with_tensorboard_output_config(
370372
371373
372374@pytest .mark .canary_quick
373- def test_mxnet_with_all_rules_and_configs (sagemaker_session , mxnet_full_version , cpu_instance_type ):
375+ def test_mxnet_with_all_rules_and_configs (
376+ sagemaker_session , mxnet_full_version , mxnet_full_py_version , cpu_instance_type
377+ ):
374378 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
375379 rules = [
376380 Rule .sagemaker (rule_configs .vanishing_gradient ()),
@@ -398,7 +402,7 @@ def test_mxnet_with_all_rules_and_configs(sagemaker_session, mxnet_full_version,
398402 entry_point = script_path ,
399403 role = "SageMakerRole" ,
400404 framework_version = mxnet_full_version ,
401- py_version = PYTHON_VERSION ,
405+ py_version = mxnet_full_py_version ,
402406 train_instance_count = 1 ,
403407 train_instance_type = cpu_instance_type ,
404408 sagemaker_session = sagemaker_session ,
@@ -441,7 +445,7 @@ def test_mxnet_with_all_rules_and_configs(sagemaker_session, mxnet_full_version,
441445
442446
443447def test_mxnet_with_debugger_hook_config_disabled (
444- sagemaker_session , mxnet_full_version , cpu_instance_type
448+ sagemaker_session , mxnet_full_version , mxnet_full_py_version , cpu_instance_type
445449):
446450 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
447451 script_path = os .path .join (DATA_DIR , "mxnet_mnist" , "mnist_gluon.py" )
@@ -451,7 +455,7 @@ def test_mxnet_with_debugger_hook_config_disabled(
451455 entry_point = script_path ,
452456 role = "SageMakerRole" ,
453457 framework_version = mxnet_full_version ,
454- py_version = PYTHON_VERSION ,
458+ py_version = mxnet_full_py_version ,
455459 train_instance_count = 1 ,
456460 train_instance_type = cpu_instance_type ,
457461 sagemaker_session = sagemaker_session ,
0 commit comments