|
| 1 | +# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""Placeholder docstring""" |
| 14 | +from __future__ import absolute_import |
| 15 | + |
| 16 | +import logging |
| 17 | + |
| 18 | +import sagemaker |
| 19 | +from sagemaker import image_uris |
| 20 | +from sagemaker.deserializers import NumpyDeserializer |
| 21 | +from sagemaker.fw_utils import ( |
| 22 | + model_code_key_prefix, |
| 23 | + validate_version_or_image_args, |
| 24 | +) |
| 25 | +from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME |
| 26 | +from sagemaker.predictor import Predictor |
| 27 | +from sagemaker.serializers import NumpySerializer |
| 28 | + |
| 29 | +logger = logging.getLogger("sagemaker") |
| 30 | + |
| 31 | + |
| 32 | +class HuggingFacePredictor(Predictor): |
| 33 | + """A Predictor for inference against HuggingFace Endpoints. |
| 34 | +
|
| 35 | + This is able to serialize Python lists, dictionaries, and numpy arrays to |
| 36 | + multidimensional tensors for HuggingFace inference. |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + endpoint_name, |
| 42 | + sagemaker_session=None, |
| 43 | + serializer=NumpySerializer(), |
| 44 | + deserializer=NumpyDeserializer(), |
| 45 | + ): |
| 46 | + """Initialize an ``HuggingFacePredictor``. |
| 47 | +
|
| 48 | + Args: |
| 49 | + endpoint_name (str): The name of the endpoint to perform inference |
| 50 | + on. |
| 51 | + sagemaker_session (sagemaker.session.Session): Session object which |
| 52 | + manages interactions with Amazon SageMaker APIs and any other |
| 53 | + AWS services needed. If not specified, the estimator creates one |
| 54 | + using the default AWS configuration chain. |
| 55 | + serializer (sagemaker.serializers.BaseSerializer): Optional. Default |
| 56 | + serializes input data to .npy format. Handles lists and numpy |
| 57 | + arrays. |
| 58 | + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. |
| 59 | + Default parses the response from .npy format to numpy array. |
| 60 | + """ |
| 61 | + super(HuggingFacePredictor, self).__init__( |
| 62 | + endpoint_name, |
| 63 | + sagemaker_session, |
| 64 | + serializer=serializer, |
| 65 | + deserializer=deserializer, |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +def _validate_pt_tf_versions(pytorch_version, tensorflow_version, image_uri): |
| 70 | + |
| 71 | + if image_uri is not None: |
| 72 | + return |
| 73 | + |
| 74 | + if tensorflow_version is not None and pytorch_version is not None: |
| 75 | + raise ValueError( |
| 76 | + "tensorflow_version and pytorch_version are both not None. " |
| 77 | + "Specify only tensorflow_version or pytorch_version." |
| 78 | + ) |
| 79 | + if tensorflow_version is None and pytorch_version is None: |
| 80 | + raise ValueError( |
| 81 | + "tensorflow_version and pytorch_version are both None. " |
| 82 | + "Specify either tensorflow_version or pytorch_version." |
| 83 | + ) |
| 84 | + |
| 85 | + |
| 86 | +class HuggingFaceModel(FrameworkModel): |
| 87 | + """An PyTorch SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``.""" |
| 88 | + |
| 89 | + _framework_name = "huggingface" |
| 90 | + |
| 91 | + def __init__( |
| 92 | + self, |
| 93 | + model_data, |
| 94 | + role, |
| 95 | + entry_point, |
| 96 | + transformers_version=None, |
| 97 | + tensorflow_version=None, |
| 98 | + pytorch_version=None, |
| 99 | + py_version=None, |
| 100 | + image_uri=None, |
| 101 | + predictor_cls=HuggingFacePredictor, |
| 102 | + model_server_workers=None, |
| 103 | + **kwargs |
| 104 | + ): |
| 105 | + """Initialize a PyTorchModel. |
| 106 | +
|
| 107 | + Args: |
| 108 | + model_data (str): The S3 location of a SageMaker model data |
| 109 | + ``.tar.gz`` file. |
| 110 | + role (str): An AWS IAM role (either name or full ARN). The Amazon |
| 111 | + SageMaker training jobs and APIs that create Amazon SageMaker |
| 112 | + endpoints use this role to access training data and model |
| 113 | + artifacts. After the endpoint is created, the inference code |
| 114 | + might use the IAM role, if it needs to access an AWS resource. |
| 115 | + entry_point (str): Path (absolute or relative) to the Python source |
| 116 | + file which should be executed as the entry point to model |
| 117 | + hosting. If ``source_dir`` is specified, then ``entry_point`` |
| 118 | + must point to a file located at the root of ``source_dir``. |
| 119 | + transformers_version (str): transformers version you want to use for |
| 120 | + executing your model training code. Defaults to None. Required |
| 121 | + unless ``image_uri`` is provided. |
| 122 | + tensorflow_version (str): TensorFlow version you want to use for |
| 123 | + executing your inference code. Defaults to ``None``. Required unless |
| 124 | + ``pytorch_version`` is provided. List of supported versions: |
| 125 | + https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators. |
| 126 | + pytorch_version (str): PyTorch version you want to use for |
| 127 | + executing your inference code. Defaults to ``None``. Required unless |
| 128 | + ``tensorflow_version`` is provided. List of supported versions: |
| 129 | + https://github.com/aws/sagemaker-python-sdk#huggingface-sagemaker-estimators. |
| 130 | + py_version (str): Python version you want to use for executing your |
| 131 | + model training code. Defaults to ``None``. Required unless |
| 132 | + ``image_uri`` is provided. |
| 133 | + image_uri (str): A Docker image URI (default: None). If not specified, a |
| 134 | + default image for PyTorch will be used. If ``framework_version`` |
| 135 | + or ``py_version`` are ``None``, then ``image_uri`` is required. If |
| 136 | + also ``None``, then a ``ValueError`` will be raised. |
| 137 | + predictor_cls (callable[str, sagemaker.session.Session]): A function |
| 138 | + to call to create a predictor with an endpoint name and |
| 139 | + SageMaker ``Session``. If specified, ``deploy()`` returns the |
| 140 | + result of invoking this function on the created endpoint name. |
| 141 | + model_server_workers (int): Optional. The number of worker processes |
| 142 | + used by the inference server. If None, server will use one |
| 143 | + worker per vCPU. |
| 144 | + **kwargs: Keyword arguments passed to the superclass |
| 145 | + :class:`~sagemaker.model.FrameworkModel` and, subsequently, its |
| 146 | + superclass :class:`~sagemaker.model.Model`. |
| 147 | +
|
| 148 | + .. tip:: |
| 149 | +
|
| 150 | + You can find additional parameters for initializing this class at |
| 151 | + :class:`~sagemaker.model.FrameworkModel` and |
| 152 | + :class:`~sagemaker.model.Model`. |
| 153 | + """ |
| 154 | + validate_version_or_image_args(transformers_version, py_version, image_uri) |
| 155 | + _validate_pt_tf_versions(pytorch_version=pytorch_version,tensorflow_version=tensorflow_version,image_uri=image_uri) |
| 156 | + if py_version == "py2": |
| 157 | + raise ValueError("py2 is not supported with HuggingFace images") |
| 158 | + self.framework_version = transformers_version |
| 159 | + self.pytorch_version = pytorch_version |
| 160 | + self.tensorflow_version = tensorflow_version |
| 161 | + self.py_version = py_version |
| 162 | + |
| 163 | + super(HuggingFaceModel, self).__init__( |
| 164 | + model_data, image_uri, role, entry_point, predictor_cls=predictor_cls, **kwargs |
| 165 | + ) |
| 166 | + |
| 167 | + self.model_server_workers = model_server_workers |
| 168 | + |
| 169 | + def register( |
| 170 | + self, |
| 171 | + content_types, |
| 172 | + response_types, |
| 173 | + inference_instances, |
| 174 | + transform_instances, |
| 175 | + model_package_name=None, |
| 176 | + model_package_group_name=None, |
| 177 | + image_uri=None, |
| 178 | + model_metrics=None, |
| 179 | + metadata_properties=None, |
| 180 | + marketplace_cert=False, |
| 181 | + approval_status=None, |
| 182 | + description=None, |
| 183 | + ): |
| 184 | + """Creates a model package for creating SageMaker models or listing on Marketplace. |
| 185 | +
|
| 186 | + Args: |
| 187 | + content_types (list): The supported MIME types for the input data. |
| 188 | + response_types (list): The supported MIME types for the output data. |
| 189 | + inference_instances (list): A list of the instance types that are used to |
| 190 | + generate inferences in real-time. |
| 191 | + transform_instances (list): A list of the instance types on which a transformation |
| 192 | + job can be run or on which an endpoint can be deployed. |
| 193 | + model_package_name (str): Model Package name, exclusive to `model_package_group_name`, |
| 194 | + using `model_package_name` makes the Model Package un-versioned (default: None). |
| 195 | + model_package_group_name (str): Model Package Group name, exclusive to |
| 196 | + `model_package_name`, using `model_package_group_name` makes the Model Package |
| 197 | + versioned (default: None). |
| 198 | + image_uri (str): Inference image uri for the container. Model class' self.image will |
| 199 | + be used if it is None (default: None). |
| 200 | + model_metrics (ModelMetrics): ModelMetrics object (default: None). |
| 201 | + metadata_properties (MetadataProperties): MetadataProperties object (default: None). |
| 202 | + marketplace_cert (bool): A boolean value indicating if the Model Package is certified |
| 203 | + for AWS Marketplace (default: False). |
| 204 | + approval_status (str): Model Approval Status, values can be "Approved", "Rejected", |
| 205 | + or "PendingManualApproval" (default: "PendingManualApproval"). |
| 206 | + description (str): Model Package description (default: None). |
| 207 | +
|
| 208 | + Returns: |
| 209 | + A `sagemaker.model.ModelPackage` instance. |
| 210 | + """ |
| 211 | + instance_type = inference_instances[0] |
| 212 | + self._init_sagemaker_session_if_does_not_exist(instance_type) |
| 213 | + |
| 214 | + if image_uri: |
| 215 | + self.image_uri = image_uri |
| 216 | + if not self.image_uri: |
| 217 | + self.image_uri = self.serving_image_uri( |
| 218 | + region_name=self.sagemaker_session.boto_session.region_name, |
| 219 | + instance_type=instance_type, |
| 220 | + ) |
| 221 | + return super(HuggingFaceModel, self).register( |
| 222 | + content_types, |
| 223 | + response_types, |
| 224 | + inference_instances, |
| 225 | + transform_instances, |
| 226 | + model_package_name, |
| 227 | + model_package_group_name, |
| 228 | + image_uri, |
| 229 | + model_metrics, |
| 230 | + metadata_properties, |
| 231 | + marketplace_cert, |
| 232 | + approval_status, |
| 233 | + description, |
| 234 | + ) |
| 235 | + |
| 236 | + def prepare_container_def(self, instance_type=None, accelerator_type=None): |
| 237 | + """A container definition with framework configuration set in model environment variables. |
| 238 | +
|
| 239 | + Args: |
| 240 | + instance_type (str): The EC2 instance type to deploy this Model to. |
| 241 | + For example, 'ml.p2.xlarge'. |
| 242 | + accelerator_type (str): The Elastic Inference accelerator type to |
| 243 | + deploy to the instance for loading and making inferences to the |
| 244 | + model. |
| 245 | +
|
| 246 | + Returns: |
| 247 | + dict[str, str]: A container definition object usable with the |
| 248 | + CreateModel API. |
| 249 | + """ |
| 250 | + deploy_image = self.image_uri |
| 251 | + if not deploy_image: |
| 252 | + if instance_type is None: |
| 253 | + raise ValueError( |
| 254 | + "Must supply either an instance type (for choosing CPU vs GPU) or an image URI." |
| 255 | + ) |
| 256 | + |
| 257 | + region_name = self.sagemaker_session.boto_session.region_name |
| 258 | + deploy_image = self.serving_image_uri( |
| 259 | + region_name, instance_type, accelerator_type=accelerator_type |
| 260 | + ) |
| 261 | + |
| 262 | + deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) |
| 263 | + self._upload_code(deploy_key_prefix, repack=True) |
| 264 | + deploy_env = dict(self.env) |
| 265 | + deploy_env.update(self._framework_env_vars()) |
| 266 | + |
| 267 | + if self.model_server_workers: |
| 268 | + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) |
| 269 | + return sagemaker.container_def( |
| 270 | + deploy_image, self.repacked_model_data or self.model_data, deploy_env |
| 271 | + ) |
| 272 | + |
| 273 | + def serving_image_uri(self, region_name, instance_type, accelerator_type=None): |
| 274 | + """Create a URI for the serving image. |
| 275 | +
|
| 276 | + Args: |
| 277 | + region_name (str): AWS region where the image is uploaded. |
| 278 | + instance_type (str): SageMaker instance type. Used to determine device type |
| 279 | + (cpu/gpu/family-specific optimized). |
| 280 | + accelerator_type (str): The Elastic Inference accelerator type to |
| 281 | + deploy to the instance for loading and making inferences to the |
| 282 | + model. |
| 283 | +
|
| 284 | + Returns: |
| 285 | + str: The appropriate image URI based on the given parameters. |
| 286 | +
|
| 287 | + """ |
| 288 | + if self.tensorflow_version is not None: # pylint: disable=no-member |
| 289 | + base_framework_version = ( |
| 290 | + f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member |
| 291 | + ) |
| 292 | + else: |
| 293 | + base_framework_version = ( |
| 294 | + f"pytorch{self.pytorch_version}" # pylint: disable=no-member |
| 295 | + ) |
| 296 | + return image_uris.retrieve( |
| 297 | + self._framework_name, |
| 298 | + region_name, |
| 299 | + version=self.framework_version, |
| 300 | + py_version=self.py_version, |
| 301 | + instance_type=instance_type, |
| 302 | + accelerator_type=accelerator_type, |
| 303 | + image_scope="inference", |
| 304 | + base_framework_version=base_framework_version, |
| 305 | + ) |
0 commit comments