Skip to content

breaking: require framework_version, py_version for chainer #1588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/frameworks/chainer/using_chainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ directories ('train' and 'test').
train_instance_type='ml.p3.2xlarge',
train_instance_count=1,
framework_version='5.0.0',
py_version='py3',
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
chainer_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
'test': 's3://my-data-bucket/path/to/my/test/data'})
Expand Down Expand Up @@ -222,7 +223,8 @@ operation.
chainer_estimator = Chainer(entry_point='train_and_deploy.py',
train_instance_type='ml.p3.2xlarge',
train_instance_count=1,
framework_version='5.0.0')
framework_version='5.0.0',
py_version='py3')
chainer_estimator.fit('s3://my_bucket/my_training_data/')

# Deploy my estimator to a SageMaker Endpoint and get a Predictor
Expand Down
38 changes: 21 additions & 17 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
empty_framework_version_warning,
python_deprecation_warning,
validate_version_or_image_args,
)
from sagemaker.chainer import defaults
from sagemaker.chainer.model import ChainerModel
Expand Down Expand Up @@ -51,8 +51,8 @@ def __init__(
additional_mpi_options=None,
source_dir=None,
hyperparameters=None,
py_version="py3",
framework_version=None,
py_version=None,
image_name=None,
**kwargs
):
Expand Down Expand Up @@ -103,11 +103,12 @@ def __init__(
and values, but ``str()`` will be called to convert them before
training.
py_version (str): Python version you want to use for executing your
model training code (default: 'py2'). One of 'py2' or 'py3'.
model training code. Defaults to ``None``. Required unless ``image_name``
is provided.
framework_version (str): Chainer version you want to use for
executing your model training code. List of supported versions
executing your model training code. Defaults to ``None``. Required unless
``image_name`` is provided. List of supported versions:
https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators.
If not specified, this will default to 4.1.
image_name (str): If specified, the estimator will use this image
for training and hosting, instead of selecting the appropriate
SageMaker official image based on framework_version and
Expand All @@ -117,6 +118,9 @@ def __init__(
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
* ``custom-image:latest``

If ``framework_version`` or ``py_version`` are ``None``, then
``image_name`` is required. If also ``None``, then a ``ValueError``
will be raised.
**kwargs: Additional kwargs passed to the
:class:`~sagemaker.estimator.Framework` constructor.

Expand All @@ -126,22 +130,18 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
validate_version_or_image_args(framework_version, py_version, image_name)
if py_version == "py2":
logger.warning(
empty_framework_version_warning(defaults.CHAINER_VERSION, self.LATEST_VERSION)
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version or defaults.CHAINER_VERSION
self.framework_version = framework_version
self.py_version = py_version

super(Chainer, self).__init__(
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
)

if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)

self.py_version = py_version
self.use_mpi = use_mpi
self.num_processes = num_processes
self.process_slots_per_host = process_slots_per_host
Expand Down Expand Up @@ -262,15 +262,19 @@ class constructor
image_name = init_params.pop("image")
framework, py_version, tag, _ = framework_name_from_image(image_name)

if tag is None:
framework_version = None
else:
framework_version = framework_version_from_tag(tag)
init_params["framework_version"] = framework_version
init_params["py_version"] = py_version

if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params["image_name"] = image_name
return init_params

init_params["py_version"] = py_version
init_params["framework_version"] = framework_version_from_tag(tag)

training_job_name = init_params["base_job_name"]

if framework != cls.__framework_name__:
Expand Down
31 changes: 16 additions & 15 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
create_image_uri,
model_code_key_prefix,
python_deprecation_warning,
empty_framework_version_warning,
validate_version_or_image_args,
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.chainer import defaults
Expand Down Expand Up @@ -67,8 +67,8 @@ def __init__(
role,
entry_point,
image=None,
py_version="py3",
framework_version=None,
py_version=None,
predictor_cls=ChainerPredictor,
model_server_workers=None,
**kwargs
Expand All @@ -88,11 +88,15 @@ def __init__(
hosting. If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
image (str): A Docker image URI (default: None). If not specified, a
default image for Chainer will be used.
py_version (str): Python version you want to use for executing your
model training code (default: 'py2').
default image for Chainer will be used. If ``framework_version``
or ``py_version`` are ``None``, then ``image`` is required. If
also ``None``, then a ``ValueError`` will be raised.
framework_version (str): Chainer version you want to use for
executing your model training code.
executing your model training code. Defaults to ``None``. Required
unless ``image`` is provided.
py_version (str): Python version you want to use for executing your
model training code. Defaults to ``None``. Required unless
``image`` is provided.
predictor_cls (callable[str, sagemaker.session.Session]): A function
to call to create a predictor with an endpoint name and
SageMaker ``Session``. If specified, ``deploy()`` returns the
Expand All @@ -109,21 +113,18 @@ def __init__(
:class:`~sagemaker.model.FrameworkModel` and
:class:`~sagemaker.model.Model`.
"""
super(ChainerModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)
validate_version_or_image_args(framework_version, py_version, image)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
)
self.framework_version = framework_version
self.py_version = py_version

if framework_version is None:
logger.warning(
empty_framework_version_warning(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)
)
super(ChainerModel, self).__init__(
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
)

self.py_version = py_version
self.framework_version = framework_version or defaults.CHAINER_VERSION
self.model_server_workers = model_server_workers

def prepare_container_def(self, instance_type, accelerator_type=None):
Expand Down
4 changes: 3 additions & 1 deletion tests/integ/test_chainer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,16 @@ def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_typ


@pytest.mark.local_mode
def test_deploy_model(chainer_local_training_job, sagemaker_local_session):
def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chainer_full_version):
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")

model = ChainerModel(
chainer_local_training_job.model_data,
"SageMakerRole",
entry_point=script_path,
sagemaker_session=sagemaker_local_session,
framework_version=chainer_full_version,
py_version=PYTHON_VERSION,
)

predictor = model.deploy(1, "local")
Expand Down
3 changes: 2 additions & 1 deletion tests/integ/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,14 +689,15 @@ def test_tuning_tf_vpc_multi(sagemaker_session, cpu_instance_type):


@pytest.mark.canary_quick
def test_tuning_chainer(sagemaker_session, cpu_instance_type):
def test_tuning_chainer(sagemaker_session, chainer_full_version, cpu_instance_type):
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
data_path = os.path.join(DATA_DIR, "chainer_mnist")

estimator = Chainer(
entry_point=script_path,
role="SageMakerRole",
framework_version=chainer_full_version,
py_version=PYTHON_VERSION,
train_instance_count=1,
train_instance_type=cpu_instance_type,
Expand Down
Loading