Skip to content

Commit 74bbb09

Browse files
authored
feat: Adding scan and tagging utility (aws#4499)
* fix: Adding new utils * feat: Adding Curated Hub scanning feature * fix: Refactoring * fix: Removing test values * fix: Refactoring * fix: removing lru cache temporarily * fix: Adding unit tests * fix: renaming * fix: Initial refactorings * fix: Adding more alterations * fix: Addressing unit tests * fix: Adding more unittests * fix: Add tests * fix: Adding list to scan input * fix: typo * fix: Addressing naming comments * fix: changing from string * fix: formatter * fix: linter * fix: linter * fix: linters * fix: linter * fix: linting * fix: more linting :/ * fix: linting * fix: more linting :/ * fix: linting * fix: tests * fix: Adding __init__.py * fix: linting * fix: linting
1 parent a632795 commit 74bbb09

21 files changed

+744
-246
lines changed

src/sagemaker/jumpstart/artifacts/environment_variables.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def _retrieve_default_environment_variables(
133133
)
134134

135135
if gated_model_env_var is None and model_specs.is_gated_model():
136-
137136
possible_env_vars: Set[str] = {
138137
retrieve_gated_env_var_for_instance_type(instance_type)
139138
for instance_type in model_specs.supported_training_instance_types

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def _retrieve_model_package_arn(
8484
)
8585

8686
if scope == JumpStartScriptScope.INFERENCE:
87-
8887
instance_specific_arn: Optional[str] = (
8988
model_specs.hosting_instance_type_variants.get_model_package_arn(
9089
region=region, instance_type=instance_type
@@ -155,7 +154,6 @@ def _retrieve_model_package_model_artifact_s3_uri(
155154
"""
156155

157156
if scope == JumpStartScriptScope.TRAINING:
158-
159157
if region is None:
160158
region = JUMPSTART_DEFAULT_REGION_NAME
161159

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def _retrieve_model_uri(
149149
model_artifact_key: str
150150

151151
if model_scope == JumpStartScriptScope.INFERENCE:
152-
153152
is_prepacked = not model_specs.use_inference_script_uri()
154153

155154
model_artifact_key = (
@@ -159,7 +158,6 @@ def _retrieve_model_uri(
159158
)
160159

161160
elif model_scope == JumpStartScriptScope.TRAINING:
162-
163161
model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
164162

165163
default_jumpstart_bucket: str = (

src/sagemaker/jumpstart/artifacts/resource_requirements.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def _retrieve_default_resources(
134134
}
135135

136136
if is_dynamic_container_deployment_supported:
137-
138137
all_resource_requirement_kwargs = {}
139138

140139
for (

src/sagemaker/jumpstart/cache.py

Lines changed: 37 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,7 @@ def set_manifest_file_s3_key(
177177
}
178178
property_name = file_mapping.get(file_type)
179179
if not property_name:
180-
raise ValueError(
181-
self._file_type_error_msg(file_type, manifest_only=True)
182-
)
180+
raise ValueError(self._file_type_error_msg(file_type, manifest_only=True))
183181
if key != property_name:
184182
setattr(self, property_name, key)
185183
self.clear()
@@ -192,9 +190,7 @@ def get_manifest_file_s3_key(
192190
return self._manifest_file_s3_key
193191
if file_type == JumpStartS3FileType.PROPRIETARY_MANIFEST:
194192
return self._proprietary_manifest_s3_key
195-
raise ValueError(
196-
self._file_type_error_msg(file_type, manifest_only=True)
197-
)
193+
raise ValueError(self._file_type_error_msg(file_type, manifest_only=True))
198194

199195
def set_s3_bucket_name(self, s3_bucket_name: str) -> None:
200196
"""Set s3 bucket used for cache."""
@@ -247,7 +243,8 @@ def _model_id_retrieval_function(
247243
sm_version = utils.get_sagemaker_version()
248244
manifest = self._content_cache.get(
249245
JumpStartCachedContentKey(
250-
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
246+
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
247+
)
251248
)[0].formatted_content
252249

253250
versions_compatible_with_sagemaker = [
@@ -264,7 +261,8 @@ def _model_id_retrieval_function(
264261
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)
265262

266263
versions_incompatible_with_sagemaker = [
267-
Version(header.version) for header in manifest.values() # type: ignore
264+
Version(header.version)
265+
for header in manifest.values() # type: ignore
268266
if header.model_id == model_id
269267
]
270268
sm_incompatible_model_version = self._select_version(
@@ -294,9 +292,7 @@ def _model_id_retrieval_function(
294292
raise KeyError(error_msg)
295293

296294
error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. "
297-
error_msg += (
298-
f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. "
299-
)
295+
error_msg += f"Visit {MODEL_ID_LIST_WEB_URL} for updated list of models. "
300296

301297
other_model_id_version = None
302298
if model_type == JumpStartModelType.OPEN_WEIGHTS:
@@ -305,19 +301,17 @@ def _model_id_retrieval_function(
305301
) # all versions here are incompatible with sagemaker
306302
elif model_type == JumpStartModelType.PROPRIETARY:
307303
all_possible_model_id_version = [
308-
header.version for header in manifest.values() # type: ignore
304+
header.version
305+
for header in manifest.values() # type: ignore
309306
if header.model_id == model_id
310307
]
311308
other_model_id_version = (
312-
None
313-
if not all_possible_model_id_version
314-
else all_possible_model_id_version[0]
309+
None if not all_possible_model_id_version else all_possible_model_id_version[0]
315310
)
316311

317312
if other_model_id_version is not None:
318313
error_msg += (
319-
f"Consider using model ID '{model_id}' with version "
320-
f"'{other_model_id_version}'."
314+
f"Consider using model ID '{model_id}' with version " f"'{other_model_id_version}'."
321315
)
322316
else:
323317
possible_model_ids = [header.model_id for header in manifest.values()] # type: ignore
@@ -359,15 +353,15 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list],
359353

360354
def _is_local_metadata_mode(self) -> bool:
361355
"""Returns True if the cache should use local metadata mode, based off env variables."""
362-
return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
363-
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
364-
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
365-
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]))
356+
return (
357+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ
358+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE])
359+
and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ
360+
and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])
361+
)
366362

367363
def _get_json_file(
368-
self,
369-
key: str,
370-
filetype: JumpStartS3FileType
364+
self, key: str, filetype: JumpStartS3FileType
371365
) -> Tuple[Union[dict, list], Optional[str]]:
372366
"""Returns json file either from s3 or local file system.
373367
@@ -391,21 +385,19 @@ def _get_json_md5_hash(self, key: str):
391385
return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"]
392386

393387
def _get_json_file_from_local_override(
394-
self,
395-
key: str,
396-
filetype: JumpStartS3FileType
388+
self, key: str, filetype: JumpStartS3FileType
397389
) -> Union[dict, list]:
398390
"""Reads json file from local filesystem and returns data."""
399391
if filetype == JumpStartS3FileType.OPEN_WEIGHT_MANIFEST:
400-
metadata_local_root = (
401-
os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]
402-
)
392+
metadata_local_root = os.environ[
393+
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE
394+
]
403395
elif filetype == JumpStartS3FileType.OPEN_WEIGHT_SPECS:
404396
metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE]
405397
else:
406398
raise ValueError(f"Unsupported file type for local override: {filetype}")
407399
file_path = os.path.join(metadata_local_root, key)
408-
with open(file_path, 'r') as f:
400+
with open(file_path, "r") as f:
409401
data = json.load(f)
410402
return data
411403

@@ -450,9 +442,7 @@ def _retrieval_function(
450442
formatted_body, _ = self._get_json_file(id_info, data_type)
451443
model_specs = JumpStartModelSpecs(formatted_body)
452444
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
453-
return JumpStartCachedContentValue(
454-
formatted_content=model_specs
455-
)
445+
return JumpStartCachedContentValue(formatted_content=model_specs)
456446

457447
if data_type == HubContentType.MODEL:
458448
hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn(
@@ -462,21 +452,15 @@ def _retrieval_function(
462452
hub_name=hub_name,
463453
hub_content_name=model_name,
464454
hub_content_version=model_version,
465-
hub_content_type=data_type
455+
hub_content_type=data_type,
466456
)
467457

468458
model_specs = JumpStartModelSpecs(
469459
DescribeHubContentsResponse(hub_model_description), is_hub_content=True
470460
)
471461

472-
utils.emit_logs_based_on_model_specs(
473-
model_specs,
474-
self.get_region(),
475-
self._s3_client
476-
)
477-
return JumpStartCachedContentValue(
478-
formatted_content=model_specs
479-
)
462+
utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
463+
return JumpStartCachedContentValue(formatted_content=model_specs)
480464

481465
if data_type == HubType.HUB:
482466
hub_name, _, _, _ = hub_utils.get_info_from_hub_resource_arn(id_info)
@@ -486,9 +470,7 @@ def _retrieval_function(
486470
formatted_content=DescribeHubResponse(hub_description)
487471
)
488472

489-
raise ValueError(
490-
self._file_type_error_msg(data_type)
491-
)
473+
raise ValueError(self._file_type_error_msg(data_type))
492474

493475
def get_manifest(
494476
self,
@@ -497,7 +479,8 @@ def get_manifest(
497479
"""Return entire JumpStart models manifest."""
498480
manifest_dict = self._content_cache.get(
499481
JumpStartCachedContentKey(
500-
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
482+
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
483+
)
501484
)[0].formatted_content
502485
manifest = list(manifest_dict.values()) # type: ignore
503486
return manifest
@@ -554,16 +537,14 @@ def _select_version(
554537
except InvalidSpecifier:
555538
raise KeyError(f"Bad semantic version: {version_str}")
556539
available_versions_filtered = list(spec.filter(available_versions))
557-
return (
558-
str(max(available_versions_filtered)) if available_versions_filtered != [] else None
559-
)
540+
return str(max(available_versions_filtered)) if available_versions_filtered != [] else None
560541

561542
def _get_header_impl(
562543
self,
563544
model_id: str,
564545
semantic_version_str: str,
565546
attempt: int = 0,
566-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS
547+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
567548
) -> JumpStartModelHeader:
568549
"""Lower-level function to return header.
569550
@@ -586,7 +567,8 @@ def _get_header_impl(
586567

587568
manifest = self._content_cache.get(
588569
JumpStartCachedContentKey(
589-
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type])
570+
MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
571+
)
590572
)[0].formatted_content
591573

592574
try:
@@ -602,7 +584,7 @@ def get_specs(
602584
self,
603585
model_id: str,
604586
version_str: str,
605-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS
587+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
606588
) -> JumpStartModelSpecs:
607589
"""Return specs for a given JumpStart model ID and semantic version.
608590
@@ -615,16 +597,12 @@ def get_specs(
615597
header = self.get_header(model_id, version_str, model_type)
616598
spec_key = header.spec_key
617599
specs, cache_hit = self._content_cache.get(
618-
JumpStartCachedContentKey(
619-
MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key
620-
)
600+
JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key)
621601
)
622602

623603
if not cache_hit and "*" in version_str:
624604
JUMPSTART_LOGGER.warning(
625-
get_wildcard_model_version_msg(
626-
header.model_id, version_str, header.version
627-
)
605+
get_wildcard_model_version_msg(header.model_id, version_str, header.version)
628606
)
629607
return specs.formatted_content
630608

src/sagemaker/jumpstart/curated_hub/accessors/file_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
HubContentDependencyType,
2222
S3ObjectLocation,
2323
)
24-
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import PublicModelDataAccessor
24+
from sagemaker.jumpstart.curated_hub.accessors.public_model_data import (
25+
PublicModelDataAccessor,
26+
)
2527
from sagemaker.jumpstart.curated_hub.utils import is_gated_bucket
2628
from sagemaker.jumpstart.types import JumpStartModelSpecs
2729

src/sagemaker/jumpstart/curated_hub/accessors/public_model_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
model_specs: JumpStartModelSpecs,
3737
studio_specs: Dict[str, Dict[str, Any]],
3838
):
39+
"""Creates a PublicModelDataAccessor."""
3940
self._region = region
4041
self._bucket = (
4142
get_jumpstart_gated_content_bucket(region)

0 commit comments

Comments
 (0)