Skip to content

Commit 66178b4

Browse files
authored
breaking: require framework_version, py_version for chainer (#1588)
1 parent dbdaf50 commit 66178b4

File tree

7 files changed

+78
-82
lines changed

7 files changed

+78
-82
lines changed

doc/frameworks/chainer/using_chainer.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ directories ('train' and 'test').
141141
train_instance_type='ml.p3.2xlarge',
142142
train_instance_count=1,
143143
framework_version='5.0.0',
144+
py_version='py3',
144145
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
145146
chainer_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
146147
'test': 's3://my-data-bucket/path/to/my/test/data'})
@@ -222,7 +223,8 @@ operation.
222223
chainer_estimator = Chainer(entry_point='train_and_deploy.py',
223224
train_instance_type='ml.p3.2xlarge',
224225
train_instance_count=1,
225-
framework_version='5.0.0')
226+
framework_version='5.0.0',
227+
py_version='py3')
226228
chainer_estimator.fit('s3://my_bucket/my_training_data/')
227229
228230
# Deploy my estimator to a SageMaker Endpoint and get a Predictor

src/sagemaker/chainer/estimator.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
22-
empty_framework_version_warning,
2322
python_deprecation_warning,
23+
validate_version_or_image_args,
2424
)
2525
from sagemaker.chainer import defaults
2626
from sagemaker.chainer.model import ChainerModel
@@ -51,8 +51,8 @@ def __init__(
5151
additional_mpi_options=None,
5252
source_dir=None,
5353
hyperparameters=None,
54-
py_version="py3",
5554
framework_version=None,
55+
py_version=None,
5656
image_name=None,
5757
**kwargs
5858
):
@@ -103,11 +103,12 @@ def __init__(
103103
and values, but ``str()`` will be called to convert them before
104104
training.
105105
py_version (str): Python version you want to use for executing your
106-
model training code (default: 'py2'). One of 'py2' or 'py3'.
106+
model training code. Defaults to ``None``. Required unless ``image_name``
107+
is provided.
107108
framework_version (str): Chainer version you want to use for
108-
executing your model training code. List of supported versions
109+
executing your model training code. Defaults to ``None``. Required unless
110+
``image_name`` is provided. List of supported versions:
109111
https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators.
110-
If not specified, this will default to 4.1.
111112
image_name (str): If specified, the estimator will use this image
112113
for training and hosting, instead of selecting the appropriate
113114
SageMaker official image based on framework_version and
@@ -117,6 +118,9 @@ def __init__(
117118
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
118119
* ``custom-image:latest``
119120
121+
If ``framework_version`` or ``py_version`` are ``None``, then
122+
``image_name`` is required. If also ``None``, then a ``ValueError``
123+
will be raised.
120124
**kwargs: Additional kwargs passed to the
121125
:class:`~sagemaker.estimator.Framework` constructor.
122126
@@ -126,22 +130,18 @@ def __init__(
126130
:class:`~sagemaker.estimator.Framework` and
127131
:class:`~sagemaker.estimator.EstimatorBase`.
128132
"""
129-
if framework_version is None:
133+
validate_version_or_image_args(framework_version, py_version, image_name)
134+
if py_version == "py2":
130135
logger.warning(
131-
empty_framework_version_warning(defaults.CHAINER_VERSION, self.LATEST_VERSION)
136+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
132137
)
133-
self.framework_version = framework_version or defaults.CHAINER_VERSION
138+
self.framework_version = framework_version
139+
self.py_version = py_version
134140

135141
super(Chainer, self).__init__(
136142
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
137143
)
138144

139-
if py_version == "py2":
140-
logger.warning(
141-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
142-
)
143-
144-
self.py_version = py_version
145145
self.use_mpi = use_mpi
146146
self.num_processes = num_processes
147147
self.process_slots_per_host = process_slots_per_host
@@ -262,15 +262,19 @@ class constructor
262262
image_name = init_params.pop("image")
263263
framework, py_version, tag, _ = framework_name_from_image(image_name)
264264

265+
if tag is None:
266+
framework_version = None
267+
else:
268+
framework_version = framework_version_from_tag(tag)
269+
init_params["framework_version"] = framework_version
270+
init_params["py_version"] = py_version
271+
265272
if not framework:
266273
# If we were unable to parse the framework name from the image it is not one of our
267274
# officially supported images, in this case just add the image to the init params.
268275
init_params["image_name"] = image_name
269276
return init_params
270277

271-
init_params["py_version"] = py_version
272-
init_params["framework_version"] = framework_version_from_tag(tag)
273-
274278
training_job_name = init_params["base_job_name"]
275279

276280
if framework != cls.__framework_name__:

src/sagemaker/chainer/model.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
create_image_uri,
2323
model_code_key_prefix,
2424
python_deprecation_warning,
25-
empty_framework_version_warning,
25+
validate_version_or_image_args,
2626
)
2727
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2828
from sagemaker.chainer import defaults
@@ -67,8 +67,8 @@ def __init__(
6767
role,
6868
entry_point,
6969
image=None,
70-
py_version="py3",
7170
framework_version=None,
71+
py_version=None,
7272
predictor_cls=ChainerPredictor,
7373
model_server_workers=None,
7474
**kwargs
@@ -88,11 +88,15 @@ def __init__(
8888
hosting. If ``source_dir`` is specified, then ``entry_point``
8989
must point to a file located at the root of ``source_dir``.
9090
image (str): A Docker image URI (default: None). If not specified, a
91-
default image for Chainer will be used.
92-
py_version (str): Python version you want to use for executing your
93-
model training code (default: 'py2').
91+
default image for Chainer will be used. If ``framework_version``
92+
or ``py_version`` are ``None``, then ``image`` is required. If
93+
also ``None``, then a ``ValueError`` will be raised.
9494
framework_version (str): Chainer version you want to use for
95-
executing your model training code.
95+
executing your model training code. Defaults to ``None``. Required
96+
unless ``image`` is provided.
97+
py_version (str): Python version you want to use for executing your
98+
model training code. Defaults to ``None``. Required unless
99+
``image`` is provided.
96100
predictor_cls (callable[str, sagemaker.session.Session]): A function
97101
to call to create a predictor with an endpoint name and
98102
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -109,21 +113,18 @@ def __init__(
109113
:class:`~sagemaker.model.FrameworkModel` and
110114
:class:`~sagemaker.model.Model`.
111115
"""
112-
super(ChainerModel, self).__init__(
113-
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
114-
)
116+
validate_version_or_image_args(framework_version, py_version, image)
115117
if py_version == "py2":
116118
logger.warning(
117119
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
118120
)
121+
self.framework_version = framework_version
122+
self.py_version = py_version
119123

120-
if framework_version is None:
121-
logger.warning(
122-
empty_framework_version_warning(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)
123-
)
124+
super(ChainerModel, self).__init__(
125+
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
126+
)
124127

125-
self.py_version = py_version
126-
self.framework_version = framework_version or defaults.CHAINER_VERSION
127128
self.model_server_workers = model_server_workers
128129

129130
def prepare_container_def(self, instance_type, accelerator_type=None):

tests/integ/test_chainer_train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,16 @@ def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_typ
100100

101101

102102
@pytest.mark.local_mode
103-
def test_deploy_model(chainer_local_training_job, sagemaker_local_session):
103+
def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chainer_full_version):
104104
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
105105

106106
model = ChainerModel(
107107
chainer_local_training_job.model_data,
108108
"SageMakerRole",
109109
entry_point=script_path,
110110
sagemaker_session=sagemaker_local_session,
111+
framework_version=chainer_full_version,
112+
py_version=PYTHON_VERSION,
111113
)
112114

113115
predictor = model.deploy(1, "local")

tests/integ/test_tuner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,14 +689,15 @@ def test_tuning_tf_vpc_multi(sagemaker_session, cpu_instance_type):
689689

690690

691691
@pytest.mark.canary_quick
692-
def test_tuning_chainer(sagemaker_session, cpu_instance_type):
692+
def test_tuning_chainer(sagemaker_session, chainer_full_version, cpu_instance_type):
693693
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
694694
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
695695
data_path = os.path.join(DATA_DIR, "chainer_mnist")
696696

697697
estimator = Chainer(
698698
entry_point=script_path,
699699
role="SageMakerRole",
700+
framework_version=chainer_full_version,
700701
py_version=PYTHON_VERSION,
701702
train_instance_count=1,
702703
train_instance_type=cpu_instance_type,

0 commit comments

Comments
 (0)