Skip to content

Commit ce6ba25

Browse files
authored
infra: use fixture for Python version in MXNet integ tests (#1613)
1 parent c211417 commit ce6ba25

10 files changed

+125
-79
lines changed

tests/conftest.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def pytest_addoption(parser):
4444
parser.addoption("--sagemaker-runtime-config", action="store", default=None)
4545
parser.addoption("--boto-config", action="store", default=None)
4646
parser.addoption("--chainer-full-version", action="store", default="5.0.0")
47-
parser.addoption("--mxnet-full-version", action="store", default="1.6.0")
4847
parser.addoption("--ei-mxnet-full-version", action="store", default="1.5.1")
4948
parser.addoption(
5049
"--rl-coach-mxnet-full-version",
@@ -255,8 +254,13 @@ def chainer_full_version(request):
255254

256255

257256
@pytest.fixture(scope="module")
258-
def mxnet_full_version(request):
259-
return request.config.getoption("--mxnet-full-version")
257+
def mxnet_full_version():
258+
return "1.6.0"
259+
260+
261+
@pytest.fixture(scope="module")
262+
def mxnet_full_py_version():
263+
return "py3"
260264

261265

262266
@pytest.fixture(scope="module")

tests/integ/test_airflow_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def test_chainer_airflow_config_uploads_data_source_to_s3(
442442

443443
@pytest.mark.canary_quick
444444
def test_mxnet_airflow_config_uploads_data_source_to_s3(
445-
sagemaker_session, cpu_instance_type, mxnet_full_version
445+
sagemaker_session, cpu_instance_type, mxnet_full_version, mxnet_full_py_version
446446
):
447447
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
448448
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
@@ -452,7 +452,7 @@ def test_mxnet_airflow_config_uploads_data_source_to_s3(
452452
entry_point=script_path,
453453
role=ROLE,
454454
framework_version=mxnet_full_version,
455-
py_version=PYTHON_VERSION,
455+
py_version=mxnet_full_py_version,
456456
train_instance_count=SINGLE_INSTANCE_COUNT,
457457
train_instance_type=cpu_instance_type,
458458
sagemaker_session=sagemaker_session,
@@ -573,7 +573,7 @@ def test_xgboost_airflow_config_uploads_data_source_to_s3(
573573

574574
@pytest.mark.canary_quick
575575
def test_pytorch_airflow_config_uploads_data_source_to_s3_when_inputs_not_provided(
576-
sagemaker_session, cpu_instance_type, pytorch_full_version, pytorch_full_py_version,
576+
sagemaker_session, cpu_instance_type, pytorch_full_version, pytorch_full_py_version
577577
):
578578
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
579579
estimator = PyTorch(

tests/integ/test_debugger.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,9 @@
1717

1818
import pytest
1919

20-
from sagemaker.debugger import Rule
21-
from sagemaker.debugger import DebuggerHookConfig
22-
from sagemaker.debugger import TensorBoardOutputConfig
23-
24-
from sagemaker.debugger import rule_configs
20+
from sagemaker.debugger import DebuggerHookConfig, Rule, rule_configs, TensorBoardOutputConfig
2521
from sagemaker.mxnet.estimator import MXNet
26-
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
22+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2723
from tests.integ.retry import retries
2824
from tests.integ.timeout import timeout
2925

@@ -60,7 +56,9 @@
6056
# TODO-reinvent-2019: test get_debugger_artifacts_path and get_tensorboard_artifacts_path
6157

6258

63-
def test_mxnet_with_rules(sagemaker_session, mxnet_full_version, cpu_instance_type):
59+
def test_mxnet_with_rules(
60+
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
61+
):
6462
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
6563
rules = [
6664
Rule.sagemaker(rule_configs.vanishing_gradient()),
@@ -77,7 +75,7 @@ def test_mxnet_with_rules(sagemaker_session, mxnet_full_version, cpu_instance_ty
7775
entry_point=script_path,
7876
role="SageMakerRole",
7977
framework_version=mxnet_full_version,
80-
py_version=PYTHON_VERSION,
78+
py_version=mxnet_full_py_version,
8179
train_instance_count=1,
8280
train_instance_type=cpu_instance_type,
8381
sagemaker_session=sagemaker_session,
@@ -119,7 +117,9 @@ def test_mxnet_with_rules(sagemaker_session, mxnet_full_version, cpu_instance_ty
119117
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
120118

121119

122-
def test_mxnet_with_custom_rule(sagemaker_session, mxnet_full_version, cpu_instance_type):
120+
def test_mxnet_with_custom_rule(
121+
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
122+
):
123123
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
124124
rules = [_get_custom_rule(sagemaker_session)]
125125

@@ -130,7 +130,7 @@ def test_mxnet_with_custom_rule(sagemaker_session, mxnet_full_version, cpu_insta
130130
entry_point=script_path,
131131
role="SageMakerRole",
132132
framework_version=mxnet_full_version,
133-
py_version=PYTHON_VERSION,
133+
py_version=mxnet_full_py_version,
134134
train_instance_count=1,
135135
train_instance_type=cpu_instance_type,
136136
sagemaker_session=sagemaker_session,
@@ -166,7 +166,9 @@ def test_mxnet_with_custom_rule(sagemaker_session, mxnet_full_version, cpu_insta
166166
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
167167

168168

169-
def test_mxnet_with_debugger_hook_config(sagemaker_session, mxnet_full_version, cpu_instance_type):
169+
def test_mxnet_with_debugger_hook_config(
170+
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
171+
):
170172
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
171173
debugger_hook_config = DebuggerHookConfig(
172174
s3_output_path=os.path.join(
@@ -181,7 +183,7 @@ def test_mxnet_with_debugger_hook_config(sagemaker_session, mxnet_full_version,
181183
entry_point=script_path,
182184
role="SageMakerRole",
183185
framework_version=mxnet_full_version,
184-
py_version=PYTHON_VERSION,
186+
py_version=mxnet_full_py_version,
185187
train_instance_count=1,
186188
train_instance_type=cpu_instance_type,
187189
sagemaker_session=sagemaker_session,
@@ -204,7 +206,7 @@ def test_mxnet_with_debugger_hook_config(sagemaker_session, mxnet_full_version,
204206

205207

206208
def test_mxnet_with_rules_and_debugger_hook_config(
207-
sagemaker_session, mxnet_full_version, cpu_instance_type
209+
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
208210
):
209211
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
210212
rules = [
@@ -227,7 +229,7 @@ def test_mxnet_with_rules_and_debugger_hook_config(
227229
entry_point=script_path,
228230
role="SageMakerRole",
229231
framework_version=mxnet_full_version,
230-
py_version=PYTHON_VERSION,
232+
py_version=mxnet_full_py_version,
231233
train_instance_count=1,
232234
train_instance_type=cpu_instance_type,
233235
sagemaker_session=sagemaker_session,
@@ -272,7 +274,7 @@ def test_mxnet_with_rules_and_debugger_hook_config(
272274

273275

274276
def test_mxnet_with_custom_rule_and_debugger_hook_config(
275-
sagemaker_session, mxnet_full_version, cpu_instance_type
277+
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
276278
):
277279
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
278280
rules = [_get_custom_rule(sagemaker_session)]
@@ -289,7 +291,7 @@ def test_mxnet_with_custom_rule_and_debugger_hook_config(
289291
entry_point=script_path,
290292
role="SageMakerRole",
291293
framework_version=mxnet_full_version,
292-
py_version=PYTHON_VERSION,
294+
py_version=mxnet_full_py_version,
293295
train_instance_count=1,
294296
train_instance_type=cpu_instance_type,
295297
sagemaker_session=sagemaker_session,
@@ -328,7 +330,7 @@ def test_mxnet_with_custom_rule_and_debugger_hook_config(
328330

329331

330332
def test_mxnet_with_tensorboard_output_config(
331-
sagemaker_session, mxnet_full_version, cpu_instance_type
333+
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
332334
):
333335
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
334336
tensorboard_output_config = TensorBoardOutputConfig(
@@ -344,7 +346,7 @@ def test_mxnet_with_tensorboard_output_config(
344346
entry_point=script_path,
345347
role="SageMakerRole",
346348
framework_version=mxnet_full_version,
347-
py_version=PYTHON_VERSION,
349+
py_version=mxnet_full_py_version,
348350
train_instance_count=1,
349351
train_instance_type=cpu_instance_type,
350352
sagemaker_session=sagemaker_session,
@@ -370,7 +372,9 @@ def test_mxnet_with_tensorboard_output_config(
370372

371373

372374
@pytest.mark.canary_quick
373-
def test_mxnet_with_all_rules_and_configs(sagemaker_session, mxnet_full_version, cpu_instance_type):
375+
def test_mxnet_with_all_rules_and_configs(
376+
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
377+
):
374378
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
375379
rules = [
376380
Rule.sagemaker(rule_configs.vanishing_gradient()),
@@ -398,7 +402,7 @@ def test_mxnet_with_all_rules_and_configs(sagemaker_session, mxnet_full_version,
398402
entry_point=script_path,
399403
role="SageMakerRole",
400404
framework_version=mxnet_full_version,
401-
py_version=PYTHON_VERSION,
405+
py_version=mxnet_full_py_version,
402406
train_instance_count=1,
403407
train_instance_type=cpu_instance_type,
404408
sagemaker_session=sagemaker_session,
@@ -441,7 +445,7 @@ def test_mxnet_with_all_rules_and_configs(sagemaker_session, mxnet_full_version,
441445

442446

443447
def test_mxnet_with_debugger_hook_config_disabled(
444-
sagemaker_session, mxnet_full_version, cpu_instance_type
448+
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
445449
):
446450
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
447451
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
@@ -451,7 +455,7 @@ def test_mxnet_with_debugger_hook_config_disabled(
451455
entry_point=script_path,
452456
role="SageMakerRole",
453457
framework_version=mxnet_full_version,
454-
py_version=PYTHON_VERSION,
458+
py_version=mxnet_full_py_version,
455459
train_instance_count=1,
456460
train_instance_type=cpu_instance_type,
457461
sagemaker_session=sagemaker_session,

tests/integ/test_git.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sagemaker.pytorch.estimator import PyTorch
2525
from sagemaker.sklearn.estimator import SKLearn
2626
from sagemaker.sklearn.model import SKLearnModel
27-
from tests.integ import DATA_DIR, PYTHON_VERSION
27+
from tests.integ import DATA_DIR
2828

2929

3030
GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git"
@@ -81,7 +81,7 @@ def test_github(sagemaker_local_session, pytorch_full_version, pytorch_full_py_v
8181

8282
@pytest.mark.local_mode
8383
@pytest.mark.skip("needs a secure authentication approach")
84-
def test_private_github(sagemaker_local_session, mxnet_full_version):
84+
def test_private_github(sagemaker_local_session, mxnet_full_version, mxnet_full_py_version):
8585
script_path = "mnist.py"
8686
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
8787
git_config = {
@@ -100,7 +100,7 @@ def test_private_github(sagemaker_local_session, mxnet_full_version):
100100
source_dir=source_dir,
101101
dependencies=dependencies,
102102
framework_version=mxnet_full_version,
103-
py_version=PYTHON_VERSION,
103+
py_version=mxnet_full_py_version,
104104
train_instance_count=1,
105105
train_instance_type="local",
106106
sagemaker_session=sagemaker_local_session,
@@ -219,7 +219,7 @@ def test_github_with_ssh_passphrase_not_configured(sagemaker_local_session, skle
219219

220220
@pytest.mark.local_mode
221221
@pytest.mark.skip("needs a secure authentication approach")
222-
def test_codecommit(sagemaker_local_session, mxnet_full_version):
222+
def test_codecommit(sagemaker_local_session, mxnet_full_version, mxnet_full_py_version):
223223
script_path = "mnist.py"
224224
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
225225
git_config = {
@@ -236,7 +236,7 @@ def test_codecommit(sagemaker_local_session, mxnet_full_version):
236236
source_dir=source_dir,
237237
dependencies=dependencies,
238238
framework_version=mxnet_full_version,
239-
py_version=PYTHON_VERSION,
239+
py_version=mxnet_full_py_version,
240240
train_instance_count=1,
241241
train_instance_type="local",
242242
sagemaker_session=sagemaker_local_session,

tests/integ/test_local_mode.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import stopit
2424

2525
import tests.integ.lock as lock
26-
from tests.integ import DATA_DIR, PYTHON_VERSION
26+
from tests.integ import DATA_DIR
2727

2828
from sagemaker.local import LocalSession, LocalSagemakerRuntimeClient, LocalSagemakerClient
2929
from sagemaker.mxnet import MXNet
@@ -54,7 +54,7 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
5454

5555

5656
@pytest.fixture(scope="module")
57-
def mxnet_model(sagemaker_local_session, mxnet_full_version):
57+
def mxnet_model(sagemaker_local_session, mxnet_full_version, mxnet_full_py_version):
5858
def _create_model(output_path):
5959
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
6060
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
@@ -66,7 +66,7 @@ def _create_model(output_path):
6666
train_instance_type="local",
6767
output_path=output_path,
6868
framework_version=mxnet_full_version,
69-
py_version=PYTHON_VERSION,
69+
py_version=mxnet_full_py_version,
7070
sagemaker_session=sagemaker_local_session,
7171
)
7272

@@ -85,7 +85,9 @@ def _create_model(output_path):
8585

8686

8787
@pytest.mark.local_mode
88-
def test_local_mode_serving_from_s3_model(sagemaker_local_session, mxnet_model, mxnet_full_version):
88+
def test_local_mode_serving_from_s3_model(
89+
sagemaker_local_session, mxnet_model, mxnet_full_version, mxnet_full_py_version
90+
):
8991
path = "s3://%s" % sagemaker_local_session.default_bucket()
9092
s3_model = mxnet_model(path)
9193
s3_model.sagemaker_session = sagemaker_local_session
@@ -119,14 +121,14 @@ def test_local_mode_serving_from_local_model(tmpdir, sagemaker_local_session, mx
119121

120122

121123
@pytest.mark.local_mode
122-
def test_mxnet_local_mode(sagemaker_local_session, mxnet_full_version):
124+
def test_mxnet_local_mode(sagemaker_local_session, mxnet_full_version, mxnet_full_py_version):
123125
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
124126
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
125127

126128
mx = MXNet(
127129
entry_point=script_path,
128130
role="SageMakerRole",
129-
py_version=PYTHON_VERSION,
131+
py_version=mxnet_full_py_version,
130132
train_instance_count=1,
131133
train_instance_type="local",
132134
sagemaker_session=sagemaker_local_session,
@@ -153,14 +155,16 @@ def test_mxnet_local_mode(sagemaker_local_session, mxnet_full_version):
153155

154156

155157
@pytest.mark.local_mode
156-
def test_mxnet_distributed_local_mode(sagemaker_local_session, mxnet_full_version):
158+
def test_mxnet_distributed_local_mode(
159+
sagemaker_local_session, mxnet_full_version, mxnet_full_py_version
160+
):
157161
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
158162
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
159163

160164
mx = MXNet(
161165
entry_point=script_path,
162166
role="SageMakerRole",
163-
py_version=PYTHON_VERSION,
167+
py_version=mxnet_full_py_version,
164168
train_instance_count=2,
165169
train_instance_type="local",
166170
sagemaker_session=sagemaker_local_session,
@@ -179,7 +183,7 @@ def test_mxnet_distributed_local_mode(sagemaker_local_session, mxnet_full_versio
179183

180184

181185
@pytest.mark.local_mode
182-
def test_mxnet_local_data_local_script(mxnet_full_version):
186+
def test_mxnet_local_data_local_script(mxnet_full_version, mxnet_full_py_version):
183187
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
184188
script_path = os.path.join(data_path, "mnist.py")
185189

@@ -189,7 +193,7 @@ def test_mxnet_local_data_local_script(mxnet_full_version):
189193
train_instance_count=1,
190194
train_instance_type="local",
191195
framework_version=mxnet_full_version,
192-
py_version=PYTHON_VERSION,
196+
py_version=mxnet_full_py_version,
193197
sagemaker_session=LocalNoS3Session(),
194198
)
195199

@@ -209,14 +213,16 @@ def test_mxnet_local_data_local_script(mxnet_full_version):
209213

210214

211215
@pytest.mark.local_mode
212-
def test_mxnet_training_failure(sagemaker_local_session, mxnet_full_version, tmpdir):
216+
def test_mxnet_training_failure(
217+
sagemaker_local_session, mxnet_full_version, mxnet_full_py_version, tmpdir
218+
):
213219
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "failure_script.py")
214220

215221
mx = MXNet(
216222
entry_point=script_path,
217223
role="SageMakerRole",
218224
framework_version=mxnet_full_version,
219-
py_version=PYTHON_VERSION,
225+
py_version=mxnet_full_py_version,
220226
train_instance_count=1,
221227
train_instance_type="local",
222228
sagemaker_session=sagemaker_local_session,
@@ -233,7 +239,7 @@ def test_mxnet_training_failure(sagemaker_local_session, mxnet_full_version, tmp
233239

234240
@pytest.mark.local_mode
235241
def test_local_transform_mxnet(
236-
sagemaker_local_session, tmpdir, mxnet_full_version, cpu_instance_type
242+
sagemaker_local_session, tmpdir, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
237243
):
238244
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
239245
script_path = os.path.join(data_path, "mnist.py")
@@ -244,7 +250,7 @@ def test_local_transform_mxnet(
244250
train_instance_count=1,
245251
train_instance_type="local",
246252
framework_version=mxnet_full_version,
247-
py_version=PYTHON_VERSION,
253+
py_version=mxnet_full_py_version,
248254
sagemaker_session=sagemaker_local_session,
249255
)
250256

0 commit comments

Comments
 (0)