15
15
import datetime
16
16
from difflib import get_close_matches
17
17
import os
18
- from typing import List , Optional , Tuple , Union
18
+ from typing import Any , Dict , List , Optional , Tuple , Union
19
19
import json
20
20
import boto3
21
21
import botocore
29
29
JUMPSTART_LOGGER ,
30
30
MODEL_ID_LIST_WEB_URL ,
31
31
)
32
+ from sagemaker .jumpstart .curated_hub .curated_hub import CuratedHub
32
33
from sagemaker .jumpstart .exceptions import get_wildcard_model_version_msg
33
34
from sagemaker .jumpstart .parameters import (
34
35
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS ,
37
38
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON ,
38
39
)
39
40
from sagemaker .jumpstart .types import (
40
- JumpStartCachedS3ContentKey ,
41
- JumpStartCachedS3ContentValue ,
41
+ JumpStartCachedContentKey ,
42
+ JumpStartCachedContentValue ,
42
43
JumpStartModelHeader ,
43
44
JumpStartModelSpecs ,
44
45
JumpStartS3FileType ,
45
46
JumpStartVersionedModelId ,
47
+ HubDataType ,
46
48
)
47
49
from sagemaker .jumpstart import utils
48
50
from sagemaker .utilities .cache import LRUCache
@@ -95,7 +97,7 @@ def __init__(
95
97
"""
96
98
97
99
self ._region = region
98
- self ._s3_cache = LRUCache [JumpStartCachedS3ContentKey , JumpStartCachedS3ContentValue ](
100
+ self ._content_cache = LRUCache [JumpStartCachedContentKey , JumpStartCachedContentValue ](
99
101
max_cache_items = max_s3_cache_items ,
100
102
expiration_horizon = s3_cache_expiration_horizon ,
101
103
retrieval_function = self ._retrieval_function ,
@@ -172,8 +174,8 @@ def _get_manifest_key_from_model_id_semantic_version(
172
174
173
175
model_id , version = key .model_id , key .version
174
176
175
- manifest = self ._s3_cache .get (
176
- JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
177
+ manifest = self ._content_cache .get (
178
+ JumpStartCachedContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
177
179
)[0 ].formatted_content
178
180
179
181
sm_version = utils .get_sagemaker_version ()
@@ -301,50 +303,71 @@ def _get_json_file_from_local_override(
301
303
302
304
def _retrieval_function (
303
305
self ,
304
- key : JumpStartCachedS3ContentKey ,
305
- value : Optional [JumpStartCachedS3ContentValue ],
306
- ) -> JumpStartCachedS3ContentValue :
307
- """Return s3 content given a file type and s3_key in ``JumpStartCachedS3ContentKey ``.
306
+ key : JumpStartCachedContentKey ,
307
+ value : Optional [JumpStartCachedContentValue ],
308
+ ) -> JumpStartCachedContentValue :
309
+ """Return s3 content given a data type and s3_key in ``JumpStartCachedContentKey ``.
308
310
309
311
If a manifest file is being fetched, we only download the object if the md5 hash in
310
312
``head_object`` does not match the current md5 hash for the stored value. This prevents
311
313
unnecessarily downloading the full manifest when it hasn't changed.
312
314
313
315
Args:
314
- key (JumpStartCachedS3ContentKey ): key for which to fetch s3 content.
316
+ key (JumpStartCachedContentKey ): key for which to fetch JumpStart content.
315
317
value (Optional[JumpStartVersionedModelId]): Current value of old cached
316
318
s3 content. This is used for the manifest file, so that it is only
317
319
downloaded when its content changes.
318
320
"""
319
321
320
- file_type , s3_key = key .file_type , key .s3_key
322
+ data_type , id_info = key .data_type , key .id_info
321
323
322
- if file_type == JumpStartS3FileType .MANIFEST :
324
+ if data_type == JumpStartS3FileType .MANIFEST :
323
325
if value is not None and not self ._is_local_metadata_mode ():
324
- etag = self ._get_json_md5_hash (s3_key )
326
+ etag = self ._get_json_md5_hash (id_info )
325
327
if etag == value .md5_hash :
326
328
return value
327
- formatted_body , etag = self ._get_json_file (s3_key , file_type )
328
- return JumpStartCachedS3ContentValue (
329
+ formatted_body , etag = self ._get_json_file (id_info , data_type )
330
+ return JumpStartCachedContentValue (
329
331
formatted_content = utils .get_formatted_manifest (formatted_body ),
330
332
md5_hash = etag ,
331
333
)
332
- if file_type == JumpStartS3FileType .SPECS :
333
- formatted_body , _ = self ._get_json_file (s3_key , file_type )
334
+ if data_type == JumpStartS3FileType .SPECS :
335
+ formatted_body , _ = self ._get_json_file (id_info , data_type )
334
336
model_specs = JumpStartModelSpecs (formatted_body )
335
337
utils .emit_logs_based_on_model_specs (model_specs , self .get_region (), self ._s3_client )
336
- return JumpStartCachedS3ContentValue (
338
+ return JumpStartCachedContentValue (
337
339
formatted_content = model_specs
338
340
)
341
+ if data_type == HubDataType .MODEL :
342
+ hub_name , region , model_name , model_version = utils .extract_info_from_hub_content_arn (
343
+ id_info
344
+ )
345
+ hub = CuratedHub (hub_name = hub_name , region = region )
346
+ hub_content = hub .describe_model (model_name = model_name , model_version = model_version )
347
+ utils .emit_logs_based_on_model_specs (
348
+ hub_content .content_document ,
349
+ self .get_region (),
350
+ self ._s3_client
351
+ )
352
+ model_specs = JumpStartModelSpecs (hub_content .content_document , is_hub_content = True )
353
+ return JumpStartCachedContentValue (
354
+ formatted_content = model_specs
355
+ )
356
+ if data_type == HubDataType .HUB :
357
+ hub_name , region , _ , _ = utils .extract_info_from_hub_content_arn (id_info )
358
+ hub = CuratedHub (hub_name = hub_name , region = region )
359
+ hub_info = hub .describe ()
360
+ return JumpStartCachedContentValue (formatted_content = hub_info )
339
361
raise ValueError (
340
- f"Bad value for key '{ key } ': must be in { [JumpStartS3FileType .MANIFEST , JumpStartS3FileType .SPECS ]} "
362
+ f"Bad value for key '{ key } ': must be in" ,
363
+ f"{ [JumpStartS3FileType .MANIFEST , JumpStartS3FileType .SPECS , HubDataType .HUB , HubDataType .MODEL ]} "
341
364
)
342
365
343
366
def get_manifest (self ) -> List [JumpStartModelHeader ]:
344
367
"""Return entire JumpStart models manifest."""
345
368
346
- manifest_dict = self ._s3_cache .get (
347
- JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
369
+ manifest_dict = self ._content_cache .get (
370
+ JumpStartCachedContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
348
371
)[0 ].formatted_content
349
372
manifest = list (manifest_dict .values ()) # type: ignore
350
373
return manifest
@@ -407,8 +430,8 @@ def _get_header_impl(
407
430
JumpStartVersionedModelId (model_id , semantic_version_str )
408
431
)[0 ]
409
432
410
- manifest = self ._s3_cache .get (
411
- JumpStartCachedS3ContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
433
+ manifest = self ._content_cache .get (
434
+ JumpStartCachedContentKey (JumpStartS3FileType .MANIFEST , self ._manifest_file_s3_key )
412
435
)[0 ].formatted_content
413
436
try :
414
437
header = manifest [versioned_model_id ] # type: ignore
@@ -430,8 +453,8 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
430
453
431
454
header = self .get_header (model_id , semantic_version_str )
432
455
spec_key = header .spec_key
433
- specs , cache_hit = self ._s3_cache .get (
434
- JumpStartCachedS3ContentKey (JumpStartS3FileType .SPECS , spec_key )
456
+ specs , cache_hit = self ._content_cache .get (
457
+ JumpStartCachedContentKey (JumpStartS3FileType .SPECS , spec_key )
435
458
)
436
459
if not cache_hit and "*" in semantic_version_str :
437
460
JUMPSTART_LOGGER .warning (
@@ -443,7 +466,29 @@ def get_specs(self, model_id: str, semantic_version_str: str) -> JumpStartModelS
443
466
)
444
467
return specs .formatted_content
445
468
469
+ def get_hub_model (self , hub_model_arn : str ) -> JumpStartModelSpecs :
470
+ """Return JumpStart-compatible specs for a given Hub model
471
+
472
+ Args:
473
+ hub_model_arn (str): Arn for the Hub model to get specs for
474
+ """
475
+
476
+ details , _ = self ._content_cache .get (
477
+ JumpStartCachedContentKey (HubDataType .MODEL , hub_model_arn )
478
+ )
479
+ return details .formatted_content
480
+
481
+ def get_hub (self , hub_arn : str ) -> Dict [str , Any ]:
482
+ """Return descriptive info for a given Hub
483
+
484
+ Args:
485
+ hub_arn (str): Arn for the Hub to get info for
486
+ """
487
+
488
+ details , _ = self ._content_cache .get (JumpStartCachedContentKey (HubDataType .HUB , hub_arn ))
489
+ return details .formatted_content
490
+
446
491
def clear (self ) -> None :
447
492
"""Clears the model ID/version and s3 cache."""
448
- self ._s3_cache .clear ()
493
+ self ._content_cache .clear ()
449
494
self ._model_id_semantic_version_manifest_key_cache .clear ()
0 commit comments