Skip to content

Commit f0eff66

Browse files
author
Rui Wang Napieralski
committed
reformat
1 parent ca63a14 commit f0eff66

File tree

1 file changed

+36
-34
lines changed

1 file changed

+36
-34
lines changed

src/sagemaker/huggingface/model.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ class HuggingFacePredictor(Predictor):
3737
"""
3838

3939
def __init__(
40-
self,
41-
endpoint_name,
42-
sagemaker_session=None,
43-
serializer=NumpySerializer(),
44-
deserializer=NumpyDeserializer(),
40+
self,
41+
endpoint_name,
42+
sagemaker_session=None,
43+
serializer=NumpySerializer(),
44+
deserializer=NumpyDeserializer(),
4545
):
4646
"""Initialize an ``HuggingFacePredictor``.
4747
@@ -89,18 +89,18 @@ class HuggingFaceModel(FrameworkModel):
8989
_framework_name = "huggingface"
9090

9191
def __init__(
92-
self,
93-
model_data,
94-
role,
95-
entry_point,
96-
transformers_version=None,
97-
tensorflow_version=None,
98-
pytorch_version=None,
99-
py_version=None,
100-
image_uri=None,
101-
predictor_cls=HuggingFacePredictor,
102-
model_server_workers=None,
103-
**kwargs
92+
self,
93+
model_data,
94+
role,
95+
entry_point,
96+
transformers_version=None,
97+
tensorflow_version=None,
98+
pytorch_version=None,
99+
py_version=None,
100+
image_uri=None,
101+
predictor_cls=HuggingFacePredictor,
102+
model_server_workers=None,
103+
**kwargs,
104104
):
105105
"""Initialize a PyTorchModel.
106106
@@ -152,7 +152,11 @@ def __init__(
152152
:class:`~sagemaker.model.Model`.
153153
"""
154154
validate_version_or_image_args(transformers_version, py_version, image_uri)
155-
_validate_pt_tf_versions(pytorch_version=pytorch_version,tensorflow_version=tensorflow_version,image_uri=image_uri)
155+
_validate_pt_tf_versions(
156+
pytorch_version=pytorch_version,
157+
tensorflow_version=tensorflow_version,
158+
image_uri=image_uri,
159+
)
156160
if py_version == "py2":
157161
raise ValueError("py2 is not supported with HuggingFace images")
158162
self.framework_version = transformers_version
@@ -167,19 +171,19 @@ def __init__(
167171
self.model_server_workers = model_server_workers
168172

169173
def register(
170-
self,
171-
content_types,
172-
response_types,
173-
inference_instances,
174-
transform_instances,
175-
model_package_name=None,
176-
model_package_group_name=None,
177-
image_uri=None,
178-
model_metrics=None,
179-
metadata_properties=None,
180-
marketplace_cert=False,
181-
approval_status=None,
182-
description=None,
174+
self,
175+
content_types,
176+
response_types,
177+
inference_instances,
178+
transform_instances,
179+
model_package_name=None,
180+
model_package_group_name=None,
181+
image_uri=None,
182+
model_metrics=None,
183+
metadata_properties=None,
184+
marketplace_cert=False,
185+
approval_status=None,
186+
description=None,
183187
):
184188
"""Creates a model package for creating SageMaker models or listing on Marketplace.
185189
@@ -290,9 +294,7 @@ def serving_image_uri(self, region_name, instance_type, accelerator_type=None):
290294
f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member
291295
)
292296
else:
293-
base_framework_version = (
294-
f"pytorch{self.pytorch_version}" # pylint: disable=no-member
295-
)
297+
base_framework_version = f"pytorch{self.pytorch_version}" # pylint: disable=no-member
296298
return image_uris.retrieve(
297299
self._framework_name,
298300
region_name,

0 commit comments

Comments
 (0)