Skip to content

Commit 0965e8d

Browse files
committed
fix: register jumpstart models on model registry
1 parent 58bb448 commit 0965e8d

File tree

4 files changed

+7
-8
lines changed

4 files changed

+7
-8
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,6 @@ def register(
761761
source_uri: Optional[Union[str, PipelineVariable]] = None,
762762
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
763763
accept_eula: Optional[bool] = None,
764-
765764
):
766765
"""Creates a model package for creating SageMaker models or listing on Marketplace.
767766

src/sagemaker/jumpstart/types.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2410,7 +2410,6 @@ class JumpStartModelRegisterKwargs(JumpStartKwargs):
24102410
"model_version",
24112411
"hub_arn",
24122412
"sagemaker_session",
2413-
"model_type",
24142413
}
24152414

24162415
def __init__(

src/sagemaker/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ def register(
450450
source_uri: Optional[Union[str, PipelineVariable]] = None,
451451
model_card: Optional[Union[ModelPackageModelCard, ModelCard]] = None,
452452
accept_eula: Optional[bool] = None,
453+
model_type: Optional[JumpStartModelType] = None,
453454
):
454455
"""Creates a model package for creating SageMaker models or listing on Marketplace.
455456
@@ -517,7 +518,7 @@ def register(
517518

518519
if image_uri is not None:
519520
self.image_uri = image_uri
520-
if self.model_type is not JumpStartModelType.PROPRIETARY:
521+
if model_type is not JumpStartModelType.PROPRIETARY:
521522
if model_package_group_name is None and model_package_name is None:
522523
# If model package group and model package name is not set
523524
# then register to auto-generated model package group
@@ -547,7 +548,7 @@ def register(
547548
if self.model_data is not None:
548549
container_def["ModelDataUrl"] = self.model_data
549550

550-
if self.model_type is JumpStartModelType.PROPRIETARY:
551+
if model_type is JumpStartModelType.PROPRIETARY:
551552
source_uri = self.model_package_arn
552553
model_package_group_name = self.model_id
553554

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def test_proprietary_jumpstart_model(setup):
294294

295295
assert response is not None
296296

297+
297298
@pytest.mark.skipif(
298299
True,
299300
reason="Only enable if test account is subscribed to the proprietary model",
@@ -309,7 +310,6 @@ def test_register_proprietary_jumpstart_model(setup):
309310
sagemaker_session=get_sm_session(),
310311
)
311312
model_package = model.register()
312-
313313

314314
predictor = model_package.deploy(
315315
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}]
@@ -329,7 +329,7 @@ def test_register_proprietary_jumpstart_model(setup):
329329
)
330330
def test_register_gated_jumpstart_model(setup):
331331

332-
model_id="meta-textgenerationneuron-llama-2-7b"
332+
model_id = "meta-textgenerationneuron-llama-2-7b"
333333
model = JumpStartModel(
334334
model_id=model_id,
335335
model_version="1.1.0",
@@ -339,7 +339,8 @@ def test_register_gated_jumpstart_model(setup):
339339
model_package = model.register(accept_eula=True)
340340

341341
predictor = model_package.deploy(
342-
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], accept_eula=True
342+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
343+
accept_eula=True,
343344
)
344345
payload = {"prompt": "To be, or", "maxTokens": 4, "temperature": 0, "numResults": 1}
345346

@@ -348,4 +349,3 @@ def test_register_gated_jumpstart_model(setup):
348349
predictor.delete_predictor()
349350

350351
assert response is not None
351-

0 commit comments

Comments
 (0)