|
36 | 36 | JumpStartModelHeader,
|
37 | 37 | JumpStartModelSpecs,
|
38 | 38 | JumpStartVersionedModelId,
|
| 39 | + JumpStartS3FileType, |
39 | 40 | )
|
40 | 41 | from sagemaker.jumpstart.enums import JumpStartModelType
|
41 | 42 | from tests.unit.sagemaker.jumpstart.utils import (
|
@@ -358,6 +359,12 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client):
|
358 | 359 | cache.set_manifest_file_s3_key("some_key1")
|
359 | 360 | cache.clear.assert_called_once()
|
360 | 361 |
|
| 362 | + cache.clear.reset_mock() |
| 363 | + cache.set_manifest_file_s3_key("some_key1", file_type=JumpStartS3FileType.OPEN_SOURCE_MANIFEST) |
| 364 | + cache.clear.assert_called_once() |
| 365 | + with pytest.raises(ValueError): |
| 366 | + cache.set_manifest_file_s3_key("some_key1", file_type="unknown_type") |
| 367 | + |
361 | 368 |
|
362 | 369 | def test_jumpstart_cache_handles_boto3_client_errors():
|
363 | 370 | # Testing get_object
|
@@ -514,6 +521,71 @@ def test_jumpstart_cache_accepts_input_parameters():
|
514 | 521 | )
|
515 | 522 |
|
516 | 523 |
|
| 524 | +def test_jumpstart_proprietary_cache_accepts_input_parameters(): |
| 525 | + |
| 526 | + region = "us-east-1" |
| 527 | + max_s3_cache_items = 1 |
| 528 | + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) |
| 529 | + max_semantic_version_cache_items = 3 |
| 530 | + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) |
| 531 | + bucket = "my-amazing-bucket" |
| 532 | + manifest_file_key = "some_s3_key" |
| 533 | + proprietary_manifest_file_key = "some_proprietary_s3_key" |
| 534 | + |
| 535 | + cache = JumpStartModelsCache( |
| 536 | + region=region, |
| 537 | + max_s3_cache_items=max_s3_cache_items, |
| 538 | + s3_cache_expiration_horizon=s3_cache_expiration_horizon, |
| 539 | + max_semantic_version_cache_items=max_semantic_version_cache_items, |
| 540 | + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, |
| 541 | + s3_bucket_name=bucket, |
| 542 | + manifest_file_s3_key=manifest_file_key, |
| 543 | + proprietary_manifest_s3_key=proprietary_manifest_file_key, |
| 544 | + ) |
| 545 | + |
| 546 | + assert ( |
| 547 | + cache.get_manifest_file_s3_key(file_type=JumpStartS3FileType.PROPRIETARY_MANIFEST) |
| 548 | + == proprietary_manifest_file_key |
| 549 | + ) |
| 550 | + assert cache.get_region() == region |
| 551 | + assert cache.get_bucket() == bucket |
| 552 | + assert cache._s3_cache._max_cache_items == max_s3_cache_items |
| 553 | + assert cache._s3_cache._expiration_horizon == s3_cache_expiration_horizon |
| 554 | + assert ( |
| 555 | + cache._proprietary_model_id_manifest_key_cache._max_cache_items |
| 556 | + == max_semantic_version_cache_items |
| 557 | + ) |
| 558 | + assert ( |
| 559 | + cache._proprietary_model_id_manifest_key_cache._expiration_horizon |
| 560 | + == semantic_version_cache_expiration_horizon |
| 561 | + ) |
| 562 | + |
| 563 | + |
| 564 | +def test_jumpstart_cache_raise_unknown_file_type_exception(): |
| 565 | + |
| 566 | + region = "us-east-1" |
| 567 | + max_s3_cache_items = 1 |
| 568 | + s3_cache_expiration_horizon = datetime.timedelta(weeks=2) |
| 569 | + max_semantic_version_cache_items = 3 |
| 570 | + semantic_version_cache_expiration_horizon = datetime.timedelta(microseconds=4) |
| 571 | + bucket = "my-amazing-bucket" |
| 572 | + manifest_file_key = "some_s3_key" |
| 573 | + proprietary_manifest_file_key = "some_proprietary_s3_key" |
| 574 | + |
| 575 | + cache = JumpStartModelsCache( |
| 576 | + region=region, |
| 577 | + max_s3_cache_items=max_s3_cache_items, |
| 578 | + s3_cache_expiration_horizon=s3_cache_expiration_horizon, |
| 579 | + max_semantic_version_cache_items=max_semantic_version_cache_items, |
| 580 | + semantic_version_cache_expiration_horizon=semantic_version_cache_expiration_horizon, |
| 581 | + s3_bucket_name=bucket, |
| 582 | + manifest_file_s3_key=manifest_file_key, |
| 583 | + proprietary_manifest_s3_key=proprietary_manifest_file_key, |
| 584 | + ) |
| 585 | + with pytest.raises(ValueError): |
| 586 | + cache.get_manifest_file_s3_key(file_type="unknown_type") |
| 587 | + |
| 588 | + |
517 | 589 | @patch("boto3.client")
|
518 | 590 | def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
|
519 | 591 |
|
|
0 commit comments