File tree 2 files changed +9
-1
lines changed
src/sagemaker/jumpstart/factory
tests/unit/sagemaker/jumpstart/estimator
2 files changed +9
-1
lines changed Original file line number Diff line number Diff line change @@ -322,7 +322,12 @@ def get_deploy_kwargs(
322
322
model_id = model_id ,
323
323
model_from_estimator = True ,
324
324
model_version = model_version ,
325
- instance_type = model_deploy_kwargs .instance_type if training_instance_type is None else None ,
325
+ instance_type = (
326
+ model_deploy_kwargs .instance_type
327
+ if training_instance_type is None
328
+ or instance_type is not None # always use supplied inference instance type
329
+ else None
330
+ ),
326
331
region = region ,
327
332
image_uri = image_uri ,
328
333
source_dir = source_dir ,
Original file line number Diff line number Diff line change @@ -1532,6 +1532,9 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta
1532
1532
estimator .deploy (image_uri = "blah" )
1533
1533
assert mock_estimator_deploy .call_args [1 ]["instance_type" ] == "ml.p4de.24xlarge"
1534
1534
1535
+ estimator .deploy (image_uri = "blah" , instance_type = "ml.quantum.large" )
1536
+ assert mock_estimator_deploy .call_args [1 ]["instance_type" ] == "ml.quantum.large"
1537
+
1535
1538
@mock .patch ("sagemaker.utils.sagemaker_timestamp" )
1536
1539
@mock .patch ("sagemaker.jumpstart.estimator.validate_model_id_and_get_type" )
1537
1540
@mock .patch (
You can’t perform that action at this time.
0 commit comments