Skip to content

Commit 359ea1c

Browse files
committed
added more tests to cover some lines
1 parent d71b727 commit 359ea1c

File tree

2 files changed

+74
-2
lines changed

2 files changed

+74
-2
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ def set_manifest_file_s3_key(
164164
if not property_name:
165165
raise ValueError(
166166
f"Bad value when setting manifest '{file_type}': must be in"
167-
f"{JumpStartS3FileType.OPEN_SOURCE_MANIFEST}"
168-
f"{JumpStartS3FileType.PROPRIETARY_MANIFEST}"
167+
f" {JumpStartS3FileType.OPEN_SOURCE_MANIFEST}"
168+
f" {JumpStartS3FileType.PROPRIETARY_MANIFEST}"
169169
)
170170
if key != property_name:
171171
setattr(self, property_name, key)

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
JumpStartModelHeader,
3737
JumpStartModelSpecs,
3838
JumpStartVersionedModelId,
39+
JumpStartS3FileType,
3940
)
4041
from sagemaker.jumpstart.enums import JumpStartModelType
4142
from tests.unit.sagemaker.jumpstart.utils import (
@@ -358,6 +359,12 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client):
358359
cache.set_manifest_file_s3_key("some_key1")
359360
cache.clear.assert_called_once()
360361

362+
cache.clear.reset_mock()
363+
cache.set_manifest_file_s3_key("some_key1", file_type=JumpStartS3FileType.OPEN_SOURCE_MANIFEST)
364+
cache.clear.assert_called_once()
365+
with pytest.raises(ValueError):
366+
cache.set_manifest_file_s3_key("some_key1", file_type="unknown_type")
367+
361368

362369
def test_jumpstart_cache_handles_boto3_client_errors():
363370
# Testing get_object
@@ -514,6 +521,71 @@ def test_jumpstart_cache_accepts_input_parameters():
514521
)
515522

516523

524+
def test_jumpstart_proprietary_cache_accepts_input_parameters():
525+
526+
region = "us-east-1"
527+
max_s3_cache_items = 1
528+
s3_cache_expiration_horizon = datetime.timedelta(weeks=2)
529+
max_semantic_version_cache_items = 3
530+
semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4)
531+
bucket = "my-amazing-bucket"
532+
manifest_file_key = "some_s3_key"
533+
proprietary_manifest_file_key = "some_proprietary_s3_key"
534+
535+
cache = JumpStartModelsCache(
536+
region=region,
537+
max_s3_cache_items=max_s3_cache_items,
538+
s3_cache_expiration_horizon=s3_cache_expiration_horizon,
539+
max_semantic_version_cache_items=max_semantic_version_cache_items,
540+
semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon,
541+
s3_bucket_name=bucket,
542+
manifest_file_s3_key=manifest_file_key,
543+
proprietary_manifest_s3_key=proprietary_manifest_file_key,
544+
)
545+
546+
assert (
547+
cache.get_manifest_file_s3_key(file_type=JumpStartS3FileType.PROPRIETARY_MANIFEST)
548+
== proprietary_manifest_file_key
549+
)
550+
assert cache.get_region() == region
551+
assert cache.get_bucket() == bucket
552+
assert cache._s3_cache._max_cache_items == max_s3_cache_items
553+
assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon
554+
assert (
555+
cache._proprietary_model_id_manifest_key_cache._max_cache_items
556+
== max_semantic_version_cache_items
557+
)
558+
assert (
559+
cache._proprietary_model_id_manifest_key_cache._expiration_horizon
560+
== semantic_version_cache_expiration_horizon
561+
)
562+
563+
564+
def test_jumpstart_cache_raise_unknown_file_type_exception():
565+
566+
region = "us-east-1"
567+
max_s3_cache_items = 1
568+
s3_cache_expiration_horizon = datetime.timedelta(weeks=2)
569+
max_semantic_version_cache_items = 3
570+
semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4)
571+
bucket = "my-amazing-bucket"
572+
manifest_file_key = "some_s3_key"
573+
proprietary_manifest_file_key = "some_proprietary_s3_key"
574+
575+
cache = JumpStartModelsCache(
576+
region=region,
577+
max_s3_cache_items=max_s3_cache_items,
578+
s3_cache_expiration_horizon=s3_cache_expiration_horizon,
579+
max_semantic_version_cache_items=max_semantic_version_cache_items,
580+
semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon,
581+
s3_bucket_name=bucket,
582+
manifest_file_s3_key=manifest_file_key,
583+
proprietary_manifest_s3_key=proprietary_manifest_file_key,
584+
)
585+
with pytest.raises(ValueError):
586+
cache.get_manifest_file_s3_key(file_type="unknown_type")
587+
588+
517589
@patch("boto3.client")
518590
def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
519591

0 commit comments

Comments
 (0)