Skip to content

Commit ca63a14

Browse files
author
Rui Wang Napieralski
committed
feature: add model and predictor for HuggingFace
1 parent 7b1e5c1 commit ca63a14

File tree

1 file changed

+305
-0
lines changed

1 file changed

+305
-0
lines changed

src/sagemaker/huggingface/model.py

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
import logging
17+
18+
import sagemaker
19+
from sagemaker import image_uris
20+
from sagemaker.deserializers import NumpyDeserializer
21+
from sagemaker.fw_utils import (
22+
model_code_key_prefix,
23+
validate_version_or_image_args,
24+
)
25+
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
26+
from sagemaker.predictor import Predictor
27+
from sagemaker.serializers import NumpySerializer
28+
29+
logger = logging.getLogger("sagemaker")
30+
31+
32+
class HuggingFacePredictor(Predictor):
33+
"""A Predictor for inference against HuggingFace Endpoints.
34+
35+
This is able to serialize Python lists, dictionaries, and numpy arrays to
36+
multidimensional tensors for HuggingFace inference.
37+
"""
38+
39+
def __init__(
40+
self,
41+
endpoint_name,
42+
sagemaker_session=None,
43+
serializer=NumpySerializer(),
44+
deserializer=NumpyDeserializer(),
45+
):
46+
"""Initialize an ``HuggingFacePredictor``.
47+
48+
Args:
49+
endpoint_name (str): The name of the endpoint to perform inference
50+
on.
51+
sagemaker_session (sagemaker.session.Session): Session object which
52+
manages interactions with Amazon SageMaker APIs and any other
53+
AWS services needed. If not specified, the estimator creates one
54+
using the default AWS configuration chain.
55+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
56+
serializes input data to .npy format. Handles lists and numpy
57+
arrays.
58+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
59+
Default parses the response from .npy format to numpy array.
60+
"""
61+
super(HuggingFacePredictor, self).__init__(
62+
endpoint_name,
63+
sagemaker_session,
64+
serializer=serializer,
65+
deserializer=deserializer,
66+
)
67+
68+
69+
def _validate_pt_tf_versions(pytorch_version, tensorflow_version, image_uri):
70+
71+
if image_uri is not None:
72+
return
73+
74+
if tensorflow_version is not None and pytorch_version is not None:
75+
raise ValueError(
76+
"tensorflow_version and pytorch_version are both not None. "
77+
"Specify only tensorflow_version or pytorch_version."
78+
)
79+
if tensorflow_version is None and pytorch_version is None:
80+
raise ValueError(
81+
"tensorflow_version and pytorch_version are both None. "
82+
"Specify either tensorflow_version or pytorch_version."
83+
)
84+
85+
86+
class HuggingFaceModel(FrameworkModel):
87+
"""An PyTorch SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
88+
89+
_framework_name = "huggingface"
90+
91+
def __init__(
92+
self,
93+
model_data,
94+
role,
95+
entry_point,
96+
transformers_version=None,
97+
tensorflow_version=None,
98+
pytorch_version=None,
99+
py_version=None,
100+
image_uri=None,
101+
predictor_cls=HuggingFacePredictor,
102+
model_server_workers=None,
103+
**kwargs
104+
):
105+
"""Initialize a PyTorchModel.
106+
107+
Args:
108+
model_data (str): The S3 location of a SageMaker model data
109+
``.tar.gz`` file.
110+
role (str): An AWS IAM role (either name or full ARN). The Amazon
111+
SageMaker training jobs and APIs that create Amazon SageMaker
112+
endpoints use this role to access training data and model
113+
artifacts. After the endpoint is created, the inference code
114+
might use the IAM role, if it needs to access an AWS resource.
115+
entry_point (str): Path (absolute or relative) to the Python source
116+
file which should be executed as the entry point to model
117+
hosting. If ``source_dir`` is specified, then ``entry_point``
118+
must point to a file located at the root of ``source_dir``.
119+
transformers_version (str): transformers version you want to use for
120+
executing your model training code. Defaults to None. Required
121+
unless ``image_uri`` is provided.
122+
tensorflow_version (str): TensorFlow version you want to use for
123+
executing your inference code. Defaults to ``None``. Required unless
124+
``pytorch_version`` is provided. List of supported versions:
125+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
126+
pytorch_version (str): PyTorch version you want to use for
127+
executing your inference code. Defaults to ``None``. Required unless
128+
``tensorflow_version`` is provided. List of supported versions:
129+
https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators.
130+
py_version (str): Python version you want to use for executing your
131+
model training code. Defaults to ``None``. Required unless
132+
``image_uri`` is provided.
133+
image_uri (str): A Docker image URI (default: None). If not specified, a
134+
default image for PyTorch will be used. If ``framework_version``
135+
or ``py_version`` are ``None``, then ``image_uri`` is required. If
136+
also ``None``, then a ``ValueError`` will be raised.
137+
predictor_cls (callable[str, sagemaker.session.Session]): A function
138+
to call to create a predictor with an endpoint name and
139+
SageMaker ``Session``. If specified, ``deploy()`` returns the
140+
result of invoking this function on the created endpoint name.
141+
model_server_workers (int): Optional. The number of worker processes
142+
used by the inference server. If None, server will use one
143+
worker per vCPU.
144+
**kwargs: Keyword arguments passed to the superclass
145+
:class:`~sagemaker.model.FrameworkModel` and, subsequently, its
146+
superclass :class:`~sagemaker.model.Model`.
147+
148+
.. tip::
149+
150+
You can find additional parameters for initializing this class at
151+
:class:`~sagemaker.model.FrameworkModel` and
152+
:class:`~sagemaker.model.Model`.
153+
"""
154+
validate_version_or_image_args(transformers_version, py_version, image_uri)
155+
_validate_pt_tf_versions(pytorch_version=pytorch_version,tensorflow_version=tensorflow_version,image_uri=image_uri)
156+
if py_version == "py2":
157+
raise ValueError("py2 is not supported with HuggingFace images")
158+
self.framework_version = transformers_version
159+
self.pytorch_version = pytorch_version
160+
self.tensorflow_version = tensorflow_version
161+
self.py_version = py_version
162+
163+
super(HuggingFaceModel, self).__init__(
164+
model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs
165+
)
166+
167+
self.model_server_workers = model_server_workers
168+
169+
def register(
170+
self,
171+
content_types,
172+
response_types,
173+
inference_instances,
174+
transform_instances,
175+
model_package_name=None,
176+
model_package_group_name=None,
177+
image_uri=None,
178+
model_metrics=None,
179+
metadata_properties=None,
180+
marketplace_cert=False,
181+
approval_status=None,
182+
description=None,
183+
):
184+
"""Creates a model package for creating SageMaker models or listing on Marketplace.
185+
186+
Args:
187+
content_types (list): The supported MIME types for the input data.
188+
response_types (list): The supported MIME types for the output data.
189+
inference_instances (list): A list of the instance types that are used to
190+
generate inferences in real-time.
191+
transform_instances (list): A list of the instance types on which a transformation
192+
job can be run or on which an endpoint can be deployed.
193+
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
194+
using `model_package_name` makes the Model Package un-versioned (default: None).
195+
model_package_group_name (str): Model Package Group name, exclusive to
196+
`model_package_name`, using `model_package_group_name` makes the Model Package
197+
versioned (default: None).
198+
image_uri (str): Inference image uri for the container. Model class' self.image will
199+
be used if it is None (default: None).
200+
model_metrics (ModelMetrics): ModelMetrics object (default: None).
201+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
202+
marketplace_cert (bool): A boolean value indicating if the Model Package is certified
203+
for AWS Marketplace (default: False).
204+
approval_status (str): Model Approval Status, values can be "Approved", "Rejected",
205+
or "PendingManualApproval" (default: "PendingManualApproval").
206+
description (str): Model Package description (default: None).
207+
208+
Returns:
209+
A `sagemaker.model.ModelPackage` instance.
210+
"""
211+
instance_type = inference_instances[0]
212+
self._init_sagemaker_session_if_does_not_exist(instance_type)
213+
214+
if image_uri:
215+
self.image_uri = image_uri
216+
if not self.image_uri:
217+
self.image_uri = self.serving_image_uri(
218+
region_name=self.sagemaker_session.boto_session.region_name,
219+
instance_type=instance_type,
220+
)
221+
return super(HuggingFaceModel, self).register(
222+
content_types,
223+
response_types,
224+
inference_instances,
225+
transform_instances,
226+
model_package_name,
227+
model_package_group_name,
228+
image_uri,
229+
model_metrics,
230+
metadata_properties,
231+
marketplace_cert,
232+
approval_status,
233+
description,
234+
)
235+
236+
def prepare_container_def(self, instance_type=None, accelerator_type=None):
237+
"""A container definition with framework configuration set in model environment variables.
238+
239+
Args:
240+
instance_type (str): The EC2 instance type to deploy this Model to.
241+
For example, 'ml.p2.xlarge'.
242+
accelerator_type (str): The Elastic Inference accelerator type to
243+
deploy to the instance for loading and making inferences to the
244+
model.
245+
246+
Returns:
247+
dict[str, str]: A container definition object usable with the
248+
CreateModel API.
249+
"""
250+
deploy_image = self.image_uri
251+
if not deploy_image:
252+
if instance_type is None:
253+
raise ValueError(
254+
"Must supply either an instance type (for choosing CPU vs GPU) or an image URI."
255+
)
256+
257+
region_name = self.sagemaker_session.boto_session.region_name
258+
deploy_image = self.serving_image_uri(
259+
region_name, instance_type, accelerator_type=accelerator_type
260+
)
261+
262+
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
263+
self._upload_code(deploy_key_prefix, repack=True)
264+
deploy_env = dict(self.env)
265+
deploy_env.update(self._framework_env_vars())
266+
267+
if self.model_server_workers:
268+
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
269+
return sagemaker.container_def(
270+
deploy_image, self.repacked_model_data or self.model_data, deploy_env
271+
)
272+
273+
def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
274+
"""Create a URI for the serving image.
275+
276+
Args:
277+
region_name (str): AWS region where the image is uploaded.
278+
instance_type (str): SageMaker instance type. Used to determine device type
279+
(cpu/gpu/family-specific optimized).
280+
accelerator_type (str): The Elastic Inference accelerator type to
281+
deploy to the instance for loading and making inferences to the
282+
model.
283+
284+
Returns:
285+
str: The appropriate image URI based on the given parameters.
286+
287+
"""
288+
if self.tensorflow_version is not None: # pylint: disable=no-member
289+
base_framework_version = (
290+
f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member
291+
)
292+
else:
293+
base_framework_version = (
294+
f"pytorch{self.pytorch_version}" # pylint: disable=no-member
295+
)
296+
return image_uris.retrieve(
297+
self._framework_name,
298+
region_name,
299+
version=self.framework_version,
300+
py_version=self.py_version,
301+
instance_type=instance_type,
302+
accelerator_type=accelerator_type,
303+
image_scope="inference",
304+
base_framework_version=base_framework_version,
305+
)

0 commit comments

Comments
 (0)