Skip to content

Commit 13c9d66

Browse files
committed
breaking: addressing comments from PR #1559
1 parent 00b4b48 commit 13c9d66

File tree

5 files changed

+34
-26
lines changed

5 files changed

+34
-26
lines changed

src/sagemaker/fw_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -686,17 +686,19 @@ def _region_supports_debugger(region_name):
686686
def validate_version_or_image_args(framework_version, py_version, image_name):
687687
"""Checks if version or image arguments are specified.
688688
689-
Used to validate framework and model arguments to enforce version or image specification.
690-
Raises ValueError if version or image arguments are not specified.
689+
Validates framework and model arguments to enforce version or image specification.
691690
692691
Args:
693-
framework_version (str): the version of the framework
694-
py_version (str): the version of python
695-
image_name (str): the uri of the image
692+
framework_version (str): The version of the framework.
693+
py_version (str): The version of python.
694+
image_name (str): The URI of the image.
695+
696+
Raises:
697+
ValueError: if `image_name` is None and either `framework_version` or `py_version` is
698+
None.
696699
"""
697700
if (framework_version is None or py_version is None) and image_name is None:
698701
raise ValueError(
699702
"framework_version or py_version was None, yet image_name was also None. "
700703
"Either specify both framework_version and py_version, or specify image_name."
701704
)
702-
return True

src/sagemaker/mxnet/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,12 @@ def __init__(
7474
If ``source_dir`` is specified, then ``entry_point``
7575
must point to a file located at the root of ``source_dir``.
7676
framework_version (str): MXNet version you want to use for executing
77-
your model training code. List of supported versions. Defaults to ``None``.
77+
your model training code. Defaults to `None`. Required unless
78+
``image_name`` is provided. List of supported versions.
7879
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
7980
py_version (str): Python version you want to use for executing your
80-
model training code. One of 'py2' or 'py3'. Defaults to ``None``.
81+
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
82+
unless ``image_name`` is provided.
8183
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
8284
with any other training source code dependencies aside from the entry
8385
point file (default: None). If ``source_dir`` is an S3 URI, it must

src/sagemaker/mxnet/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,11 @@ def __init__(
8787
hosting. If ``source_dir`` is specified, then ``entry_point``
8888
must point to a file located at the root of ``source_dir``.
8989
framework_version (str): MXNet version you want to use for executing
90-
your model training code. Defaults to ``None``.
90+
your model training code. Defaults to ``None``. Required unless
91+
``image_name`` is provided.
9192
py_version (str): Python version you want to use for executing your
92-
model training code. Defaults to ``None``.
93+
model training code. Defaults to ``None``. Required unless
94+
``image_name`` is provided.
9395
image (str): A Docker image URI (default: None). If not specified, a
9496
default image for MXNet will be used.
9597
@@ -112,7 +114,6 @@ def __init__(
112114
:class:`~sagemaker.model.FrameworkModel` and
113115
:class:`~sagemaker.model.Model`.
114116
"""
115-
# TODO: rename/unify image attribute to match across code base
116117
validate_version_or_image_args(framework_version, py_version, image)
117118
if py_version and py_version == "py2":
118119
logger.warning(

tests/unit/test_fw_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,13 +1275,16 @@ def test_warn_if_parameter_server_with_multi_gpu(caplog):
12751275

12761276

12771277
def test_validate_version_or_image_args():
1278-
for good_args in [("", "", None), (None, "", ""), ("", None, "")]:
1279-
kwargs = dict(zip(("framework_version", "py_version", "image_name"), good_args))
1280-
assert fw_utils.validate_version_or_image_args(**kwargs)
1278+
good_args = [("1.0", "py3", None), (None, "py3", "my:uri"), ("1.0", None, "my:uri")]
1279+
for framework_version, py_version, image_name in good_args:
1280+
assert (
1281+
fw_utils.validate_version_or_image_args(framework_version, py_version, image_name)
1282+
is None
1283+
)
12811284

12821285

12831286
def test_validate_version_or_image_args_raises():
1284-
for bad_args in [(None, None, None), (None, "", None), ("", None, None)]:
1285-
kwargs = dict(zip(("framework_version", "py_version", "image_name"), bad_args))
1287+
bad_args = [(None, None, None), (None, "py3", None), ("1.0", None, None)]
1288+
for framework_version, py_version, image_name in bad_args:
12861289
with pytest.raises(ValueError):
1287-
fw_utils.validate_version_or_image_args(**kwargs)
1290+
fw_utils.validate_version_or_image_args(framework_version, py_version, image_name)

tests/unit/test_mxnet.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
INSTANCE_COUNT = 1
3838
INSTANCE_TYPE = "ml.c4.4xlarge"
3939
ACCELERATOR_TYPE = "ml.eia.medium"
40-
IMAGE = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.6.0-cpu-py3"
40+
IMAGE = "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.0-cpu-py3"
4141
COMPILATION_JOB_NAME = "{}-{}".format("compilation-sagemaker-mxnet", TIMESTAMP)
4242
FRAMEWORK = "mxnet"
4343
ROLE = "Dummy"
@@ -85,25 +85,25 @@ def sagemaker_session():
8585
return session
8686

8787

88-
def is_mms_version(mxnet_version):
88+
def _is_mms_version(mxnet_version):
8989
return parse_version(MXNetModel._LOWEST_MMS_VERSION) <= parse_version(mxnet_version)
9090

9191

9292
@pytest.fixture()
9393
def skip_if_mms_version(mxnet_version):
94-
if is_mms_version(mxnet_version):
94+
if _is_mms_version(mxnet_version):
9595
pytest.skip("Skipping because this version uses MMS")
9696

9797

9898
@pytest.fixture()
9999
def skip_if_not_mms_version(mxnet_version):
100-
if not is_mms_version(mxnet_version):
100+
if not _is_mms_version(mxnet_version):
101101
pytest.skip("Skipping because this version does not use MMS")
102102

103103

104-
def _get_train_args(job_name, image_name):
104+
def _get_train_args(job_name):
105105
return {
106-
"image": image_name,
106+
"image": IMAGE,
107107
"input_mode": "File",
108108
"input_config": [
109109
{
@@ -321,7 +321,7 @@ def test_mxnet(
321321

322322
actual_train_args = sagemaker_session.method_calls[0][2]
323323
job_name = actual_train_args["job_name"]
324-
expected_train_args = _get_train_args(job_name, IMAGE)
324+
expected_train_args = _get_train_args(job_name)
325325
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
326326
expected_train_args["experiment_config"] = EXPERIMENT_CONFIG
327327

@@ -338,7 +338,7 @@ def test_mxnet(
338338
assert "cpu" in model.prepare_container_def(CPU)["Image"]
339339
predictor = mx.deploy(1, GPU)
340340
assert isinstance(predictor, MXNetPredictor)
341-
assert is_mms_version(mxnet_version) ^ (create_tar_file.called and not repack_model.called)
341+
assert _is_mms_version(mxnet_version) ^ (create_tar_file.called and not repack_model.called)
342342

343343

344344
@patch("sagemaker.utils.create_tar_file", MagicMock())
@@ -464,7 +464,7 @@ def test_model_image_accelerator(
464464
)
465465
container_def = model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
466466
assert container_def["Image"] == IMAGE
467-
assert is_mms_version(mxnet_version) ^ (tar_and_upload.called and not repack_model.called)
467+
assert _is_mms_version(mxnet_version) ^ (tar_and_upload.called and not repack_model.called)
468468

469469

470470
def test_attach(sagemaker_session, mxnet_version, mxnet_py_version):

0 commit comments

Comments
 (0)