Skip to content

Commit d320b88

Browse files
authored
Merge branch 'zwei' into require-framework-version-chainer
2 parents 5fc1125 + dbdaf50 commit d320b88

File tree

12 files changed

+173
-164
lines changed

12 files changed

+173
-164
lines changed

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ the difference in code would be as follows:
104104
...
105105
source_dir="code",
106106
framework_version="1.10.0",
107+
py_version="py2",
107108
train_instance_type="ml.m4.xlarge",
108109
image_name="520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.10.0-cpu-py2",
109110
hyperparameters={

doc/frameworks/tensorflow/using_tf.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ To run training job with Pipe input mode, pass in ``input_mode='Pipe'`` to your
319319
tf_estimator = TensorFlow(entry_point='tf-train-with-pipemodedataset.py', role='SageMakerRole',
320320
training_steps=10000, evaluation_steps=100,
321321
train_instance_count=1, train_instance_type='ml.p2.xlarge',
322-
framework_version='1.10.0', input_mode='Pipe')
322+
framework_version='1.10.0', py_version='py3', input_mode='Pipe')
323323
324324
tf_estimator.fit('s3://bucket/path/to/training/data')
325325
@@ -383,7 +383,8 @@ estimator object to create a SageMaker Endpoint:
383383
from sagemaker.tensorflow import TensorFlow
384384
385385
estimator = TensorFlow(entry_point='tf-train.py', ..., train_instance_count=1,
386-
train_instance_type='ml.c4.xlarge', framework_version='1.11')
386+
train_instance_type='ml.c4.xlarge', framework_version='1.11',
387+
py_version='py3')
387388
388389
estimator.fit(inputs)
389390

