Skip to content

Commit 00c3a6c

Browse files
authored
fix: estimator.deploy not respecting instance type (#4724)
* fix: estimator.deploy not respecting instance type * chore: add inline comment about using user supplied instance type
1 parent 91ebd74 commit 00c3a6c

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,12 @@ def get_deploy_kwargs(
322322
model_id=model_id,
323323
model_from_estimator=True,
324324
model_version=model_version,
325-
instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None,
325+
instance_type=(
326+
model_deploy_kwargs.instance_type
327+
if training_instance_type is None
328+
or instance_type is not None # always use supplied inference instance type
329+
else None
330+
),
326331
region=region,
327332
image_uri=image_uri,
328333
source_dir=source_dir,

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,9 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta
15321532
estimator.deploy(image_uri="blah")
15331533
assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.p4de.24xlarge"
15341534

1535+
estimator.deploy(image_uri="blah", instance_type="ml.quantum.large")
1536+
assert mock_estimator_deploy.call_args[1]["instance_type"] == "ml.quantum.large"
1537+
15351538
@mock.patch("sagemaker.utils.sagemaker_timestamp")
15361539
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
15371540
@mock.patch(

0 commit comments

Comments
 (0)