@@ -177,9 +177,7 @@ def set_manifest_file_s3_key(
177
177
}
178
178
property_name = file_mapping .get (file_type )
179
179
if not property_name :
180
- raise ValueError (
181
- self ._file_type_error_msg (file_type , manifest_only = True )
182
- )
180
+ raise ValueError (self ._file_type_error_msg (file_type , manifest_only = True ))
183
181
if key != property_name :
184
182
setattr (self , property_name , key )
185
183
self .clear ()
@@ -192,9 +190,7 @@ def get_manifest_file_s3_key(
192
190
return self ._manifest_file_s3_key
193
191
if file_type == JumpStartS3FileType .PROPRIETARY_MANIFEST :
194
192
return self ._proprietary_manifest_s3_key
195
- raise ValueError (
196
- self ._file_type_error_msg (file_type , manifest_only = True )
197
- )
193
+ raise ValueError (self ._file_type_error_msg (file_type , manifest_only = True ))
198
194
199
195
def set_s3_bucket_name (self , s3_bucket_name : str ) -> None :
200
196
"""Set s3 bucket used for cache."""
@@ -247,7 +243,8 @@ def _model_id_retrieval_function(
247
243
sm_version = utils .get_sagemaker_version ()
248
244
manifest = self ._content_cache .get (
249
245
JumpStartCachedContentKey (
250
- MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ])
246
+ MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ]
247
+ )
251
248
)[0 ].formatted_content
252
249
253
250
versions_compatible_with_sagemaker = [
@@ -264,7 +261,8 @@ def _model_id_retrieval_function(
264
261
return JumpStartVersionedModelId (model_id , sm_compatible_model_version )
265
262
266
263
versions_incompatible_with_sagemaker = [
267
- Version (header .version ) for header in manifest .values () # type: ignore
264
+ Version (header .version )
265
+ for header in manifest .values () # type: ignore
268
266
if header .model_id == model_id
269
267
]
270
268
sm_incompatible_model_version = self ._select_version (
@@ -294,9 +292,7 @@ def _model_id_retrieval_function(
294
292
raise KeyError (error_msg )
295
293
296
294
error_msg = f"Unable to find model manifest for '{ model_id } ' with version '{ version } '. "
297
- error_msg += (
298
- f"Visit { MODEL_ID_LIST_WEB_URL } for updated list of models. "
299
- )
295
+ error_msg += f"Visit { MODEL_ID_LIST_WEB_URL } for updated list of models. "
300
296
301
297
other_model_id_version = None
302
298
if model_type == JumpStartModelType .OPEN_WEIGHTS :
@@ -305,19 +301,17 @@ def _model_id_retrieval_function(
305
301
) # all versions here are incompatible with sagemaker
306
302
elif model_type == JumpStartModelType .PROPRIETARY :
307
303
all_possible_model_id_version = [
308
- header .version for header in manifest .values () # type: ignore
304
+ header .version
305
+ for header in manifest .values () # type: ignore
309
306
if header .model_id == model_id
310
307
]
311
308
other_model_id_version = (
312
- None
313
- if not all_possible_model_id_version
314
- else all_possible_model_id_version [0 ]
309
+ None if not all_possible_model_id_version else all_possible_model_id_version [0 ]
315
310
)
316
311
317
312
if other_model_id_version is not None :
318
313
error_msg += (
319
- f"Consider using model ID '{ model_id } ' with version "
320
- f"'{ other_model_id_version } '."
314
+ f"Consider using model ID '{ model_id } ' with version " f"'{ other_model_id_version } '."
321
315
)
322
316
else :
323
317
possible_model_ids = [header .model_id for header in manifest .values ()] # type: ignore
@@ -359,15 +353,15 @@ def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list],
359
353
360
354
def _is_local_metadata_mode (self ) -> bool :
361
355
"""Returns True if the cache should use local metadata mode, based off env variables."""
362
- return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os .environ
363
- and os .path .isdir (os .environ [ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ])
364
- and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os .environ
365
- and os .path .isdir (os .environ [ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ]))
356
+ return (
357
+ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os .environ
358
+ and os .path .isdir (os .environ [ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ])
359
+ and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os .environ
360
+ and os .path .isdir (os .environ [ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ])
361
+ )
366
362
367
363
def _get_json_file (
368
- self ,
369
- key : str ,
370
- filetype : JumpStartS3FileType
364
+ self , key : str , filetype : JumpStartS3FileType
371
365
) -> Tuple [Union [dict , list ], Optional [str ]]:
372
366
"""Returns json file either from s3 or local file system.
373
367
@@ -391,21 +385,19 @@ def _get_json_md5_hash(self, key: str):
391
385
return self ._s3_client .head_object (Bucket = self .s3_bucket_name , Key = key )["ETag" ]
392
386
393
387
def _get_json_file_from_local_override (
394
- self ,
395
- key : str ,
396
- filetype : JumpStartS3FileType
388
+ self , key : str , filetype : JumpStartS3FileType
397
389
) -> Union [dict , list ]:
398
390
"""Reads json file from local filesystem and returns data."""
399
391
if filetype == JumpStartS3FileType .OPEN_WEIGHT_MANIFEST :
400
- metadata_local_root = (
401
- os . environ [ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ]
402
- )
392
+ metadata_local_root = os . environ [
393
+ ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE
394
+ ]
403
395
elif filetype == JumpStartS3FileType .OPEN_WEIGHT_SPECS :
404
396
metadata_local_root = os .environ [ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ]
405
397
else :
406
398
raise ValueError (f"Unsupported file type for local override: { filetype } " )
407
399
file_path = os .path .join (metadata_local_root , key )
408
- with open (file_path , 'r' ) as f :
400
+ with open (file_path , "r" ) as f :
409
401
data = json .load (f )
410
402
return data
411
403
@@ -450,9 +442,7 @@ def _retrieval_function(
450
442
formatted_body , _ = self ._get_json_file (id_info , data_type )
451
443
model_specs = JumpStartModelSpecs (formatted_body )
452
444
utils .emit_logs_based_on_model_specs (model_specs , self .get_region (), self ._s3_client )
453
- return JumpStartCachedContentValue (
454
- formatted_content = model_specs
455
- )
445
+ return JumpStartCachedContentValue (formatted_content = model_specs )
456
446
457
447
if data_type == HubContentType .MODEL :
458
448
hub_name , _ , model_name , model_version = hub_utils .get_info_from_hub_resource_arn (
@@ -462,21 +452,15 @@ def _retrieval_function(
462
452
hub_name = hub_name ,
463
453
hub_content_name = model_name ,
464
454
hub_content_version = model_version ,
465
- hub_content_type = data_type
455
+ hub_content_type = data_type ,
466
456
)
467
457
468
458
model_specs = JumpStartModelSpecs (
469
459
DescribeHubContentsResponse (hub_model_description ), is_hub_content = True
470
460
)
471
461
472
- utils .emit_logs_based_on_model_specs (
473
- model_specs ,
474
- self .get_region (),
475
- self ._s3_client
476
- )
477
- return JumpStartCachedContentValue (
478
- formatted_content = model_specs
479
- )
462
+ utils .emit_logs_based_on_model_specs (model_specs , self .get_region (), self ._s3_client )
463
+ return JumpStartCachedContentValue (formatted_content = model_specs )
480
464
481
465
if data_type == HubType .HUB :
482
466
hub_name , _ , _ , _ = hub_utils .get_info_from_hub_resource_arn (id_info )
@@ -486,9 +470,7 @@ def _retrieval_function(
486
470
formatted_content = DescribeHubResponse (hub_description )
487
471
)
488
472
489
- raise ValueError (
490
- self ._file_type_error_msg (data_type )
491
- )
473
+ raise ValueError (self ._file_type_error_msg (data_type ))
492
474
493
475
def get_manifest (
494
476
self ,
@@ -497,7 +479,8 @@ def get_manifest(
497
479
"""Return entire JumpStart models manifest."""
498
480
manifest_dict = self ._content_cache .get (
499
481
JumpStartCachedContentKey (
500
- MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ])
482
+ MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ]
483
+ )
501
484
)[0 ].formatted_content
502
485
manifest = list (manifest_dict .values ()) # type: ignore
503
486
return manifest
@@ -554,16 +537,14 @@ def _select_version(
554
537
except InvalidSpecifier :
555
538
raise KeyError (f"Bad semantic version: { version_str } " )
556
539
available_versions_filtered = list (spec .filter (available_versions ))
557
- return (
558
- str (max (available_versions_filtered )) if available_versions_filtered != [] else None
559
- )
540
+ return str (max (available_versions_filtered )) if available_versions_filtered != [] else None
560
541
561
542
def _get_header_impl (
562
543
self ,
563
544
model_id : str ,
564
545
semantic_version_str : str ,
565
546
attempt : int = 0 ,
566
- model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS
547
+ model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
567
548
) -> JumpStartModelHeader :
568
549
"""Lower-level function to return header.
569
550
@@ -586,7 +567,8 @@ def _get_header_impl(
586
567
587
568
manifest = self ._content_cache .get (
588
569
JumpStartCachedContentKey (
589
- MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ])
570
+ MODEL_TYPE_TO_MANIFEST_MAP [model_type ], self ._manifest_file_s3_map [model_type ]
571
+ )
590
572
)[0 ].formatted_content
591
573
592
574
try :
@@ -602,7 +584,7 @@ def get_specs(
602
584
self ,
603
585
model_id : str ,
604
586
version_str : str ,
605
- model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS
587
+ model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
606
588
) -> JumpStartModelSpecs :
607
589
"""Return specs for a given JumpStart model ID and semantic version.
608
590
@@ -615,16 +597,12 @@ def get_specs(
615
597
header = self .get_header (model_id , version_str , model_type )
616
598
spec_key = header .spec_key
617
599
specs , cache_hit = self ._content_cache .get (
618
- JumpStartCachedContentKey (
619
- MODEL_TYPE_TO_SPECS_MAP [model_type ], spec_key
620
- )
600
+ JumpStartCachedContentKey (MODEL_TYPE_TO_SPECS_MAP [model_type ], spec_key )
621
601
)
622
602
623
603
if not cache_hit and "*" in version_str :
624
604
JUMPSTART_LOGGER .warning (
625
- get_wildcard_model_version_msg (
626
- header .model_id , version_str , header .version
627
- )
605
+ get_wildcard_model_version_msg (header .model_id , version_str , header .version )
628
606
)
629
607
return specs .formatted_content
630
608
0 commit comments