Skip to content

Commit f57b918

Browse files
authored
Merge branch 'zwei' into deprecate-version-comparison-functions
2 parents f73db16 + 23af3b1 commit f57b918

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1531
-550
lines changed

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ For example, if you want to use JSON serialization and deserialization:
248248
from sagemaker.deserializers import JSONDeserializer
249249
from sagemaker.serializers import JSONSerializer
250250
251-
predictor.serializer = JSONSerializer()
252-
predictor.deserializer = JSONDeserializer()
251+
predictor = model.deploy(..., serializer=JSONSerializer(), deserializer=JSONDeserializer())
253252
254253
predictor.predict(data)

doc/frameworks/xgboost/using_xgboost.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,14 @@ inference against your model.
192192

193193
.. code::
194194
195+
serializer = StringSerializer()
196+
serializer.CONTENT_TYPE = "text/libsvm"
197+
195198
predictor = estimator.deploy(
196199
initial_instance_count=1,
197-
instance_type="ml.m5.xlarge"
200+
instance_type="ml.m5.xlarge",
201+
serializer=serializer
198202
)
199-
predictor.serializer = str
200-
predictor.content_type = "text/libsvm"
201203
202204
with open("abalone") as f:
203205
payload = f.read()

doc/v2.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ Please instantiate the objects instead.
8181
The ``update_endpoint`` argument in ``deploy()`` methods for estimators and models has been deprecated.
8282
Please use :func:`sagemaker.predictor.Predictor.update_endpoint` instead.
8383

84+
``serializer`` and ``deserializer`` in ``create_model()``
85+
---------------------------------------------------------
86+
87+
The ``serializer`` and ``deserializer`` arguments in
88+
:func:`sagemaker.estimator.Estimator.create_model` have been deprecated. Please
89+
specify serializers and deserializers in ``deploy()`` methods instead.
90+
8491
``content_type`` and ``accept`` in the Predictor Constructor
8592
------------------------------------------------------------
8693

