Skip to content

Commit d18e41b

Browse files
malav-shastrimalavhs
and
malavhs
authored
tests: Implement integration tests covering JumpStart PrivateHub workflows (#4883)
* tests: Implement integration tests covering JumpStart PrivateHub workflows * linting * formating * Only delete the pytest session specific test * change scope to session * address nits * Address test failures * address typo * address comments * resolve flake8 errors * implement throttle handling * flake8 * flake8 * Adding more assertions --------- Co-authored-by: malavhs <[email protected]>
1 parent f7c6cd3 commit d18e41b

File tree

11 files changed

+362
-6
lines changed

11 files changed

+362
-6
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,6 @@
222222

223223
JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub"
224224

225-
JUMPSTART_MODEL_HUB_NAME = "SageMakerPublicHub"
226-
227225
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
228226
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"
229227

src/sagemaker/jumpstart/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ def retrieve_example_payload(self) -> JumpStartSerializablePayload:
447447
return payloads.retrieve_example(
448448
model_id=self.model_id,
449449
model_version=self.model_version,
450+
hub_arn=self.hub_arn,
450451
model_type=self.model_type,
451452
region=self.region,
452453
tolerate_deprecated_model=self.tolerate_deprecated_model,
@@ -1036,13 +1037,15 @@ def _get_deployment_configs(
10361037
image_uri=image_uri,
10371038
region=self.region,
10381039
model_version=self.model_version,
1040+
hub_arn=self.hub_arn,
10391041
)
10401042
deploy_kwargs = get_deploy_kwargs(
10411043
model_id=self.model_id,
10421044
instance_type=instance_type_to_use,
10431045
sagemaker_session=self.sagemaker_session,
10441046
region=self.region,
10451047
model_version=self.model_version,
1048+
hub_arn=self.hub_arn,
10461049
)
10471050

10481051
deployment_config_metadata = DeploymentConfigMetadata(

tests/integ/sagemaker/jumpstart/conftest.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,43 @@
1616
import boto3
1717
import pytest
1818
from botocore.config import Config
19+
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
20+
from sagemaker.jumpstart.hub.hub import Hub
1921
from sagemaker.session import Session
2022
from tests.integ.sagemaker.jumpstart.constants import (
2123
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
24+
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
25+
HUB_NAME_PREFIX,
2226
JUMPSTART_TAG,
2327
)
2428

29+
from sagemaker.jumpstart.types import (
30+
HubContentType,
31+
)
32+
2533

2634
from tests.integ.sagemaker.jumpstart.utils import (
2735
get_test_artifact_bucket,
2836
get_test_suite_id,
37+
get_sm_session,
38+
with_exponential_backoff,
2939
)
3040

31-
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
32-
3341

3442
def _setup():
3543
print("Setting up...")
36-
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: get_test_suite_id()})
44+
test_suite_id = get_test_suite_id()
45+
test_hub_name = f"{HUB_NAME_PREFIX}{test_suite_id}"
46+
test_hub_description = "PySDK Integ Test Private Hub"
47+
48+
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID: test_suite_id})
49+
os.environ.update({ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME: test_hub_name})
50+
51+
# Create a private hub to use for the test session
52+
hub = Hub(
53+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
54+
)
55+
hub.create(description=test_hub_description)
3756

3857

3958
def _teardown():
@@ -43,6 +62,8 @@ def _teardown():
4362

4463
test_suite_id = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]
4564

65+
test_hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
66+
4667
boto3_session = boto3.Session(region_name=JUMPSTART_DEFAULT_REGION_NAME)
4768