src/sagemaker/tensorflow/estimator.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def __init__(
5656
5757
Args:
5858
py_version (str): Python version you want to use for executing your model training
59-
code (default: 'py2').
59+
code. Defaults to ``None``. Required unless ``image_name`` is provided.
6060
framework_version (str): TensorFlow version you want to use for executing your model
61-
training code. List of supported versions
61+
training code. Defaults to ``None``. Required unless ``image_name`` is provided.
62+
List of supported versions:
6263
https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators.
63-
If not specified, this will default to 1.11.
6464
model_dir (str): S3 location where the checkpoint data and models can be exported to
6565
during training (default: None). It will be passed in the training script as one of
6666
the command line arguments. If not specified, one is provided based on
@@ -81,6 +81,10 @@ def __init__(
8181
Examples:
8282
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
8383
custom-image:latest.
84+
85+
If ``framework_version`` or ``py_version`` are ``None``, then
86+
``image_name`` is required. If also ``None``, then a ``ValueError``
87+
will be raised.
8488
distributions (dict): A dictionary with information on how to run distributed training
8589
(default: None). Currently we support distributed training with parameter servers
8690
and MPI.
@@ -114,18 +118,13 @@ def __init__(
114118
:class:`~sagemaker.estimator.Framework` and
115119
:class:`~sagemaker.estimator.EstimatorBase`.
116120
"""
117-
if framework_version is None:
118-
logger.warning(
119-
fw.empty_framework_version_warning(defaults.TF_VERSION, self.LATEST_VERSION)
120-
)
121-
self.framework_version = framework_version or defaults.TF_VERSION
122-
123-
if not py_version:
124-
py_version = "py3" if self._only_python_3_supported() else "py2"
121+
fw.validate_version_or_image_args(framework_version, py_version, image_name)
125122
if py_version == "py2":
126123
logger.warning(
127124
fw.python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
128125
)
126+
self.framework_version = framework_version
127+
self.py_version = py_version
129128

130129
if distributions is not None:
131130
logger.warning(fw.parameter_v2_rename_warning("distribution", distributions))
@@ -136,32 +135,27 @@ def __init__(
136135

137136
if "enable_sagemaker_metrics" not in kwargs:
138137
# enable sagemaker metrics for TF v1.15 or greater:
139-
if fw.is_version_equal_or_higher([1, 15], self.framework_version):
138+
if framework_version and fw.is_version_equal_or_higher([1, 15], framework_version):
140139
kwargs["enable_sagemaker_metrics"] = True
141140

142141
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
143142

144-
self.py_version = py_version
145143
self.model_dir = model_dir
146144
self.distributions = distributions or {}
147145

148-
self._validate_args(py_version=py_version, framework_version=self.framework_version)
146+
self._validate_args(py_version=py_version)
149147

150-
def _validate_args(self, py_version, framework_version):
148+
def _validate_args(self, py_version):
151149
"""Placeholder docstring"""
152150

153-
if py_version == "py3":
154-
if framework_version is None:
155-
raise AttributeError(fw.EMPTY_FRAMEWORK_VERSION_ERROR)
156-
157151
if py_version == "py2" and self._only_python_3_supported():
158152
msg = (
159153
"Python 2 containers are only available with {} and lower versions. "
160154
"Please use a Python 3 container.".format(defaults.LATEST_PY2_VERSION)
161155
)
162156
raise AttributeError(msg)
163157

164-
if self._only_legacy_mode_supported() and self.image_name is None:
158+
if self.image_name is None and self._only_legacy_mode_supported():
165159
legacy_image_uri = fw.create_image_uri(
166160
self.sagemaker_session.boto_region_name,
167161
"tensorflow",

src/sagemaker/tensorflow/model.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from sagemaker.content_types import CONTENT_TYPE_JSON
2020
from sagemaker.fw_utils import create_image_uri
2121
from sagemaker.predictor import json_serializer, json_deserializer
22-
from sagemaker.tensorflow.defaults import TF_VERSION
2322

2423

2524
class TensorFlowPredictor(sagemaker.RealTimePredictor):
@@ -138,7 +137,7 @@ def __init__(
138137
role,
139138
entry_point=None,
140139
image=None,
141-
framework_version=TF_VERSION,
140+
framework_version=None,
142141
container_log_level=None,
143142
predictor_cls=TensorFlowPredictor,
144143
**kwargs
@@ -158,9 +157,12 @@ def __init__(
158157
hosting. If ``source_dir`` is specified, then ``entry_point``
159158
must point to a file located at the root of ``source_dir``.
160159
image (str): A Docker image URI (default: None). If not specified, a
161-
default image for TensorFlow Serving will be used.
160+
default image for TensorFlow Serving will be used. If
161+
``framework_version`` is ``None``, then ``image`` is required.
162+
If also ``None``, then a ``ValueError`` will be raised.
162163
framework_version (str): Optional. TensorFlow Serving version you
163-
want to use.
164+
want to use. Defaults to ``None``. Required unless ``image`` is
165+
provided.
164166
container_log_level (int): Log level to use within the container
165167
(default: logging.ERROR). Valid values are defined in the Python
166168
logging module.
@@ -176,6 +178,13 @@ def __init__(
176178
:class:`~sagemaker.model.FrameworkModel` and
177179
:class:`~sagemaker.model.Model`.
178180
"""
181+
if framework_version is None and image is None:
182+
raise ValueError(
183+
"Both framework_version and image were None. "
184+
"Either specify framework_version or specify image_name."
185+
)
186+
self.framework_version = framework_version
187+
179188
super(TensorFlowModel, self).__init__(
180189
model_data=model_data,
181190
role=role,
@@ -184,7 +193,6 @@ def __init__(
184193
entry_point=entry_point,
185194
**kwargs
186195
)
187-
self.framework_version = framework_version
188196
self._container_log_level = container_log_level
189197

190198
def deploy(

tests/conftest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,16 @@ def tf_version(request):
226226
return request.param
227227

228228

229+
@pytest.fixture(scope="module", params=["py2", "py3"])
230+
def tf_py_version(tf_version, request):
231+
version = [int(val) for val in tf_version.split(".")]
232+
if version < [1, 11]:
233+
return "py2"
234+
if version < [2, 2]:
235+
return request.param
236+
return "py37"
237+
238+
229239
@pytest.fixture(scope="module", params=["0.10.1", "0.10.1", "0.11", "0.11.0", "0.11.1"])
230240
def rl_coach_tf_version(request):
231241
return request.param
@@ -290,6 +300,23 @@ def tf_full_version(request):
290300
return tf_version
291301

292302

303+
@pytest.fixture(scope="module")
304+
def tf_full_py_version(tf_full_version, request):
305+
"""fixture to match tf_full_version
306+
307+
Fixture exists as such, since tf_full_version may be overridden --tf-full-version.
308+
Otherwise, this would simply be py37 to match the latest version support.
309+
310+
TODO: Evaluate use of --tf-full-version with possible eye to remove and simplify code.
311+
"""
312+
version = [int(val) for val in tf_full_version.split(".")]
313+
if version < [1, 11]:
314+
return "py2"
315+
if tf_full_version in [TensorFlow._LATEST_1X_VERSION, LATEST_VERSION]:
316+
return "py37"
317+
return "py3"
318+
319+
293320
@pytest.fixture(scope="module", params=["1.15.0", "2.0.0"])
294321
def ei_tf_full_version(request):
295322
tf_ei_version = request.config.getoption("--ei-tf-full-version")

tests/integ/test_airflow_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def test_tf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_inst
562562
train_instance_type=cpu_instance_type,
563563
sagemaker_session=sagemaker_session,
564564
framework_version=TensorFlow.LATEST_VERSION,
565-
py_version=PYTHON_VERSION,
565+
py_version="py37", # only version available with 2.2.0
566566
metric_definitions=[
567567
{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}
568568
],

tests/integ/test_tf.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp
2424

2525
import tests.integ
26-
from tests.integ import timeout
27-
from tests.integ import kms_utils
26+
from tests.integ import kms_utils, timeout, PYTHON_VERSION
2827
from tests.integ.retry import retries
2928
from tests.integ.s3_utils import assert_s3_files_exist
3029

@@ -40,13 +39,8 @@
4039
TAGS = [{"Key": "some-key", "Value": "some-value"}]
4140

4241

43-
@pytest.fixture(scope="module")
44-
def py_version(tf_full_version, tf_serving_version):
45-
return "py37" if tf_full_version == tf_serving_version else tests.integ.PYTHON_VERSION
46-
47-
4842
def test_mnist_with_checkpoint_config(
49-
sagemaker_session, instance_type, tf_full_version, py_version
43+
sagemaker_session, instance_type, tf_full_version, tf_full_py_version
5044
):
5145
checkpoint_s3_uri = "s3://{}/checkpoints/tf-{}".format(
5246
sagemaker_session.default_bucket(), sagemaker_timestamp()
@@ -59,7 +53,7 @@ def test_mnist_with_checkpoint_config(
5953
train_instance_type=instance_type,
6054
sagemaker_session=sagemaker_session,
6155
framework_version=tf_full_version,
62-
py_version="py37",
56+
py_version=tf_full_py_version,
6357
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
6458
checkpoint_s3_uri=checkpoint_s3_uri,
6559
checkpoint_local_path=checkpoint_local_path,
@@ -89,7 +83,7 @@ def test_mnist_with_checkpoint_config(
8983
assert actual_training_checkpoint_config == expected_training_checkpoint_config
9084

9185

92-
def test_server_side_encryption(sagemaker_session, tf_serving_version, py_version):
86+
def test_server_side_encryption(sagemaker_session, tf_serving_version):
9387
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
9488
output_path = os.path.join(
9589
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
@@ -103,7 +97,7 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version, py_versio
10397
train_instance_type="ml.c5.xlarge",
10498
sagemaker_session=sagemaker_session,
10599
framework_version=tf_serving_version,
106-
py_version=py_version,
100+
py_version=PYTHON_VERSION,
107101
code_location=output_path,
108102
output_path=output_path,
109103
model_dir="/opt/ml/model",
@@ -130,15 +124,15 @@ def test_server_side_encryption(sagemaker_session, tf_serving_version, py_versio
130124

131125

132126
@pytest.mark.canary_quick
133-
def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py_version):
127+
def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, tf_full_py_version):
134128
estimator = TensorFlow(
135129
entry_point=SCRIPT,
136130
role=ROLE,
137131
train_instance_count=2,
138132
train_instance_type=instance_type,
139133
sagemaker_session=sagemaker_session,
140-
py_version="py37",
141134
framework_version=tf_full_version,
135+
py_version=tf_full_py_version,
142136
distributions=PARAMETER_SERVER_DISTRIBUTION,
143137
)
144138
inputs = estimator.sagemaker_session.upload_data(
@@ -154,13 +148,13 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
154148
)
155149

156150

157-
def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_version):
151+
def test_mnist_async(sagemaker_session, cpu_instance_type):
158152
estimator = TensorFlow(
159153
entry_point=SCRIPT,
160154
role=ROLE,
161155
train_instance_count=1,
162156
train_instance_type="ml.c5.4xlarge",
163-
py_version=tests.integ.PYTHON_VERSION,
157+
py_version=PYTHON_VERSION,
164158
sagemaker_session=sagemaker_session,
165159
# testing py-sdk functionality, no need to run against all TF versions
166160
framework_version=LATEST_SERVING_VERSION,
@@ -195,18 +189,16 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
195189
_assert_model_name_match(sagemaker_session.sagemaker_client, endpoint_name, model_name)
196190

197191

198-
def test_deploy_with_input_handlers(
199-
sagemaker_session, instance_type, tf_serving_version, py_version
200-
):
192+
def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_serving_version):
201193
estimator = TensorFlow(
202194
entry_point="training.py",
203195
source_dir=TFS_RESOURCE_PATH,
204196
role=ROLE,
205197
train_instance_count=1,
206198
train_instance_type=instance_type,
207-
py_version=py_version,
208-
sagemaker_session=sagemaker_session,
209199
framework_version=tf_serving_version,
200+
py_version=PYTHON_VERSION,
201+
sagemaker_session=sagemaker_session,
210202
tags=TAGS,
211203
)
212204

tests/integ/test_tf_efs_fsx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
FSX_DIR_PATH = "/fsx/tensorflow"
3737
MAX_JOBS = 2
3838
MAX_PARALLEL_JOBS = 2
39-
PY_VERSION = "py3"
39+
PY_VERSION = "py37"
4040

4141

4242
@pytest.fixture(scope="module")
@@ -139,8 +139,8 @@ def test_tuning_tf_efs(efs_fsx_setup, sagemaker_session, cpu_instance_type):
139139
train_instance_count=1,
140140
train_instance_type=cpu_instance_type,
141141
sagemaker_session=sagemaker_session,
142-
py_version=PY_VERSION,
143142
framework_version=TensorFlow.LATEST_VERSION,
143+
py_version=PY_VERSION,
144144
subnets=subnets,
145145
security_group_ids=security_group_ids,
146146
)
@@ -186,8 +186,8 @@ def test_tuning_tf_lustre(efs_fsx_setup, sagemaker_session, cpu_instance_type):
186186
train_instance_count=1,
187187
train_instance_type=cpu_instance_type,
188188
sagemaker_session=sagemaker_session,
189-
py_version=PY_VERSION,
190189
framework_version=TensorFlow.LATEST_VERSION,
190+
py_version=PY_VERSION,
191191
subnets=subnets,
192192
security_group_ids=security_group_ids,
193193
)

0 commit comments

Comments
 (0)