Skip to content

Commit b73a795

Browse files
committed
fix: register jumpstart models on model registry
1 parent 58bb448 commit b73a795

File tree

5 files changed

+24
-21
lines changed

5 files changed

+24
-21
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,6 @@ def register(
761761
source_uri: Optional[Union[str, PipelineVariable]] = None,
762762
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
763763
accept_eula: Optional[bool] = None,
764-
765764
):
766765
"""Creates a model package for creating SageMaker models or listing on Marketplace.
767766

src/sagemaker/jumpstart/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2410,7 +2410,6 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
24102410
"model_version",
24112411
"hub_arn",
24122412
"sagemaker_session",
2413-
"model_type",
24142413
}
24152414

24162415
def __init__(

src/sagemaker/model.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ def register(
450450
source_uri: Optional[Union[str, PipelineVariable]] = None,
451451
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
452452
accept_eula: Optional[bool] = None,
453+
model_type: Optional[JumpStartModelType] = None,
453454
):
454455
"""Creates a model package for creating SageMaker models or listing on Marketplace.
455456
@@ -517,7 +518,10 @@ def register(
517518

518519
if image_uri is not None:
519520
self.image_uri = image_uri
520-
if self.model_type is not JumpStartModelType.PROPRIETARY:
521+
if model_type is JumpStartModelType.PROPRIETARY:
522+
source_uri = self.model_package_arn
523+
model_package_group_name = self.model_id
524+
else:
521525
if model_package_group_name is None and model_package_name is None:
522526
# If model package group and model package name is not set
523527
# then register to auto-generated model package group
@@ -533,23 +537,20 @@ def register(
533537
data_input_configuration=data_input_configuration,
534538
container_def=container_def,
535539
)
536-
else:
537-
container_def = {
538-
"Image": self.image_uri,
539-
}
540+
else:
541+
container_def = {
542+
"Image": self.image_uri,
543+
}
540544

541-
if isinstance(self.model_data, dict):
542-
raise ValueError(
543-
"Un-versioned SageMaker Model Package currently cannot be "
544-
"created with ModelDataSource."
545-
)
545+
if isinstance(self.model_data, dict):
546+
raise ValueError(
547+
"Un-versioned SageMaker Model Package currently cannot be "
548+
"created with ModelDataSource."
549+
)
546550

547-
if self.model_data is not None:
548-
container_def["ModelDataUrl"] = self.model_data
551+
if self.model_data is not None:
552+
container_def["ModelDataUrl"] = self.model_data
549553

550-
if self.model_type is JumpStartModelType.PROPRIETARY:
551-
source_uri = self.model_package_arn
552-
model_package_group_name = self.model_id
553554

554555
model_pkg_args = sagemaker.get_model_package_args(
555556
self.content_types,

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def test_proprietary_jumpstart_model(setup):
294294

295295
assert response is not None
296296

297+
297298
@pytest.mark.skipif(
298299
True,
299300
reason="Only enable if test account is subscribed to the proprietary model",
@@ -309,7 +310,6 @@ def test_register_proprietary_jumpstart_model(setup):
309310
sagemaker_session=get_sm_session(),
310311
)
311312
model_package = model.register()
312-
313313

314314
predictor = model_package.deploy(
315315
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}]
@@ -329,7 +329,7 @@ def test_register_proprietary_jumpstart_model(setup):
329329
)
330330
def test_register_gated_jumpstart_model(setup):
331331

332-
model_id="meta-textgenerationneuron-llama-2-7b"
332+
model_id = "meta-textgenerationneuron-llama-2-7b"
333333
model = JumpStartModel(
334334
model_id=model_id,
335335
model_version="1.1.0",
@@ -339,7 +339,8 @@ def test_register_gated_jumpstart_model(setup):
339339
model_package = model.register(accept_eula=True)
340340

341341
predictor = model_package.deploy(
342-
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], accept_eula=True
342+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
343+
accept_eula=True,
343344
)
344345
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}
345346

@@ -348,4 +349,3 @@ def test_register_gated_jumpstart_model(setup):
348349
predictor.delete_predictor()
349350

350351
assert response is not None
351-

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,8 +513,11 @@ def test_proprietary_model_endpoint(
513513
model.deploy()
514514

515515
mock_model_register.assert_called_once_with(
516+
model_type=JumpStartModelType.PROPRIETARY,
516517
content_types=["application/json"],
517518
response_types=["application/json"],
519+
model_package_group_name=model_id,
520+
source_uri=model.model_package_arn
518521
)
519522

520523
mock_model_deploy.assert_called_once_with(
@@ -1416,6 +1419,7 @@ def test_model_registry_accept_and_response_types(
14161419
model.register()
14171420

14181421
mock_model_register.assert_called_once_with(
1422+
model_type=JumpStartModelType.OPEN_WEIGHTS,
14191423
content_types=["application/x-text"],
14201424
response_types=["application/json;verbose", "application/json"],
14211425
)

0 commit comments

Comments
 (0)