Skip to content

Commit 7f5ad9d

Browse files
authored
fix: include model channel for gated uncompressed models (#5181)
1 parent ddc54d2 commit 7f5ad9d

File tree

7 files changed

+328
-20
lines changed

7 files changed

+328
-20
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,18 @@ def _get_json_file(
372372
object and None when reading from the local file system.
373373
"""
374374
if self._is_local_metadata_mode():
375-
file_content, etag = self._get_json_file_from_local_override(key, filetype), None
376-
else:
377-
file_content, etag = self._get_json_file_and_etag_from_s3(key)
378-
return file_content, etag
375+
if filetype in {
376+
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
377+
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
378+
}:
379+
return self._get_json_file_from_local_override(key, filetype), None
380+
else:
381+
JUMPSTART_LOGGER.warning(
382+
"Local metadata mode is enabled, but the file type %s is not supported "
383+
"for local override. Falling back to s3.",
384+
filetype,
385+
)
386+
return self._get_json_file_and_etag_from_s3(key)
379387

380388
def _get_json_md5_hash(self, key: str):
381389
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@
5454
from sagemaker.jumpstart.constants import (
5555
JUMPSTART_DEFAULT_REGION_NAME,
5656
JUMPSTART_LOGGER,
57+
JUMPSTART_MODEL_HUB_NAME,
5758
TRAINING_ENTRY_POINT_SCRIPT_NAME,
5859
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
59-
JUMPSTART_MODEL_HUB_NAME,
6060
)
6161
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
6262
from sagemaker.jumpstart.factory import model
@@ -634,10 +634,10 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
634634
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
635635
# hub_arn is by default None unless the user specifies the hub_name
636636
# If no hub_name is specified, it is assumed the public hub
637+
# Training platform enforces that private hub models must use model channel
637638
is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False
638-
if (
639-
_model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs))
640-
or is_private_hub
639+
if is_private_hub or _model_supports_training_model_uri(
640+
**get_model_info_default_kwargs(kwargs)
641641
):
642642
default_model_uri = model_uris.retrieve(
643643
model_scope=JumpStartScriptScope.TRAINING,

src/sagemaker/jumpstart/types.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1940,12 +1940,20 @@ def use_inference_script_uri(self) -> bool:
19401940

19411941
def use_training_model_artifact(self) -> bool:
19421942
"""Returns True if the model should use a model uri when kicking off training job."""
1943-
# gated model never use training model artifact
1944-
if self.gated_bucket:
1943+
# old models with this environment variable present don't use model channel
1944+
if any(
1945+
self.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value(
1946+
instance_type
1947+
)
1948+
for instance_type in self.supported_training_instance_types
1949+
):
1950+
return False
1951+
1952+
# even older models with training model package artifact uris present also don't use model channel
1953+
if len(self.training_model_package_artifact_uris or {}) > 0:
19451954
return False
19461955

1947-
# otherwise, return true is a training model package is not set
1948-
return len(self.training_model_package_artifact_uris or {}) == 0
1956+
return getattr(self, "training_artifact_key", None) is not None
19491957

19501958
def is_gated_model(self) -> bool:
19511959
"""Returns True if the model has a EULA key or the model bucket is gated."""

tests/unit/sagemaker/jumpstart/factory/__init__.py

Whitespace-only changes.
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import pytest
15+
from unittest.mock import patch
16+
from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME
17+
from sagemaker.jumpstart.factory.estimator import (
18+
_add_model_uri_to_kwargs,
19+
get_model_info_default_kwargs,
20+
)
21+
from sagemaker.jumpstart.types import JumpStartEstimatorInitKwargs
22+
from sagemaker.jumpstart.enums import JumpStartScriptScope
23+
24+
25+
class TestAddModelUriToKwargs:
26+
@pytest.fixture
27+
def mock_kwargs(self):
28+
return JumpStartEstimatorInitKwargs(
29+
model_id="test-model",
30+
model_version="1.0.0",
31+
instance_type="ml.m5.large",
32+
model_uri=None,
33+
)
34+
35+
@patch(
36+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
37+
return_value=True,
38+
)
39+
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
40+
def test_add_model_uri_to_kwargs_default_uri(
41+
self, mock_retrieve, mock_supports_training, mock_kwargs
42+
):
43+
"""Test adding default model URI when none is provided."""
44+
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
45+
mock_retrieve.return_value = default_uri
46+
47+
result = _add_model_uri_to_kwargs(mock_kwargs)
48+
49+
mock_supports_training.assert_called_once()
50+
mock_retrieve.assert_called_once_with(
51+
model_scope=JumpStartScriptScope.TRAINING,
52+
instance_type=mock_kwargs.instance_type,
53+
**get_model_info_default_kwargs(mock_kwargs),
54+
)
55+
assert result.model_uri == default_uri
56+
57+
@patch(
58+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
59+
return_value=True,
60+
)
61+
@patch(
62+
"sagemaker.jumpstart.factory.estimator._model_supports_incremental_training",
63+
return_value=True,
64+
)
65+
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
66+
def test_add_model_uri_to_kwargs_custom_uri_with_incremental(
67+
self, mock_retrieve, mock_supports_incremental, mock_supports_training, mock_kwargs
68+
):
69+
"""Test using custom model URI with incremental training support."""
70+
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
71+
custom_uri = "s3://custom-bucket/my-model"
72+
mock_retrieve.return_value = default_uri
73+
mock_kwargs.model_uri = custom_uri
74+
75+
result = _add_model_uri_to_kwargs(mock_kwargs)
76+
77+
mock_supports_training.assert_called_once()
78+
mock_supports_incremental.assert_called_once()
79+
assert result.model_uri == custom_uri
80+
81+
@patch(
82+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
83+
return_value=True,
84+
)
85+
@patch(
86+
"sagemaker.jumpstart.factory.estimator._model_supports_incremental_training",
87+
return_value=False,
88+
)
89+
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
90+
@patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning")
91+
def test_add_model_uri_to_kwargs_custom_uri_without_incremental(
92+
self,
93+
mock_warning,
94+
mock_retrieve,
95+
mock_supports_incremental,
96+
mock_supports_training,
97+
mock_kwargs,
98+
):
99+
"""Test using custom model URI without incremental training support logs warning."""
100+
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
101+
custom_uri = "s3://custom-bucket/my-model"
102+
mock_retrieve.return_value = default_uri
103+
mock_kwargs.model_uri = custom_uri
104+
105+
result = _add_model_uri_to_kwargs(mock_kwargs)
106+
107+
mock_supports_training.assert_called_once()
108+
mock_supports_incremental.assert_called_once()
109+
mock_warning.assert_called_once()
110+
assert "does not support incremental training" in mock_warning.call_args[0][0]
111+
assert result.model_uri == custom_uri
112+
113+
@patch(
114+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
115+
return_value=False,
116+
)
117+
def test_add_model_uri_to_kwargs_no_training_support(self, mock_supports_training, mock_kwargs):
118+
"""Test when model doesn't support training model URI."""
119+
result = _add_model_uri_to_kwargs(mock_kwargs)
120+
121+
mock_supports_training.assert_called_once()
122+
assert result.model_uri is None
123+
124+
@patch(
125+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
126+
return_value=False,
127+
)
128+
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
129+
def test_add_model_uri_to_kwargs_private_hub(
130+
self, mock_retrieve, mock_supports_training, mock_kwargs
131+
):
132+
"""Test when model is from a private hub."""
133+
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
134+
mock_retrieve.return_value = default_uri
135+
mock_kwargs.hub_arn = "arn:aws:sagemaker:us-west-2:123456789012:hub/private-hub"
136+
137+
result = _add_model_uri_to_kwargs(mock_kwargs)
138+
139+
# Should not check if model supports training model URI for private hub
140+
mock_supports_training.assert_not_called()
141+
mock_retrieve.assert_called_once()
142+
assert result.model_uri == default_uri
143+
144+
@patch(
145+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
146+
return_value=False,
147+
)
148+
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
149+
def test_add_model_uri_to_kwargs_public_hub(
150+
self, mock_retrieve, mock_supports_training, mock_kwargs
151+
):
152+
"""Test when model is from the public hub."""
153+
mock_kwargs.hub_arn = (
154+
f"arn:aws:sagemaker:us-west-2:123456789012:hub/{JUMPSTART_MODEL_HUB_NAME}"
155+
)
156+
157+
result = _add_model_uri_to_kwargs(mock_kwargs)
158+
159+
# Should check if model supports training model URI for public hub
160+
mock_supports_training.assert_called_once()
161+
mock_retrieve.assert_not_called()
162+
assert result.model_uri is None

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,3 +1288,78 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_func
12881288
assert_key = JumpStartVersionedModelId("test-model", "abc")
12891289

12901290
assert result == assert_key
1291+
1292+
1293+
@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
1294+
@patch(
1295+
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
1296+
)
1297+
def test_get_json_file_from_s3():
1298+
"""Test _get_json_file retrieves from S3 in normal mode."""
1299+
cache = JumpStartModelsCache()
1300+
test_key = "test/file/path.json"
1301+
test_json_data = {"key": "value"}
1302+
test_etag = "test-etag-123"
1303+
1304+
with patch.object(
1305+
JumpStartModelsCache,
1306+
"_get_json_file_and_etag_from_s3",
1307+
return_value=(test_json_data, test_etag),
1308+
) as mock_s3_get:
1309+
result, etag = cache._get_json_file(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST)
1310+
1311+
mock_s3_get.assert_called_once_with(test_key)
1312+
assert result == test_json_data
1313+
assert etag == test_etag
1314+
1315+
1316+
@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
1317+
@patch(
1318+
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
1319+
)
1320+
def test_get_json_file_from_local_supported_type():
1321+
"""Test _get_json_file retrieves from local override for supported file types."""
1322+
cache = JumpStartModelsCache()
1323+
test_key = "test/file/path.json"
1324+
test_json_data = {"key": "value"}
1325+
1326+
with (
1327+
patch.object(JumpStartModelsCache, "_is_local_metadata_mode", return_value=True),
1328+
patch.object(
1329+
JumpStartModelsCache, "_get_json_file_from_local_override", return_value=test_json_data
1330+
) as mock_local_get,
1331+
):
1332+
result, etag = cache._get_json_file(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST)
1333+
1334+
mock_local_get.assert_called_once_with(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST)
1335+
assert result == test_json_data
1336+
assert etag is None
1337+
1338+
1339+
@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
1340+
@patch(
1341+
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
1342+
)
1343+
def test_get_json_file_local_mode_unsupported_type():
1344+
"""Test _get_json_file falls back to S3 for unsupported file types in local mode."""
1345+
cache = JumpStartModelsCache()
1346+
test_key = "test/file/path.json"
1347+
test_json_data = {"key": "value"}
1348+
test_etag = "test-etag-123"
1349+
1350+
with (
1351+
patch.object(JumpStartModelsCache, "_is_local_metadata_mode", return_value=True),
1352+
patch.object(
1353+
JumpStartModelsCache,
1354+
"_get_json_file_and_etag_from_s3",
1355+
return_value=(test_json_data, test_etag),
1356+
) as mock_s3_get,
1357+
patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") as mock_warning,
1358+
):
1359+
result, etag = cache._get_json_file(test_key, JumpStartS3FileType.PROPRIETARY_MANIFEST)
1360+
1361+
mock_s3_get.assert_called_once_with(test_key)
1362+
mock_warning.assert_called_once()
1363+
assert "not supported for local override" in mock_warning.call_args[0][0]
1364+
assert result == test_json_data
1365+
assert etag == test_etag

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
INIT_KWARGS,
4040
)
4141

42+
from unittest.mock import Mock
43+
4244
INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants(
4345
{
4446
"regional_aliases": {
@@ -329,14 +331,67 @@ def test_jumpstart_model_header():
329331
assert header1 == header3
330332

331333

332-
def test_use_training_model_artifact():
333-
specs1 = JumpStartModelSpecs(BASE_SPEC)
334-
assert specs1.use_training_model_artifact()
335-
specs1.gated_bucket = True
336-
assert not specs1.use_training_model_artifact()
337-
specs1.gated_bucket = False
338-
specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"}
339-
assert not specs1.use_training_model_artifact()
334+
class TestUseTrainingModelArtifact:
335+
@pytest.fixture
336+
def mock_specs(self):
337+
specs = Mock(spec=JumpStartModelSpecs)
338+
specs.training_instance_type_variants = Mock()
339+
specs.supported_training_instance_types = ["ml.p3.2xlarge", "ml.g4dn.xlarge"]
340+
specs.training_model_package_artifact_uris = {}
341+
specs.training_artifact_key = None
342+
return specs
343+
344+
def test_use_training_model_artifact_with_env_var(self, mock_specs):
345+
"""Test when instance type variants have env var values."""
346+
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.side_effect = [
347+
"some-value",
348+
None,
349+
]
350+
351+
result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)
352+
353+
assert result is False
354+
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.assert_any_call(
355+
"ml.p3.2xlarge"
356+
)
357+
358+
def test_use_training_model_artifact_with_package_uris(self, mock_specs):
359+
"""Test when model has training package artifact URIs."""
360+
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = (
361+
None
362+
)
363+
mock_specs.training_model_package_artifact_uris = {
364+
"ml.p3.2xlarge": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/"
365+
"llama2-13b-e155a2e0347b323fb882f1875851c5d3"
366+
}
367+
368+
result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)
369+
370+
assert result is False
371+
372+
def test_use_training_model_artifact_with_artifact_key(self, mock_specs):
373+
"""Test when model has training artifact key."""
374+
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = (
375+
None
376+
)
377+
mock_specs.training_model_package_artifact_uris = {}
378+
mock_specs.training_artifact_key = "some-key"
379+
380+
result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)
381+
382+
assert result is True
383+
384+
def test_use_training_model_artifact_without_artifact_key(self, mock_specs):
385+
"""Test when model has no training artifact key."""
386+
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = (
387+
None
388+
)
389+
mock_specs.training_model_package_artifact_uris = {}
390+
mock_specs.training_artifact_key = None
391+
392+
result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)
393+
394+
assert result is False
340395

341396

342397
def test_jumpstart_model_specs():

0 commit comments

Comments
 (0)