@@ -37,11 +37,11 @@ class HuggingFacePredictor(Predictor):
37
37
"""
38
38
39
39
def __init__ (
40
- self ,
41
- endpoint_name ,
42
- sagemaker_session = None ,
43
- serializer = NumpySerializer (),
44
- deserializer = NumpyDeserializer (),
40
+ self ,
41
+ endpoint_name ,
42
+ sagemaker_session = None ,
43
+ serializer = NumpySerializer (),
44
+ deserializer = NumpyDeserializer (),
45
45
):
46
46
"""Initialize an ``HuggingFacePredictor``.
47
47
@@ -89,18 +89,18 @@ class HuggingFaceModel(FrameworkModel):
89
89
_framework_name = "huggingface"
90
90
91
91
def __init__ (
92
- self ,
93
- model_data ,
94
- role ,
95
- entry_point ,
96
- transformers_version = None ,
97
- tensorflow_version = None ,
98
- pytorch_version = None ,
99
- py_version = None ,
100
- image_uri = None ,
101
- predictor_cls = HuggingFacePredictor ,
102
- model_server_workers = None ,
103
- ** kwargs
92
+ self ,
93
+ model_data ,
94
+ role ,
95
+ entry_point ,
96
+ transformers_version = None ,
97
+ tensorflow_version = None ,
98
+ pytorch_version = None ,
99
+ py_version = None ,
100
+ image_uri = None ,
101
+ predictor_cls = HuggingFacePredictor ,
102
+ model_server_workers = None ,
103
+ ** kwargs ,
104
104
):
105
105
"""Initialize a PyTorchModel.
106
106
@@ -152,7 +152,11 @@ def __init__(
152
152
:class:`~sagemaker.model.Model`.
153
153
"""
154
154
validate_version_or_image_args (transformers_version , py_version , image_uri )
155
- _validate_pt_tf_versions (pytorch_version = pytorch_version ,tensorflow_version = tensorflow_version ,image_uri = image_uri )
155
+ _validate_pt_tf_versions (
156
+ pytorch_version = pytorch_version ,
157
+ tensorflow_version = tensorflow_version ,
158
+ image_uri = image_uri ,
159
+ )
156
160
if py_version == "py2" :
157
161
raise ValueError ("py2 is not supported with HuggingFace images" )
158
162
self .framework_version = transformers_version
@@ -167,19 +171,19 @@ def __init__(
167
171
self .model_server_workers = model_server_workers
168
172
169
173
def register (
170
- self ,
171
- content_types ,
172
- response_types ,
173
- inference_instances ,
174
- transform_instances ,
175
- model_package_name = None ,
176
- model_package_group_name = None ,
177
- image_uri = None ,
178
- model_metrics = None ,
179
- metadata_properties = None ,
180
- marketplace_cert = False ,
181
- approval_status = None ,
182
- description = None ,
174
+ self ,
175
+ content_types ,
176
+ response_types ,
177
+ inference_instances ,
178
+ transform_instances ,
179
+ model_package_name = None ,
180
+ model_package_group_name = None ,
181
+ image_uri = None ,
182
+ model_metrics = None ,
183
+ metadata_properties = None ,
184
+ marketplace_cert = False ,
185
+ approval_status = None ,
186
+ description = None ,
183
187
):
184
188
"""Creates a model package for creating SageMaker models or listing on Marketplace.
185
189
@@ -290,9 +294,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
290
294
f"tensorflow{ self .tensorflow_version } " # pylint: disable=no-member
291
295
)
292
296
else :
293
- base_framework_version = (
294
- f"pytorch{ self .pytorch_version } " # pylint: disable=no-member
295
- )
297
+ base_framework_version = f"pytorch{ self .pytorch_version } " # pylint: disable=no-member
296
298
return image_uris .retrieve (
297
299
self ._framework_name ,
298
300
region_name ,
0 commit comments