Skip to content

Commit 90c9fbe

Browse files
committed
Fix: Add Image URI overrides for transformers models
1 parent c9b55a4 commit 90c9fbe

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,25 @@ def _prepare_for_mode(self):
7878
"""Abstract method"""
7979

8080
def _create_transformers_model(self) -> Type[Model]:
81+
"""Initializes HF model with or without image_uri"""
82+
if self.image_uri is None:
83+
pysdk_model = self._get_hf_metadata_create_model()
84+
else:
85+
pysdk_model = HuggingFaceModel(
86+
image_uri=self.image_uri,
87+
vpc_config=self.vpc_config,
88+
env=self.env_vars,
89+
role=self.role_arn,
90+
sagemaker_session=self.sagemaker_session,
91+
)
92+
93+
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
94+
95+
self._original_deploy = pysdk_model.deploy
96+
pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
97+
return pysdk_model
98+
99+
def _get_hf_metadata_create_model(self) -> Type[Model]:
81100
"""Initializes the model after fetching image
82101
83102
1. Get the metadata for deciding framework
@@ -132,22 +151,21 @@ def _create_transformers_model(self) -> Type[Model]:
132151
vpc_config=self.vpc_config,
133152
)
134153

135-
if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER:
154+
if self.mode == Mode.LOCAL_CONTAINER:
136155
self.image_uri = pysdk_model.serving_image_uri(
137156
self.sagemaker_session.boto_region_name, "local"
138157
)
139-
elif not self.image_uri:
158+
else:
140159
self.image_uri = pysdk_model.serving_image_uri(
141160
self.sagemaker_session.boto_region_name, self.instance_type
142161
)
143162

144-
logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
163+
if pysdk_model is None or self.image_uri is None:
164+
raise ValueError("PySDK model unable to be created, try overriding image_uri")
145165

146166
if not pysdk_model.image_uri:
147167
pysdk_model.image_uri = self.image_uri
148168

149-
self._original_deploy = pysdk_model.deploy
150-
pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
151169
return pysdk_model
152170

153171
@_capture_telemetry("transformers.deploy")

0 commit comments

Comments
 (0)