Skip to content

Commit 99279ff

Browse files
committed
breaking: require framework_version, py_version for tensorflow
1 parent 9977206 commit 99279ff

15 files changed

+151
-82
lines changed

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ the difference in code would be as follows:
104104
...
105105
source_dir="code",
106106
framework_version="1.10.0",
107+
py_version="py3",
107108
train_instance_type="ml.m4.xlarge",
108109
image_name="520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.10.0-cpu-py2",
109110
hyperparameters={

doc/frameworks/tensorflow/using_tf.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ To run training job with Pipe input mode, pass in ``input_mode='Pipe'`` to your
319319
tf_estimator = TensorFlow(entry_point='tf-train-with-pipemodedataset.py', role='SageMakerRole',
320320
training_steps=10000, evaluation_steps=100,
321321
train_instance_count=1, train_instance_type='ml.p2.xlarge',
322-
framework_version='1.10.0', input_mode='Pipe')
322+
framework_version='1.10.0', py_version='py3', input_mode='Pipe')
323323
324324
tf_estimator.fit('s3://bucket/path/to/training/data')
325325
@@ -383,7 +383,8 @@ estimator object to create a SageMaker Endpoint:
383383
from sagemaker.tensorflow import TensorFlow
384384
385385
estimator = TensorFlow(entry_point='tf-train.py', ..., train_instance_count=1,
386-
train_instance_type='ml.c4.xlarge', framework_version='1.11')
386+
train_instance_type='ml.c4.xlarge', framework_version='1.11',
387+
py_version='py3')
387388
388389
estimator.fit(inputs)
389390

src/sagemaker/rl/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ def create_model(
254254
)
255255