4869
sagemaker_client = boto3_session.client(
@@ -113,6 +134,29 @@ def _teardown():
113134
bucket = s3_resource.Bucket(test_cache_bucket)
114135
bucket.objects.filter(Prefix=test_suite_id + "/").delete()
115136

137+
# delete private hubs
138+
_delete_hubs(sagemaker_session, test_hub_name)
139+
140+
141+
def _delete_hubs(sagemaker_session, hub_name):
142+
# list and delete all hub contents first
143+
list_hub_content_response = sagemaker_session.list_hub_contents(
144+
hub_name=hub_name, hub_content_type=HubContentType.MODEL_REFERENCE.value
145+
)
146+
for model in list_hub_content_response["HubContentSummaries"]:
147+
_delete_hub_contents(sagemaker_session, hub_name, model)
148+
149+
sagemaker_session.delete_hub(hub_name)
150+
151+
152+
@with_exponential_backoff()
153+
def _delete_hub_contents(sagemaker_session, hub_name, model):
154+
sagemaker_session.delete_hub_content_reference(
155+
hub_name=hub_name,
156+
hub_content_type=HubContentType.MODEL_REFERENCE.value,
157+
hub_content_name=model["HubContentName"],
158+
)
159+
116160

117161
@pytest.fixture(scope="session", autouse=True)
118162
def setup(request):

tests/integ/sagemaker/jumpstart/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,11 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
3737

3838
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID = "JUMPSTART_SDK_TEST_SUITE_ID"
3939

40+
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME = "JUMPSTART_SDK_TEST_HUB_NAME"
41+
4042
JUMPSTART_TAG = "JumpStart-SDK-Integ-Test-Suite-Id"
4143

44+
HUB_NAME_PREFIX = "PySDK-HubTest-"
4245

4346
TRAINING_DATASET_MODEL_DICT = {
4447
("huggingface-spc-bert-base-cased", "1.0.0"): ("training-datasets/QNLI-tiny/"),

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def test_jumpstart_gated_model_inference_component_enabled(setup):
229229

230230

231231
@mock.patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning")
232-
def test_instatiating_model(mock_warning_logger, setup):
232+
def test_instantiating_model(mock_warning_logger, setup):
233233

234234
model_id = "catboost-regression-model"
235235

tests/integ/sagemaker/jumpstart/private_hub/__init__.py

Whitespace-only changes.

tests/integ/sagemaker/jumpstart/private_hub/model/__init__.py

Whitespace-only changes.
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import time
17+
18+
import pytest
19+
from sagemaker.enums import EndpointType
20+
from sagemaker.jumpstart.hub.hub import Hub
21+
from sagemaker.jumpstart.hub.utils import generate_hub_arn_for_init_kwargs
22+
from sagemaker.predictor import retrieve_default
23+
24+
import tests.integ
25+
26+
from sagemaker.jumpstart.model import JumpStartModel
27+
from tests.integ.sagemaker.jumpstart.constants import (
28+
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
29+
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
30+
JUMPSTART_TAG,
31+
)
32+
from tests.integ.sagemaker.jumpstart.utils import (
33+
get_public_hub_model_arn,
34+
get_sm_session,
35+
with_exponential_backoff,
36+
)
37+
38+
MAX_INIT_TIME_SECONDS = 5
39+
40+
TEST_MODEL_IDS = {
41+
"catboost-classification-model",
42+
"huggingface-txt2img-conflictx-complex-lineart",
43+
"meta-textgeneration-llama-2-7b",
44+
"meta-textgeneration-llama-3-2-1b",
45+
"catboost-regression-model",
46+
}
47+
48+
49+
@with_exponential_backoff()
50+
def create_model_reference(hub_instance, model_arn):
51+
hub_instance.create_model_reference(model_arn=model_arn)
52+
53+
54+
@pytest.fixture(scope="session")
55+
def add_model_references():
56+
# Create Model References to test in Hub
57+
hub_instance = Hub(
58+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
59+
)
60+
for model in TEST_MODEL_IDS:
61+
model_arn = get_public_hub_model_arn(hub_instance, model)
62+
create_model_reference(hub_instance, model_arn)
63+
64+
65+
def test_jumpstart_hub_model(setup, add_model_references):
66+
67+
model_id = "catboost-classification-model"
68+
69+
sagemaker_session = get_sm_session()
70+
71+
model = JumpStartModel(
72+
model_id=model_id,
73+
role=sagemaker_session.get_caller_identity_arn(),
74+
sagemaker_session=sagemaker_session,
75+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
76+
)
77+
78+
predictor = model.deploy(
79+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
80+
)
81+
82+
assert sagemaker_session.endpoint_in_service_or_not(predictor.endpoint_name)
83+
84+
85+
def test_jumpstart_hub_gated_model(setup, add_model_references):
86+
87+
model_id = "meta-textgeneration-llama-3-2-1b"
88+
89+
model = JumpStartModel(
90+
model_id=model_id,
91+
role=get_sm_session().get_caller_identity_arn(),
92+
sagemaker_session=get_sm_session(),
93+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
94+
)
95+
96+
predictor = model.deploy(
97+
accept_eula=True,
98+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
99+
)
100+
101+
payload = model.retrieve_example_payload()
102+
103+
response = predictor.predict(payload)
104+
105+
assert response is not None
106+
107+
108+
def test_jumpstart_gated_model_inference_component_enabled(setup, add_model_references):
109+
110+
model_id = "meta-textgeneration-llama-2-7b"
111+
112+
hub_name = os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME]
113+
114+
region = tests.integ.test_region()
115+
116+
sagemaker_session = get_sm_session()
117+
118+
hub_arn = generate_hub_arn_for_init_kwargs(
119+
hub_name=hub_name, region=region, session=sagemaker_session
120+
)
121+
122+
model = JumpStartModel(
123+
model_id=model_id,
124+
role=get_sm_session().get_caller_identity_arn(),
125+
sagemaker_session=sagemaker_session,
126+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
127+
)
128+
129+
model.deploy(
130+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
131+
accept_eula=True,
132+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
133+
)
134+
135+
predictor = retrieve_default(
136+
endpoint_name=model.endpoint_name,
137+
sagemaker_session=sagemaker_session,
138+
tolerate_vulnerable_model=True,
139+
hub_arn=hub_arn,
140+
)
141+
142+
payload = model.retrieve_example_payload()
143+
144+
response = predictor.predict(payload)
145+
146+
assert response is not None
147+
148+
model = JumpStartModel.attach(
149+
predictor.endpoint_name, sagemaker_session=sagemaker_session, hub_name=hub_name
150+
)
151+
assert model.model_id == model_id
152+
assert model.endpoint_name == predictor.endpoint_name
153+
assert model.inference_component_name == predictor.component_name
154+
155+
156+
def test_instantiating_model(setup, add_model_references):
157+
158+
model_id = "catboost-regression-model"
159+
160+
start_time = time.perf_counter()
161+
162+
JumpStartModel(
163+
model_id=model_id,
164+
role=get_sm_session().get_caller_identity_arn(),
165+
sagemaker_session=get_sm_session(),
166+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME],
167+
)
168+
169+
elapsed_time = time.perf_counter() - start_time
170+
171+
assert elapsed_time <= MAX_INIT_TIME_SECONDS
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import pytest
15+
from sagemaker.jumpstart.hub.hub import Hub
16+
17+
from tests.integ.sagemaker.jumpstart.utils import (
18+
get_sm_session,
19+
)
20+
from tests.integ.sagemaker.jumpstart.utils import (
21+
get_test_suite_id,
22+
)
23+
from tests.integ.sagemaker.jumpstart.constants import (
24+
HUB_NAME_PREFIX,
25+
)
26+
27+
28+
@pytest.fixture
29+
def hub_instance():
30+
HUB_NAME = f"{HUB_NAME_PREFIX}-{get_test_suite_id()}"
31+
hub = Hub(HUB_NAME, sagemaker_session=get_sm_session())
32+
yield hub
33+
34+
35+
def test_private_hub(setup, hub_instance):
36+
# Createhub
37+
create_hub_response = hub_instance.create(
38+
description="This is a Test Private Hub.",
39+
display_name="PySDK integration tests Hub",
40+
search_keywords=["jumpstart-sdk-integ-test"],
41+
)
42+
43+
# Create Hub Verifications
44+
assert create_hub_response is not None
45+
46+
# Describe Hub
47+
hub_description = hub_instance.describe()
48+
assert hub_description is not None
49+
50+
# Delete Hub
51+
delete_hub_response = hub_instance.delete()
52+
assert delete_hub_response is not None

0 commit comments

Comments
 (0)