diff --git a/tests/conftest.py b/tests/conftest.py index d617c48d77..0be798471a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -133,6 +133,11 @@ def chainer_version(request): return request.param +@pytest.fixture(scope="module", params=["py2", "py3"]) +def chainer_py_version(request): + return request.param + + # TODO: current version fixtures are legacy fixtures that aren't useful # and no longer verify whether images are valid @pytest.fixture( diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py index 7cbb966533..82fe36738b 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py @@ -12,8 +12,6 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import sys - import pasta import pytest @@ -106,13 +104,6 @@ def constructors(versions=False, image=False): return [ctr for template in TEMPLATES for ctr in template.constructors(versions, image)] -@pytest.fixture(autouse=True) -def skip_if_py2(): - # Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed. - if sys.version_info.major < 3: - pytest.skip("v2 migration script doesn't support Python 2.") - - @pytest.fixture def constructors_empty(): return constructors() diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py index affb2940c4..6b049a510a 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py @@ -12,10 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import sys - import pasta -import pytest from mock import MagicMock, patch from sagemaker.cli.compatibility.v2.modifiers import tf_legacy_mode @@ -25,13 +22,6 @@ REGION_NAME = "us-west-2" -@pytest.fixture(autouse=True) -def skip_if_py2(): - # Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed. - if sys.version_info.major < 3: - pytest.skip("v2 migration script doesn't support Python 2.") - - def test_node_should_be_modified_tf_constructor_legacy_mode(): tf_legacy_mode_constructors = ( "TensorFlow(script_mode=False)", diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 169aef555d..a6e50e224a 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -15,18 +15,16 @@ import logging import json import os -import pytest -import sys from distutils.util import strtobool + +import pytest from mock import MagicMock, Mock from mock import patch - from sagemaker.chainer import defaults from sagemaker.chainer import Chainer from sagemaker.chainer import ChainerPredictor, ChainerModel - DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") SERVING_SCRIPT_FILE = "another_dummy_script.py" @@ -38,7 +36,6 @@ INSTANCE_COUNT = 1 INSTANCE_TYPE = "ml.c4.4xlarge" ACCELERATOR_TYPE = "ml.eia.medium" -PYTHON_VERSION = "py" + str(sys.version_info.major) IMAGE_NAME = "sagemaker-chainer" JOB_NAME = "{}-{}".format(IMAGE_NAME, TIMESTAMP) IMAGE_URI_FORMAT_STRING = "520713654638.dkr.ecr.{}.amazonaws.com/{}:{}-{}-{}" @@ -77,21 +74,22 @@ def sagemaker_session(): return session -def _get_full_cpu_image_uri(version, py_version=PYTHON_VERSION): +def _get_full_cpu_image_uri(version, py_version): return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, "cpu", py_version) -def _get_full_gpu_image_uri(version, py_version=PYTHON_VERSION): +def _get_full_gpu_image_uri(version, py_version): return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, "gpu", py_version) -def _get_full_cpu_image_uri_with_ei(version, py_version=PYTHON_VERSION): +def _get_full_cpu_image_uri_with_ei(version, py_version): return _get_full_cpu_image_uri(version, py_version=py_version) + "-eia" def _chainer_estimator( sagemaker_session, framework_version, + py_version, train_instance_type=None, base_job_name=None, use_mpi=None, @@ -103,6 +101,7 @@ def _chainer_estimator( return Chainer( entry_point=SCRIPT_PATH, framework_version=framework_version, + py_version=py_version, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, @@ -112,14 +111,13 @@ def _chainer_estimator( num_processes=num_processes, process_slots_per_host=process_slots_per_host, additional_mpi_options=additional_mpi_options, - py_version=PYTHON_VERSION, **kwargs ) -def _create_train_job(version): +def _create_train_job(version, py_version): return { - "image": _get_full_cpu_image_uri(version), + "image": _get_full_cpu_image_uri(version, py_version), "input_mode": "File", "input_config": [ { @@ -162,47 +160,7 @@ def _create_train_job(version): } -def _create_train_job_with_additional_hyperparameters(version): - return { - "image": _get_full_cpu_image_uri(version), - "input_mode": "File", - "input_config": [ - { - "ChannelName": "training", - "DataSource": { - "S3DataSource": { - "S3DataDistributionType": "FullyReplicated", - "S3DataType": "S3Prefix", - } - }, - } - ], - "role": ROLE, - "job_name": JOB_NAME, - "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, - "resource_config": { - "InstanceType": "ml.c4.4xlarge", - "InstanceCount": 1, - "VolumeSizeInGB": 30, - }, - "hyperparameters": { - "sagemaker_program": json.dumps("dummy_script.py"), - "sagemaker_container_log_level": str(logging.INFO), - "sagemaker_job_name": json.dumps(JOB_NAME), - "sagemaker_submit_directory": json.dumps( - "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) - ), - "sagemaker_region": '"us-west-2"', - "sagemaker_num_processes": "4", - "sagemaker_additional_mpi_options": '"-x MY_ENVIRONMENT_VARIABLE"', - "sagemaker_process_slots_per_host": "10", - "sagemaker_use_mpi": "true", - }, - "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, - } - - -def test_additional_hyperparameters(sagemaker_session, chainer_version): +def test_additional_hyperparameters(sagemaker_session, chainer_version, chainer_py_version): chainer = _chainer_estimator( sagemaker_session, use_mpi=True, @@ -210,6 +168,7 @@ def test_additional_hyperparameters(sagemaker_session, chainer_version): process_slots_per_host=10, additional_mpi_options="-x MY_ENVIRONMENT_VARIABLE", framework_version=chainer_version, + py_version=chainer_py_version, ) assert bool(strtobool(chainer.hyperparameters()["sagemaker_use_mpi"])) assert int(chainer.hyperparameters()["sagemaker_num_processes"]) == 4 @@ -220,9 +179,11 @@ def test_additional_hyperparameters(sagemaker_session, chainer_version): ) -def test_attach_with_additional_hyperparameters(sagemaker_session, chainer_version): +def test_attach_with_additional_hyperparameters( + sagemaker_session, chainer_version, chainer_py_version +): training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-cpu-{}".format( - chainer_version, PYTHON_VERSION + chainer_version, chainer_py_version ) returned_job_description = { "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, @@ -270,7 +231,7 @@ def test_attach_with_additional_hyperparameters(sagemaker_session, chainer_versi assert estimator.additional_mpi_options == "-x MY_ENVIRONMENT_VARIABLE" -def test_create_model(sagemaker_session, chainer_version): +def test_create_model(sagemaker_session, chainer_version, chainer_py_version): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source" chainer = Chainer( @@ -281,7 +242,7 @@ def test_create_model(sagemaker_session, chainer_version): train_instance_type=INSTANCE_TYPE, framework_version=chainer_version, container_log_level=container_log_level, - py_version=PYTHON_VERSION, + py_version=chainer_py_version, base_job_name="job", source_dir=source_dir, ) @@ -301,7 +262,7 @@ def test_create_model(sagemaker_session, chainer_version): assert model.vpc_config is None -def test_create_model_with_optional_params(sagemaker_session, chainer_version): +def test_create_model_with_optional_params(sagemaker_session, chainer_version, chainer_py_version): container_log_level = '"logging.INFO"' source_dir = "s3://mybucket/source" enable_cloudwatch_metrics = "true" @@ -313,7 +274,7 @@ def test_create_model_with_optional_params(sagemaker_session, chainer_version): train_instance_type=INSTANCE_TYPE, container_log_level=container_log_level, framework_version=chainer_version, - py_version=PYTHON_VERSION, + py_version=chainer_py_version, base_job_name="job", source_dir=source_dir, enable_cloudwatch_metrics=enable_cloudwatch_metrics, @@ -354,7 +315,6 @@ def test_create_model_with_custom_image(sagemaker_session): train_instance_type=INSTANCE_TYPE, image_name=custom_image, container_log_level=container_log_level, - py_version=PYTHON_VERSION, base_job_name="job", source_dir=source_dir, ) @@ -367,7 +327,7 @@ def test_create_model_with_custom_image(sagemaker_session): @patch("sagemaker.utils.create_tar_file", MagicMock()) @patch("time.strftime", return_value=TIMESTAMP) -def test_chainer(strftime, sagemaker_session, chainer_version): +def test_chainer(strftime, sagemaker_session, chainer_version, chainer_py_version): chainer = Chainer( entry_point=SCRIPT_PATH, role=ROLE, @@ -375,7 +335,7 @@ def test_chainer(strftime, sagemaker_session, chainer_version): train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, framework_version=chainer_version, - py_version=PYTHON_VERSION, + py_version=chainer_py_version, ) inputs = "s3://mybucket/train" @@ -387,7 +347,7 @@ def test_chainer(strftime, sagemaker_session, chainer_version): boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] assert boto_call_names == ["resource"] - expected_train_args = _create_train_job(chainer_version) + expected_train_args = _create_train_job(chainer_version, chainer_py_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs actual_train_args = sagemaker_session.method_calls[0][2] @@ -406,7 +366,7 @@ def test_chainer(strftime, sagemaker_session, chainer_version): "SAGEMAKER_REGION": "us-west-2", "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", }, - "Image": expected_image_base.format(chainer_version, PYTHON_VERSION), + "Image": expected_image_base.format(chainer_version, chainer_py_version), "ModelDataUrl": "s3://m/m.tar.gz", } == model.prepare_container_def(GPU) @@ -416,40 +376,42 @@ def test_chainer(strftime, sagemaker_session, chainer_version): @patch("sagemaker.utils.create_tar_file", MagicMock()) -def test_model(sagemaker_session, chainer_version): +def test_model(sagemaker_session, chainer_version, chainer_py_version): model = ChainerModel( "s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session, framework_version=chainer_version, - py_version=PYTHON_VERSION, + py_version=chainer_py_version, ) predictor = model.deploy(1, GPU) assert isinstance(predictor, ChainerPredictor) @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) -def test_model_prepare_container_def_accelerator_error(sagemaker_session, chainer_version): +def test_model_prepare_container_def_accelerator_error( + sagemaker_session, chainer_version, chainer_py_version +): model = ChainerModel( MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session, framework_version=chainer_version, - py_version=PYTHON_VERSION, + py_version=chainer_py_version, ) with pytest.raises(ValueError): model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) -def test_model_prepare_container_def_no_instance_type_or_image(chainer_version): +def test_model_prepare_container_def_no_instance_type_or_image(chainer_version, chainer_py_version): model = ChainerModel( MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, framework_version=chainer_version, - py_version=PYTHON_VERSION, + py_version=chainer_py_version, ) with pytest.raises(ValueError) as e: @@ -459,7 +421,7 @@ def test_model_prepare_container_def_no_instance_type_or_image(chainer_version): assert expected_msg in str(e) -def test_train_image_default(sagemaker_session, chainer_version): +def test_train_image_default(sagemaker_session, chainer_version, chainer_py_version): chainer = Chainer( entry_point=SCRIPT_PATH, role=ROLE, @@ -467,44 +429,59 @@ def test_train_image_default(sagemaker_session, chainer_version): train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, framework_version=chainer_version, - py_version=PYTHON_VERSION, + py_version=chainer_py_version, ) - assert _get_full_cpu_image_uri(chainer_version) in chainer.train_image() + assert _get_full_cpu_image_uri(chainer_version, chainer_py_version) == chainer.train_image() -def test_train_image_cpu_instances(sagemaker_session, chainer_version): +def test_train_image_cpu_instances(sagemaker_session, chainer_version, chainer_py_version): chainer = _chainer_estimator( - sagemaker_session, framework_version=chainer_version, train_instance_type="ml.c2.2xlarge" + sagemaker_session, + framework_version=chainer_version, + py_version=chainer_py_version, + train_instance_type="ml.c2.2xlarge", ) - assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version) + assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version, chainer_py_version) chainer = _chainer_estimator( - sagemaker_session, framework_version=chainer_version, train_instance_type="ml.c4.2xlarge" + sagemaker_session, + framework_version=chainer_version, + py_version=chainer_py_version, + train_instance_type="ml.c4.2xlarge", ) - assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version) + assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version, chainer_py_version) chainer = _chainer_estimator( - sagemaker_session, framework_version=chainer_version, train_instance_type="ml.m16" + sagemaker_session, + framework_version=chainer_version, + py_version=chainer_py_version, + train_instance_type="ml.m16", ) - assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version) + assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version, chainer_py_version) -def test_train_image_gpu_instances(sagemaker_session, chainer_version): +def test_train_image_gpu_instances(sagemaker_session, chainer_version, chainer_py_version): chainer = _chainer_estimator( - sagemaker_session, framework_version=chainer_version, train_instance_type="ml.g2.2xlarge" + sagemaker_session, + framework_version=chainer_version, + py_version=chainer_py_version, + train_instance_type="ml.g2.2xlarge", ) - assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version) + assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version, chainer_py_version) chainer = _chainer_estimator( - sagemaker_session, framework_version=chainer_version, train_instance_type="ml.p2.2xlarge" + sagemaker_session, + framework_version=chainer_version, + py_version=chainer_py_version, + train_instance_type="ml.p2.2xlarge", ) - assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version) + assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version, chainer_py_version) -def test_attach(sagemaker_session, chainer_version): +def test_attach(sagemaker_session, chainer_version, chainer_py_version): training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-chainer:{}-cpu-{}".format( - chainer_version, PYTHON_VERSION + chainer_version, chainer_py_version ) returned_job_description = { "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, @@ -537,7 +514,7 @@ def test_attach(sagemaker_session, chainer_version): estimator = Chainer.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.latest_training_job.job_name == "neo" - assert estimator.py_version == PYTHON_VERSION + assert estimator.py_version == chainer_py_version assert estimator.framework_version == chainer_version assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 @@ -652,11 +629,3 @@ def test_model_py2_warning(warning, sagemaker_session, chainer_version): ) assert model.py_version == "py2" warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION) - - -def test_custom_image_estimator_deploy(sagemaker_session, chainer_version): - custom_image = "mycustomimage:latest" - chainer = _chainer_estimator(sagemaker_session, framework_version=chainer_version) - chainer.fit(inputs="s3://mybucket/train", job_name="new_name") - model = chainer.create_model(image=custom_image) - assert model.image == custom_image diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 5239aa0d8f..e3356587cf 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -16,8 +16,8 @@ import json import os import pytest -import sys from mock import ANY, MagicMock, Mock, patch +from packaging.version import Version from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch @@ -35,7 +35,6 @@ INSTANCE_COUNT = 1 INSTANCE_TYPE = "ml.c4.4xlarge" ACCELERATOR_TYPE = "ml.eia.medium" -PYTHON_VERSION = "py" + str(sys.version_info.major) IMAGE_NAME = "sagemaker-pytorch" JOB_NAME = "{}-{}".format(IMAGE_NAME, TIMESTAMP) IMAGE_URI_FORMAT_STRING = "520713654638.dkr.ecr.{}.amazonaws.com/{}:{}-{}-{}" @@ -80,25 +79,30 @@ def fixture_sagemaker_session(): return session -def _get_full_cpu_image_uri(version, py_version=PYTHON_VERSION): +def _get_full_cpu_image_uri(version, py_version): return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, "cpu", py_version) -def _get_full_gpu_image_uri(version, py_version=PYTHON_VERSION): +def _get_full_gpu_image_uri(version, py_version): return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_NAME, version, "gpu", py_version) -def _get_full_cpu_image_uri_with_ei(version, py_version=PYTHON_VERSION): +def _get_full_cpu_image_uri_with_ei(version, py_version): return _get_full_cpu_image_uri(version, py_version=py_version) + "-eia" def _pytorch_estimator( - sagemaker_session, framework_version, train_instance_type=None, base_job_name=None, **kwargs + sagemaker_session, + framework_version, + py_version, + train_instance_type=None, + base_job_name=None, + **kwargs ): return PyTorch( entry_point=SCRIPT_PATH, framework_version=framework_version, - py_version=PYTHON_VERSION, + py_version=py_version, role=ROLE, sagemaker_session=sagemaker_session, train_instance_count=INSTANCE_COUNT, @@ -108,9 +112,9 @@ def _pytorch_estimator( ) -def _create_train_job(version): +def _create_train_job(version, py_version): return { - "image": _get_full_cpu_image_uri(version), + "image": _get_full_cpu_image_uri(version, py_version), "input_mode": "File", "input_config": [ { @@ -256,7 +260,7 @@ def test_create_model_with_custom_image(sagemaker_session): @patch("sagemaker.utils.create_tar_file", MagicMock()) @patch("time.strftime", return_value=TIMESTAMP) -def test_pytorch(strftime, sagemaker_session, pytorch_version): +def test_pytorch(strftime, sagemaker_session, pytorch_version, pytorch_py_version): pytorch = PyTorch( entry_point=SCRIPT_PATH, role=ROLE, @@ -264,7 +268,7 @@ def test_pytorch(strftime, sagemaker_session, pytorch_version): train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, framework_version=pytorch_version, - py_version=PYTHON_VERSION, + py_version=pytorch_py_version, ) inputs = "s3://mybucket/train" @@ -276,7 +280,7 @@ def test_pytorch(strftime, sagemaker_session, pytorch_version): boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] assert boto_call_names == ["resource"] - expected_train_args = _create_train_job(pytorch_version) + expected_train_args = _create_train_job(pytorch_version, pytorch_py_version) expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs expected_train_args["experiment_config"] = EXPERIMENT_CONFIG @@ -296,7 +300,7 @@ def test_pytorch(strftime, sagemaker_session, pytorch_version): "SAGEMAKER_REGION": "us-west-2", "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", }, - "Image": expected_image_base.format(pytorch_version, PYTHON_VERSION), + "Image": expected_image_base.format(pytorch_version, pytorch_py_version), "ModelDataUrl": "s3://m/m.tar.gz", } == model.prepare_container_def(GPU) @@ -328,7 +332,7 @@ def test_mms_model(repack_model, sagemaker_session): entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session, framework_version="1.2", - py_version=PYTHON_VERSION, + py_version="py3", ).deploy(1, GPU) repack_model.assert_called_with( @@ -351,7 +355,7 @@ def test_non_mms_model(repack_model, sagemaker_session): entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session, framework_version="1.1", - py_version=PYTHON_VERSION, + py_version="py3", ).deploy(1, GPU) repack_model.assert_not_called() @@ -400,36 +404,38 @@ def test_train_image_default(sagemaker_session, pytorch_version, pytorch_py_vers assert _get_full_cpu_image_uri(pytorch_version, pytorch_py_version) in pytorch.train_image() -def test_train_image_cpu_instances(sagemaker_session, pytorch_version): +def test_train_image_cpu_instances(sagemaker_session, pytorch_version, pytorch_py_version): pytorch = _pytorch_estimator( - sagemaker_session, pytorch_version, train_instance_type="ml.c2.2xlarge" + sagemaker_session, pytorch_version, pytorch_py_version, train_instance_type="ml.c2.2xlarge" ) - assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version) + assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version, pytorch_py_version) pytorch = _pytorch_estimator( - sagemaker_session, pytorch_version, train_instance_type="ml.c4.2xlarge" + sagemaker_session, pytorch_version, pytorch_py_version, train_instance_type="ml.c4.2xlarge" ) - assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version) + assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version, pytorch_py_version) - pytorch = _pytorch_estimator(sagemaker_session, pytorch_version, train_instance_type="ml.m16") - assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version) + pytorch = _pytorch_estimator( + sagemaker_session, pytorch_version, pytorch_py_version, train_instance_type="ml.m16" + ) + assert pytorch.train_image() == _get_full_cpu_image_uri(pytorch_version, pytorch_py_version) -def test_train_image_gpu_instances(sagemaker_session, pytorch_version): +def test_train_image_gpu_instances(sagemaker_session, pytorch_version, pytorch_py_version): pytorch = _pytorch_estimator( - sagemaker_session, pytorch_version, train_instance_type="ml.g2.2xlarge" + sagemaker_session, pytorch_version, pytorch_py_version, train_instance_type="ml.g2.2xlarge" ) - assert pytorch.train_image() == _get_full_gpu_image_uri(pytorch_version) + assert pytorch.train_image() == _get_full_gpu_image_uri(pytorch_version, pytorch_py_version) pytorch = _pytorch_estimator( - sagemaker_session, pytorch_version, train_instance_type="ml.p2.2xlarge" + sagemaker_session, pytorch_version, pytorch_py_version, train_instance_type="ml.p2.2xlarge" ) - assert pytorch.train_image() == _get_full_gpu_image_uri(pytorch_version) + assert pytorch.train_image() == _get_full_gpu_image_uri(pytorch_version, pytorch_py_version) -def test_attach(sagemaker_session, pytorch_version): +def test_attach(sagemaker_session, pytorch_version, pytorch_py_version): training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:{}-cpu-{}".format( - pytorch_version, PYTHON_VERSION + pytorch_version, pytorch_py_version ) returned_job_description = { "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, @@ -462,7 +468,7 @@ def test_attach(sagemaker_session, pytorch_version): estimator = PyTorch.attach(training_job_name="neo", sagemaker_session=sagemaker_session) assert estimator.latest_training_job.job_name == "neo" - assert estimator.py_version == PYTHON_VERSION + assert estimator.py_version == pytorch_py_version assert estimator.framework_version == pytorch_version assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" assert estimator.train_instance_count == 1 @@ -580,35 +586,41 @@ def test_model_py2_warning(warning, sagemaker_session, pytorch_version): warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION) -def test_pt_enable_sm_metrics(sagemaker_session, pytorch_full_version): +def test_pt_enable_sm_metrics(sagemaker_session, pytorch_version, pytorch_py_version): pytorch = _pytorch_estimator( - sagemaker_session, framework_version=pytorch_full_version, enable_sagemaker_metrics=True + sagemaker_session, + framework_version=pytorch_version, + py_version=pytorch_py_version, + enable_sagemaker_metrics=True, ) assert pytorch.enable_sagemaker_metrics -def test_pt_disable_sm_metrics(sagemaker_session, pytorch_full_version): +def test_pt_disable_sm_metrics(sagemaker_session, pytorch_version, pytorch_py_version): pytorch = _pytorch_estimator( - sagemaker_session, framework_version=pytorch_full_version, enable_sagemaker_metrics=False + sagemaker_session, + framework_version=pytorch_version, + py_version=pytorch_py_version, + enable_sagemaker_metrics=False, ) assert not pytorch.enable_sagemaker_metrics -def test_pt_disable_sm_metrics_if_pt_ver_is_less_than_1_15(sagemaker_session): - for fw_version in ["1.1", "1.2"]: - pytorch = _pytorch_estimator(sagemaker_session, framework_version=fw_version) +def test_pt_default_sm_metrics(sagemaker_session, pytorch_version, pytorch_py_version): + pytorch = _pytorch_estimator( + sagemaker_session, framework_version=pytorch_version, py_version=pytorch_py_version + ) + if Version(pytorch_version) < Version("1.3"): assert pytorch.enable_sagemaker_metrics is None - - -def test_pt_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session): - for fw_version in ["1.3", "1.4", "2.0", "2.1"]: - pytorch = _pytorch_estimator(sagemaker_session, framework_version=fw_version) + else: assert pytorch.enable_sagemaker_metrics -def test_custom_image_estimator_deploy(sagemaker_session, pytorch_full_version): +def test_custom_image_estimator_deploy(sagemaker_session, pytorch_version, pytorch_py_version): custom_image = "mycustomimage:latest" - pytorch = _pytorch_estimator(sagemaker_session, framework_version=pytorch_full_version) + pytorch = _pytorch_estimator( + sagemaker_session, framework_version=pytorch_version, py_version=pytorch_py_version + ) pytorch.fit(inputs="s3://mybucket/train", job_name="new_name") model = pytorch.create_model(image=custom_image) assert model.image == custom_image