|
15 | 15 | import datetime
|
16 | 16 | from difflib import get_close_matches
|
17 | 17 | import os
|
18 |
| -from typing import List, Optional, Tuple, Union |
| 18 | +from typing import Any, Dict, List, Optional, Tuple, Union |
19 | 19 | import json
|
20 | 20 | import boto3
|
21 | 21 | import botocore
|
|
42 | 42 | JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON,
|
43 | 43 | )
|
44 | 44 | from sagemaker.jumpstart.types import (
|
45 |
| - JumpStartCachedS3ContentKey, |
46 |
| - JumpStartCachedS3ContentValue, |
| 45 | + JumpStartCachedContentKey, |
| 46 | + JumpStartCachedContentValue, |
47 | 47 | JumpStartModelHeader,
|
48 | 48 | JumpStartModelSpecs,
|
49 | 49 | JumpStartS3FileType,
|
50 | 50 | JumpStartVersionedModelId,
|
| 51 | + HubType, |
| 52 | + HubContentType |
| 53 | +) |
| 54 | +from sagemaker.jumpstart.hub import utils as hub_utils |
| 55 | +from sagemaker.jumpstart.hub.interfaces import ( |
| 56 | + DescribeHubResponse, |
| 57 | + DescribeHubContentResponse, |
51 | 58 | )
|
52 | 59 | from sagemaker.jumpstart.enums import JumpStartModelType
|
53 | 60 | from sagemaker.jumpstart import utils
|
@@ -104,7 +111,7 @@ def __init__(
|
104 | 111 | s3_bucket_name=s3_bucket_name, s3_client=s3_client
|
105 | 112 | )
|
106 | 113 |
|
107 |
| - self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( |
| 114 | + self._content_cache = LRUCache[JumpStartCachedContentKey, JumpStartCachedContentValue]( |
108 | 115 | max_cache_items=max_s3_cache_items,
|
109 | 116 | expiration_horizon=s3_cache_expiration_horizon,
|
110 | 117 | retrieval_function=self._retrieval_function,
|
@@ -230,8 +237,8 @@ def _model_id_retrieval_function(
|
230 | 237 |
|
231 | 238 | model_id, version = key.model_id, key.version
|
232 | 239 | sm_version = utils.get_sagemaker_version()
|
233 |
| - manifest = self._s3_cache.get( |
234 |
| - JumpStartCachedS3ContentKey( |
| 240 | + manifest = self._content_cache.get( |
| 241 | + JumpStartCachedContentKey( |
235 | 242 | MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
|
236 | 243 | )
|
237 | 244 | )[0].formatted_content
|
@@ -392,53 +399,87 @@ def _get_json_file_from_local_override(
|
392 | 399 |
|
393 | 400 | def _retrieval_function(
|
394 | 401 | self,
|
395 |
| - key: JumpStartCachedS3ContentKey, |
396 |
| - value: Optional[JumpStartCachedS3ContentValue], |
397 |
| - ) -> JumpStartCachedS3ContentValue: |
398 |
| - """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey``. |
| 402 | + key: JumpStartCachedContentKey, |
| 403 | + value: Optional[JumpStartCachedContentValue], |
| 404 | + ) -> JumpStartCachedContentValue: |
| 405 | + """Return s3 content given a file type and s3_key in ``JumpStartCachedContentKey``. |
399 | 406 |
|
400 | 407 | If a manifest file is being fetched, we only download the object if the md5 hash in
|
401 | 408 | ``head_object`` does not match the current md5 hash for the stored value. This prevents
|
402 | 409 | unnecessarily downloading the full manifest when it hasn't changed.
|
403 | 410 |
|
404 | 411 | Args:
|
405 |
| - key (JumpStartCachedS3ContentKey): key for which to fetch s3 content. |
| 412 | + key (JumpStartCachedContentKey): key for which to fetch s3 content. |
406 | 413 | value (Optional[JumpStartVersionedModelId]): Current value of old cached
|
407 | 414 | s3 content. This is used for the manifest file, so that it is only
|
408 | 415 | downloaded when its content changes.
|
409 | 416 | """
|
410 | 417 |
|
411 |
| - file_type, s3_key = key.file_type, key.s3_key |
412 |
| - if file_type in { |
| 418 | + data_type, id_info = key.data_type, key.id_info |
| 419 | + |
| 420 | + if data_type in { |
413 | 421 | JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
|
414 | 422 | JumpStartS3FileType.PROPRIETARY_MANIFEST,
|
415 | 423 | }:
|
416 | 424 | if value is not None and not self._is_local_metadata_mode():
|
417 |
| - etag = self._get_json_md5_hash(s3_key) |
| 425 | + etag = self._get_json_md5_hash(id_info) |
418 | 426 | if etag == value.md5_hash:
|
419 | 427 | return value
|
420 |
| - formatted_body, etag = self._get_json_file(s3_key, file_type) |
421 |
| - return JumpStartCachedS3ContentValue( |
| 428 | + formatted_body, etag = self._get_json_file(id_info, data_type) |
| 429 | + return JumpStartCachedContentValue( |
422 | 430 | formatted_content=utils.get_formatted_manifest(formatted_body),
|
423 | 431 | md5_hash=etag,
|
424 | 432 | )
|
425 |
| - if file_type in { |
| 433 | + if data_type in { |
426 | 434 | JumpStartS3FileType.OPEN_WEIGHT_SPECS,
|
427 | 435 | JumpStartS3FileType.PROPRIETARY_SPECS,
|
428 | 436 | }:
|
429 |
| - formatted_body, _ = self._get_json_file(s3_key, file_type) |
| 437 | + formatted_body, _ = self._get_json_file(id_info, data_type) |
430 | 438 | model_specs = JumpStartModelSpecs(formatted_body)
|
431 | 439 | utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client)
|
432 |
| - return JumpStartCachedS3ContentValue(formatted_content=model_specs) |
433 |
| - raise ValueError(self._file_type_error_msg(file_type)) |
| 440 | + return JumpStartCachedContentValue( |
| 441 | + formatted_content=model_specs |
| 442 | + ) |
| 443 | + |
| 444 | + if data_type == HubContentType.NOTEBOOK: |
| 445 | + hub_name, _, notebook_name, notebook_version = hub_utils \ |
| 446 | + .get_info_from_hub_resource_arn(id_info) |
| 447 | + response: Dict[str, Any] = self._sagemaker_session.describe_hub_content( |
| 448 | + hub_name=hub_name, |
| 449 | + hub_content_name=notebook_name, |
| 450 | + hub_content_version=notebook_version, |
| 451 | + hub_content_type=data_type, |
| 452 | + ) |
| 453 | + hub_notebook_description = DescribeHubContentResponse(response) |
| 454 | + return JumpStartCachedContentValue(formatted_content=hub_notebook_description) |
| 455 | + |
| 456 | + if data_type in [HubContentType.MODEL, HubContentType.MODEL_REFERENCE]: |
| 457 | + hub_name, _, model_name, model_version = hub_utils.get_info_from_hub_resource_arn( |
| 458 | + id_info |
| 459 | + ) |
| 460 | + hub_model_description: Dict[str, Any] = self._sagemaker_session.describe_hub_content( |
| 461 | + hub_name=hub_name, |
| 462 | + hub_content_name=model_name, |
| 463 | + hub_content_version=model_version, |
| 464 | + hub_content_type=data_type, |
| 465 | + ) |
| 466 | + |
| 467 | + model_specs = make_model_specs_from_describe_hub_content_response( |
| 468 | + DescribeHubContentResponse(hub_model_description), |
| 469 | + ) |
| 470 | + |
| 471 | + utils.emit_logs_based_on_model_specs(model_specs, self.get_region(), self._s3_client) |
| 472 | + return JumpStartCachedContentValue(formatted_content=model_specs) |
| 473 | + |
| 474 | + raise ValueError(self._file_type_error_msg(data_type)) |
434 | 475 |
|
435 | 476 | def get_manifest(
|
436 | 477 | self,
|
437 | 478 | model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
|
438 | 479 | ) -> List[JumpStartModelHeader]:
|
439 | 480 | """Return entire JumpStart models manifest."""
|
440 |
| - manifest_dict = self._s3_cache.get( |
441 |
| - JumpStartCachedS3ContentKey( |
| 481 | + manifest_dict = self._content_cache.get( |
| 482 | + JumpStartCachedContentKey( |
442 | 483 | MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
|
443 | 484 | )
|
444 | 485 | )[0].formatted_content
|
@@ -525,8 +566,8 @@ def _get_header_impl(
|
525 | 566 | JumpStartVersionedModelId(model_id, semantic_version_str)
|
526 | 567 | )[0]
|
527 | 568 |
|
528 |
| - manifest = self._s3_cache.get( |
529 |
| - JumpStartCachedS3ContentKey( |
| 569 | + manifest = self._content_cache.get( |
| 570 | + JumpStartCachedContentKey( |
530 | 571 | MODEL_TYPE_TO_MANIFEST_MAP[model_type], self._manifest_file_s3_map[model_type]
|
531 | 572 | )
|
532 | 573 | )[0].formatted_content
|
@@ -556,18 +597,44 @@ def get_specs(
|
556 | 597 | """
|
557 | 598 | header = self.get_header(model_id, version_str, model_type)
|
558 | 599 | spec_key = header.spec_key
|
559 |
| - specs, cache_hit = self._s3_cache.get( |
560 |
| - JumpStartCachedS3ContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) |
| 600 | + specs, cache_hit = self._content_cache.get( |
| 601 | + JumpStartCachedContentKey(MODEL_TYPE_TO_SPECS_MAP[model_type], spec_key) |
561 | 602 | )
|
562 | 603 |
|
563 | 604 | if not cache_hit and "*" in version_str:
|
564 | 605 | JUMPSTART_LOGGER.warning(
|
565 | 606 | get_wildcard_model_version_msg(header.model_id, version_str, header.version)
|
566 | 607 | )
|
567 | 608 | return specs.formatted_content
|
| 609 | + |
| 610 | + def get_hub_model(self, hub_model_arn: str) -> JumpStartModelSpecs: |
| 611 | + """Return JumpStart-compatible specs for a given Hub model |
| 612 | +
|
| 613 | + Args: |
| 614 | + hub_model_arn (str): Arn for the Hub model to get specs for |
| 615 | + """ |
| 616 | + |
| 617 | + details, _ = self._content_cache.get(JumpStartCachedContentKey( |
| 618 | + HubContentType.MODEL, |
| 619 | + hub_model_arn, |
| 620 | + )) |
| 621 | + return details.formatted_content |
| 622 | + |
| 623 | + def get_hub_model_reference(self, hub_model_arn: str) -> JumpStartModelSpecs: |
| 624 | + """Return JumpStart-compatible specs for a given Hub model reference |
| 625 | +
|
| 626 | + Args: |
| 627 | + hub_model_arn (str): Arn for the Hub model to get specs for |
| 628 | + """ |
| 629 | + |
| 630 | + details, _ = self._content_cache.get(JumpStartCachedContentKey( |
| 631 | + HubContentType.MODEL_REFERENCE, |
| 632 | + hub_model_arn, |
| 633 | + )) |
| 634 | + return details.formatted_content |
568 | 635 |
|
569 | 636 | def clear(self) -> None:
|
570 | 637 | """Clears the model ID/version and s3 cache."""
|
571 |
| - self._s3_cache.clear() |
| 638 | + self._content_cache.clear() |
572 | 639 | self._open_weight_model_id_manifest_key_cache.clear()
|
573 | 640 | self._proprietary_model_id_manifest_key_cache.clear()
|
0 commit comments