@@ -196,6 +203,12 @@ To view logs after attaching a training job to an estimator, use :func:`sagemake
196203
until the completion of the Hyperparameter Tuning Job or Batch Transform Job, respectively.
197204
To make the function non-blocking, use ``wait=False``.
198205

206+
XGBoost Predictor
207+
-----------------
208+
209+
The default serializer of ``sagemaker.xgboost.model.XGBoostPredictor`` has been changed from ``NumpySerializer`` to ``LibSVMSerializer``.
210+
211+
199212
Parameter and Class Name Changes
200213
================================
201214

@@ -256,6 +269,8 @@ The follow serializer/deserializer classes have been renamed and/or moved:
256269
| ``sagemaker.predictor._JsonDeserializer`` | ``sagemaker.deserializers.JSONDeserializer`` |
257270
+--------------------------------------------------------+-------------------------------------------------------+
258271

272+
``sagemaker.serializers.LibSVMSerializer`` has been added in v2.0.
273+
259274
``distributions``
260275
~~~~~~~~~~~~~~~~~
261276

src/sagemaker/algorithm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,13 +229,13 @@ def hyperparameters(self):
229229
"""
230230
return self.hyperparam_dict
231231

232-
def train_image(self):
232+
def training_image_uri(self):
233233
"""Returns the docker image to use for training.
234234
235235
The fit() method, that does the model training, calls this method to
236236
find the image to use for model training.
237237
"""
238-
raise RuntimeError("train_image is never meant to be called on Algorithm Estimators")
238+
raise RuntimeError("training_image_uri is never meant to be called on Algorithm Estimators")
239239

240240
def enable_network_isolation(self):
241241
"""Return True if this Estimator will need network isolation to run.

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(
9191
)
9292
self._data_location = data_location
9393

94-
def train_image(self):
94+
def training_image_uri(self):
9595
"""Placeholder docstring"""
9696
return image_uris.retrieve(
9797
self.repo_name, self.sagemaker_session.boto_region_name, version=self.repo_version,

src/sagemaker/automl/automl.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ def deploy(
337337
self,
338338
initial_instance_count,
339339
instance_type,
340+
serializer=None,
341+
deserializer=None,
340342
candidate=None,
341343
sagemaker_session=None,
342344
name=None,
@@ -356,6 +358,16 @@ def deploy(
356358
in the ``Endpoint`` created from this ``Model``.
357359
instance_type (str): The EC2 instance type to deploy this Model to.
358360
For example, 'ml.p2.xlarge'.
361+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
362+
serializer object, used to encode data for an inference endpoint
363+
(default: None). If ``serializer`` is not None, then
364+
``serializer`` will override the default serializer. The
365+
default serializer is set by the ``predictor_cls``.
366+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
367+
deserializer object, used to decode data from an inference
368+
endpoint (default: None). If ``deserializer`` is not None, then
369+
``deserializer`` will override the default deserializer. The
370+
default deserializer is set by the ``predictor_cls``.
359371
candidate (CandidateEstimator or dict): a CandidateEstimator used for deploying
360372
to a SageMaker Inference Pipeline. If None, the best candidate will
361373
be used. If the candidate input is a dict, a CandidateEstimator will be
@@ -405,6 +417,8 @@ def deploy(
405417
return model.deploy(
406418
initial_instance_count=initial_instance_count,
407419
instance_type=instance_type,
420+
serializer=serializer,
421+
deserializer=deserializer,
408422
endpoint_name=endpoint_name,
409423
tags=tags,
410424
wait=wait,

src/sagemaker/cli/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying athis file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tools for automating code updates"""
14+
from __future__ import absolute_import

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
3636
modifiers.training_params.TrainPrefixRemover(),
3737
modifiers.training_input.TrainingInputConstructorRefactor(),
38+
modifiers.training_input.ShuffleConfigModuleRenamer(),
3839
modifiers.serde.SerdeConstructorRenamer(),
3940
]
4041

@@ -51,6 +52,7 @@
5152
modifiers.predictors.PredictorImportFromRenamer(),
5253
modifiers.tfs.TensorFlowServingImportFromRenamer(),
5354
modifiers.training_input.TrainingInputImportFromRenamer(),
55+
modifiers.training_input.ShuffleConfigImportFromRenamer(),
5456
modifiers.serde.SerdeImportFromAmazonCommonRenamer(),
5557
modifiers.serde.SerdeImportFromPredictorRenamer(),
5658
]

src/sagemaker/cli/compatibility/v2/modifiers/training_input.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,73 @@ def modify_node(self, node):
100100
if node.module == "sagemaker.session":
101101
node.module = "sagemaker.inputs"
102102
return node
103+
104+
105+
class ShuffleConfigModuleRenamer(Modifier):
106+
"""A class to change ``ShuffleConfig`` usage to use ``sagemaker.inputs.ShuffleConfig``."""
107+
108+
def node_should_be_modified(self, node):
109+
"""Checks if the ``ast.Call`` node instantiates a class of interest.
110+
111+
This looks for the following calls:
112+
113+
- ``sagemaker.session.ShuffleConfig``
114+
- ``session.ShuffleConfig``
115+
116+
Args:
117+
node (ast.Call): a node that represents a function call. For more,
118+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
119+
120+
Returns:
121+
bool: If the ``ast.Call`` instantiates a class of interest.
122+
"""
123+
if isinstance(node.func, ast.Name):
124+
return False
125+
126+
return matching.matches_name_or_namespaces(
127+
node, "ShuffleConfig", ("sagemaker.session", "session")
128+
)
129+
130+
def modify_node(self, node):
131+
"""Modifies the ``ast.Call`` node to call ``sagemaker.inputs.ShuffleConfig``.
132+
133+
Args:
134+
node (ast.Call): a node that represents a ``sagemaker.session.ShuffleConfig``
135+
constructor.
136+
137+
Returns:
138+
ast.Call: the original node, with its namespace changed to use the ``inputs`` module.
139+
"""
140+
_rename_namespace(node, "session")
141+
return node
142+
143+
144+
class ShuffleConfigImportFromRenamer(Modifier):
145+
"""A class to update import statements of ``ShuffleConfig``."""
146+
147+
def node_should_be_modified(self, node):
148+
"""Checks if the import statement imports ``sagemaker.session.ShuffleConfig``.
149+
150+
Args:
151+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
152+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
153+
154+
Returns:
155+
bool: If the import statement imports ``sagemaker.session.ShuffleConfig``.
156+
"""
157+
return node.module == "sagemaker.session" and any(
158+
name.name == "ShuffleConfig" for name in node.names
159+
)
160+
161+
def modify_node(self, node):
162+
"""Changes the ``ast.ImportFrom`` node's namespace to ``sagemaker.inputs``.
163+
164+
Args:
165+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
166+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
167+
168+
Returns:
169+
ast.ImportFrom: the original node, with its module modified to ``"sagemaker.inputs"``.
170+
"""
171+
node.module = "sagemaker.inputs"
172+
return node

0 commit comments

Comments
 (0)