256256
if self.framework == RLFramework.TENSORFLOW.value:
257-
return TensorFlowModel(framework_version=self.framework_version, **base_args)
257+
return TensorFlowModel(
258+
framework_version=self.framework_version, py_version=PYTHON_VERSION, **base_args
259+
)
258260
if self.framework == RLFramework.MXNET.value:
259261
return MXNetModel(
260262
framework_version=self.framework_version, py_version=PYTHON_VERSION, **extended_args

src/sagemaker/tensorflow/estimator.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,12 @@ def __init__(
5656
5757
Args:
5858
py_version (str): Python version you want to use for executing your model training
59-
code (default: 'py2').
59+
code. One of 'py2', 'py3', or 'py37'. Defaults to ``None``. Required unless
60+
``image_name`` is provided.
6061
framework_version (str): TensorFlow version you want to use for executing your model
61-
training code. List of supported versions
62+
training code. Defaults to ``None``. Required unless ``image_name`` is provided.
63+
List of supported versions:
6264
https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators.
63-
If not specified, this will default to 1.11.
6465
model_dir (str): S3 location where the checkpoint data and models can be exported to
6566
during training (default: None). It will be passed in the training script as one of
6667
the command line arguments. If not specified, one is provided based on
@@ -81,6 +82,10 @@ def __init__(
8182
Examples:
8283
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
8384
custom-image:latest.
85+
86+
If ``framework_version`` or ``py_version`` are ``None``, then
87+
``image_name`` is required. If also ``None``, then a ``ValueError``
88+
will be raised.
8489
distributions (dict): A dictionary with information on how to run distributed training
8590
(default: None). Currently we support distributed training with parameter servers
8691
and MPI.
@@ -114,18 +119,13 @@ def __init__(
114119
:class:`~sagemaker.estimator.Framework` and
115120
:class:`~sagemaker.estimator.EstimatorBase`.
116121
"""
117-
if framework_version is None:
118-
logger.warning(
119-
fw.empty_framework_version_warning(defaults.TF_VERSION, self.LATEST_VERSION)
120-
)
121-
self.framework_version = framework_version or defaults.TF_VERSION
122-
123-
if not py_version:
124-
py_version = "py3" if self._only_python_3_supported() else "py2"
122+
fw.validate_version_or_image_args(framework_version, py_version, image_name)
125123
if py_version == "py2":
126124
logger.warning(
127125
fw.python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
128126
)
127+
self.framework_version = framework_version
128+
self.py_version = py_version
129129

130130
if distributions is not None:
131131
logger.warning(fw.parameter_v2_rename_warning("distribution", distributions))
@@ -136,12 +136,11 @@ def __init__(
136136

137137
if "enable_sagemaker_metrics" not in kwargs:
138138
# enable sagemaker metrics for TF v1.15 or greater:
139-
if fw.is_version_equal_or_higher([1, 15], self.framework_version):
139+
if framework_version and fw.is_version_equal_or_higher([1, 15], framework_version):
140140
kwargs["enable_sagemaker_metrics"] = True
141141

142142
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
143143

144-
self.py_version = py_version
145144
self.model_dir = model_dir
146145
self.distributions = distributions or {}
147146

@@ -150,7 +149,7 @@ def __init__(
150149
def _validate_args(self, py_version, framework_version):
151150
"""Placeholder docstring"""
152151

153-
if py_version == "py3":
152+
if py_version:
154153
if framework_version is None:
155154
raise AttributeError(fw.EMPTY_FRAMEWORK_VERSION_ERROR)
156155

@@ -161,7 +160,7 @@ def _validate_args(self, py_version, framework_version):
161160
)
162161
raise AttributeError(msg)
163162

164-
if self._only_legacy_mode_supported() and self.image_name is None:
163+
if self.image_name is None and self._only_legacy_mode_supported():
165164
legacy_image_uri = fw.create_image_uri(
166165
self.sagemaker_session.boto_region_name,
167166
"tensorflow",
@@ -294,6 +293,7 @@ def create_model(
294293
role=role or self.role,
295294
container_log_level=self.container_log_level,
296295
framework_version=self.framework_version,
296+
py_version=self.py_version,
297297
sagemaker_session=self.sagemaker_session,
298298
vpc_config=self.get_vpc_config(vpc_config_override),
299299
entry_point=entry_point,

src/sagemaker/tensorflow/model.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
import sagemaker
1919
from sagemaker.content_types import CONTENT_TYPE_JSON
20-
from sagemaker.fw_utils import create_image_uri
20+
from sagemaker.fw_utils import create_image_uri, validate_version_or_image_args
2121
from sagemaker.predictor import json_serializer, json_deserializer
22-
from sagemaker.tensorflow.defaults import TF_VERSION
2322

2423

2524
class TensorFlowPredictor(sagemaker.RealTimePredictor):
@@ -138,7 +137,8 @@ def __init__(
138137
role,
139138
entry_point=None,
140139
image=None,
141-
framework_version=TF_VERSION,
140+
framework_version=None,
141+
py_version=None,
142142
container_log_level=None,
143143
predictor_cls=TensorFlowPredictor,
144144
**kwargs
@@ -158,9 +158,16 @@ def __init__(
158158
hosting. If ``source_dir`` is specified, then ``entry_point``
159159
must point to a file located at the root of ``source_dir``.
160160
image (str): A Docker image URI (default: None). If not specified, a
161-
default image for TensorFlow Serving will be used.
161+
default image for TensorFlow Serving will be used. If
162+
``framework_version`` or ``py_version`` are ``None``, then
163+
``image`` is required. If also ``None``, then a ``ValueError``
164+
will be raised.
162165
framework_version (str): Optional. TensorFlow Serving version you
163-
want to use.
166+
want to use. Defaults to ``None``. Required unless ``image`` is
167+
provided.
168+
py_version (str): Python version you want to use for executing your
169+
model training code. One of 'py2', 'py3', or 'py37'. Defaults to
170+
``None``. Required unless ``image`` is provided.
164171
container_log_level (int): Log level to use within the container
165172
(default: logging.ERROR). Valid values are defined in the Python
166173
logging module.
@@ -176,6 +183,10 @@ def __init__(
176183
:class:`~sagemaker.model.FrameworkModel` and
177184
:class:`~sagemaker.model.Model`.
178185
"""
186+
validate_version_or_image_args(framework_version, py_version, image)
187+
self.framework_version = framework_version
188+
self.py_version = py_version
189+
179190
super(TensorFlowModel, self).__init__(
180191
model_data=model_data,
181192
role=role,
@@ -184,7 +195,6 @@ def __init__(
184195
entry_point=entry_point,
185196
**kwargs
186197
)
187-
self.framework_version = framework_version
188198
self._container_log_level = container_log_level
189199

190200
def deploy(

tests/integ/test_data_capture_config.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import os
16+
import pytest
1617

1718
import sagemaker
1819
import tests.integ
@@ -40,8 +41,13 @@
4041
CUSTOM_JSON_CONTENT_TYPES = ["application/jsontype1", "application/jsontype2"]
4142

4243

44+
@pytest.fixture(scope="module")
45+
def py_version(tf_full_version, tf_serving_version):
46+
return "py37" if tf_full_version == tf_serving_version else tests.integ.PYTHON_VERSION
47+
48+
4349
def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
44-
sagemaker_session, tf_serving_version
50+
sagemaker_session, tf_serving_version, py_version
4551
):
4652
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
4753
model_data = sagemaker_session.upload_data(
@@ -53,6 +59,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
5359
model_data=model_data,
5460
role=ROLE,
5561
framework_version=tf_serving_version,
62+
py_version=py_version,
5663
sagemaker_session=sagemaker_session,
5764
)
5865
predictor = model.deploy(
@@ -98,7 +105,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
98105

99106

100107
def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
101-
sagemaker_session, tf_serving_version
108+
sagemaker_session, tf_serving_version, py_version
102109
):
103110
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
104111
model_data = sagemaker_session.upload_data(
@@ -110,6 +117,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
110117
model_data=model_data,
111118
role=ROLE,
112119
framework_version=tf_serving_version,
120+
py_version=py_version,
113121
sagemaker_session=sagemaker_session,
114122
)
115123
destination_s3_uri = os.path.join(
@@ -184,7 +192,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
184192

185193

186194
def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
187-
sagemaker_session, tf_serving_version
195+
sagemaker_session, tf_serving_version, py_version
188196
):
189197
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
190198
model_data = sagemaker_session.upload_data(
@@ -196,6 +204,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
196204
model_data=model_data,
197205
role=ROLE,
198206
framework_version=tf_serving_version,
207+
py_version=py_version,
199208
sagemaker_session=sagemaker_session,
200209
)
201210
destination_s3_uri = os.path.join(

tests/integ/test_model_monitor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,12 @@
8888

8989

9090
@pytest.fixture(scope="module")
91-
def predictor(sagemaker_session, tf_serving_version):
91+
def py_version(tf_full_version, tf_serving_version):
92+
return "py37" if tf_full_version == tf_serving_version else tests.integ.PYTHON_VERSION
93+
94+
95+
@pytest.fixture(scope="module")
96+
def predictor(sagemaker_session, tf_serving_version, py_version):
9297
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
9398
model_data = sagemaker_session.upload_data(
9499
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
@@ -101,6 +106,7 @@ def predictor(sagemaker_session, tf_serving_version):
101106
model_data=model_data,
102107
role=ROLE,
103108
framework_version=tf_serving_version,
109+
py_version=py_version,
104110
sagemaker_session=sagemaker_session,
105111
)
106112
predictor = model.deploy(

tests/integ/test_tf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_mnist_with_checkpoint_config(
5959
train_instance_type=instance_type,
6060
sagemaker_session=sagemaker_session,
6161
framework_version=tf_full_version,
62-
py_version="py37",
62+
py_version=py_version,
6363
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
6464
checkpoint_s3_uri=checkpoint_s3_uri,
6565
checkpoint_local_path=checkpoint_local_path,
@@ -137,8 +137,8 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
137137
train_instance_count=2,
138138
train_instance_type=instance_type,
139139
sagemaker_session=sagemaker_session,
140-
py_version="py37",
141140
framework_version=tf_full_version,
141+
py_version=py_version,
142142
distributions=PARAMETER_SERVER_DISTRIBUTION,
143143
)
144144
inputs = estimator.sagemaker_session.upload_data(

tests/integ/test_tf_efs_fsx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
FSX_DIR_PATH = "/fsx/tensorflow"
3737
MAX_JOBS = 2
3838
MAX_PARALLEL_JOBS = 2
39-
PY_VERSION = "py3"
39+
PY_VERSION = "py37"
4040

4141

4242
@pytest.fixture(scope="module")
@@ -139,8 +139,8 @@ def test_tuning_tf_efs(efs_fsx_setup, sagemaker_session, cpu_instance_type):
139139
train_instance_count=1,
140140
train_instance_type=cpu_instance_type,
141141
sagemaker_session=sagemaker_session,
142-
py_version=PY_VERSION,
143142
framework_version=TensorFlow.LATEST_VERSION,
143+
py_version=PY_VERSION,
144144
subnets=subnets,
145145
security_group_ids=security_group_ids,
146146
)
@@ -186,8 +186,8 @@ def test_tuning_tf_lustre(efs_fsx_setup, sagemaker_session, cpu_instance_type):
186186
train_instance_count=1,
187187
train_instance_type=cpu_instance_type,
188188
sagemaker_session=sagemaker_session,
189-
py_version=PY_VERSION,
190189
framework_version=TensorFlow.LATEST_VERSION,
190+
py_version=PY_VERSION,
191191
subnets=subnets,
192192
security_group_ids=security_group_ids,
193193
)

tests/integ/test_tfs.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727

2828

2929
@pytest.fixture(scope="module")
30-
def tfs_predictor(sagemaker_session, tf_serving_version):
30+
def py_version(tf_full_version, tf_serving_version):
31+
return "py37" if tf_full_version == tf_serving_version else tests.integ.PYTHON_VERSION
32+
33+
34+
@pytest.fixture(scope="module")
35+
def tfs_predictor(sagemaker_session, tf_serving_version, py_version):
3136
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
3237
model_data = sagemaker_session.upload_data(
3338
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
@@ -38,6 +43,7 @@ def tfs_predictor(sagemaker_session, tf_serving_version):
3843
model_data=model_data,
3944
role="SageMakerRole",
4045
framework_version=tf_serving_version,
46+
py_version=py_version,
4147
sagemaker_session=sagemaker_session,
4248
)
4349
predictor = model.deploy(1, "ml.c5.xlarge", endpoint_name=endpoint_name)
@@ -54,7 +60,7 @@ def tar_dir(directory, tmpdir):
5460

5561
@pytest.fixture
5662
def tfs_predictor_with_model_and_entry_point_same_tar(
57-
sagemaker_local_session, tf_serving_version, tmpdir
63+
sagemaker_local_session, tf_serving_version, py_version, tmpdir
5864
):
5965
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
6066

@@ -66,6 +72,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(
6672
model_data="file://" + model_tar,
6773
role="SageMakerRole",
6874
framework_version=tf_serving_version,
75+
py_version=py_version,
6976
sagemaker_session=sagemaker_local_session,
7077
)
7178
predictor = model.deploy(1, "local", endpoint_name=endpoint_name)
@@ -78,7 +85,7 @@ def tfs_predictor_with_model_and_entry_point_same_tar(
7885

7986
@pytest.fixture(scope="module")
8087
def tfs_predictor_with_model_and_entry_point_and_dependencies(
81-
sagemaker_local_session, tf_serving_version
88+
sagemaker_local_session, tf_serving_version, py_version
8289
):
8390
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
8491

@@ -99,6 +106,7 @@ def tfs_predictor_with_model_and_entry_point_and_dependencies(
99106
role="SageMakerRole",
100107
dependencies=dependencies,
101108
framework_version=tf_serving_version,
109+
py_version=py_version,
102110
sagemaker_session=sagemaker_local_session,
103111
)
104112

@@ -122,6 +130,7 @@ def tfs_predictor_with_accelerator(sagemaker_session, ei_tf_full_version, cpu_in
122130
model_data=model_data,
123131
role="SageMakerRole",
124132
framework_version=ei_tf_full_version,
133+
py_version=tests.integ.PYTHON_VERSION,
125134
sagemaker_session=sagemaker_session,
126135
)
127136
predictor = model.deploy(

0 commit comments

Comments
 (0)