52
52
53
53
@override_pipeline_parameter_var
54
54
def retrieve (
55
- framework ,
56
- region ,
57
- version = None ,
58
- py_version = None ,
59
- instance_type = None ,
60
- accelerator_type = None ,
61
- image_scope = None ,
62
- container_version = None ,
63
- distribution = None ,
64
- base_framework_version = None ,
65
- training_compiler_config = None ,
66
- model_id = None ,
67
- model_version = None ,
68
- hub_arn = None ,
69
- tolerate_vulnerable_model = False ,
70
- tolerate_deprecated_model = False ,
71
- sdk_version = None ,
72
- inference_tool = None ,
73
- serverless_inference_config = None ,
74
- sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
75
- config_name = None ,
76
- model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
55
+ framework ,
56
+ region ,
57
+ version = None ,
58
+ py_version = None ,
59
+ instance_type = None ,
60
+ accelerator_type = None ,
61
+ image_scope = None ,
62
+ container_version = None ,
63
+ distribution = None ,
64
+ base_framework_version = None ,
65
+ training_compiler_config = None ,
66
+ model_id = None ,
67
+ model_version = None ,
68
+ hub_arn = None ,
69
+ tolerate_vulnerable_model = False ,
70
+ tolerate_deprecated_model = False ,
71
+ sdk_version = None ,
72
+ inference_tool = None ,
73
+ serverless_inference_config = None ,
74
+ sagemaker_session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
75
+ config_name = None ,
76
+ model_type : JumpStartModelType = JumpStartModelType .OPEN_WEIGHTS ,
77
77
) -> str :
78
78
"""Retrieves the ECR URI for the Docker image matching the given arguments.
79
79
@@ -250,10 +250,10 @@ def retrieve(
250
250
if config .get ("version_aliases" ).get (original_version ):
251
251
_version = config .get ("version_aliases" )[original_version ]
252
252
if (
253
- config .get ("versions" , {})
254
- .get (_version , {})
255
- .get ("version_aliases" , {})
256
- .get (base_framework_version , {})
253
+ config .get ("versions" , {})
254
+ .get (_version , {})
255
+ .get ("version_aliases" , {})
256
+ .get (base_framework_version , {})
257
257
):
258
258
_base_framework_version = config .get ("versions" )[_version ]["version_aliases" ][
259
259
base_framework_version
@@ -290,16 +290,16 @@ def retrieve(
290
290
291
291
292
292
def _get_image_tag (
293
- container_version ,
294
- distribution ,
295
- final_image_scope ,
296
- framework ,
297
- inference_tool ,
298
- instance_type ,
299
- processor ,
300
- py_version ,
301
- tag_prefix ,
302
- version ,
293
+ container_version ,
294
+ distribution ,
295
+ final_image_scope ,
296
+ framework ,
297
+ inference_tool ,
298
+ instance_type ,
299
+ processor ,
300
+ py_version ,
301
+ tag_prefix ,
302
+ version ,
303
303
):
304
304
"""Return image tag based on framework, container, and compute configuration(s)."""
305
305
instance_type_family = utils .get_instance_type_family (instance_type )
@@ -311,8 +311,8 @@ def _get_image_tag(
311
311
"instance type" ,
312
312
)
313
313
if (
314
- instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
315
- or final_image_scope == INFERENCE_GRAVITON
314
+ instance_type_family in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
315
+ or final_image_scope == INFERENCE_GRAVITON
316
316
):
317
317
version_to_arm64_tag_mapping = {
318
318
"xgboost" : {
@@ -330,7 +330,7 @@ def _get_image_tag(
330
330
tag = _format_tag (tag_prefix , processor , py_version , container_version , inference_tool )
331
331
332
332
if instance_type is not None and _should_auto_select_container_version (
333
- instance_type , distribution
333
+ instance_type , distribution
334
334
):
335
335
container_versions = {
336
336
"tensorflow-2.3-gpu-py37" : "cu110-ubuntu18.04-v3" ,
@@ -398,7 +398,7 @@ def _validate_instance_deprecation(framework, instance_type, version):
398
398
"""Check if instance type is deprecated for a certain framework with a certain version"""
399
399
if utils .get_instance_type_family (instance_type ) == "p2" :
400
400
if (framework == "pytorch" and Version (version ) >= Version ("1.13" )) or (
401
- framework == "tensorflow" and Version (version ) >= Version ("2.12" )
401
+ framework == "tensorflow" and Version (version ) >= Version ("2.12" )
402
402
):
403
403
raise ValueError (
404
404
"P2 instances have been deprecated for sagemaker jobs starting PyTorch 1.13 and TensorFlow 2.12"
@@ -411,17 +411,17 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
411
411
"""Validate if framework is supported for the instance_type"""
412
412
# Validate for Trainium allowed frameworks
413
413
if (
414
- instance_type is not None
415
- and "trn" in instance_type
416
- and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
414
+ instance_type is not None
415
+ and "trn" in instance_type
416
+ and framework not in TRAINIUM_ALLOWED_FRAMEWORKS
417
417
):
418
418
_validate_framework (framework , TRAINIUM_ALLOWED_FRAMEWORKS , "framework" , "Trainium" )
419
419
420
420
# Validate for Graviton allowed frameowrks
421
421
if (
422
- instance_type is not None
423
- and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
424
- and framework not in GRAVITON_ALLOWED_FRAMEWORKS
422
+ instance_type is not None
423
+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
424
+ and framework not in GRAVITON_ALLOWED_FRAMEWORKS
425
425
):
426
426
_validate_framework (framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton" )
427
427
@@ -436,8 +436,8 @@ def config_for_framework(framework):
436
436
def _get_final_image_scope (framework , instance_type , image_scope ):
437
437
"""Return final image scope based on provided framework and instance type."""
438
438
if (
439
- framework in GRAVITON_ALLOWED_FRAMEWORKS
440
- and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
439
+ framework in GRAVITON_ALLOWED_FRAMEWORKS
440
+ and utils .get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
441
441
):
442
442
return INFERENCE_GRAVITON
443
443
if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
@@ -635,16 +635,16 @@ def _format_tag(tag_prefix, processor, py_version, container_version, inference_
635
635
636
636
@override_pipeline_parameter_var
637
637
def get_training_image_uri (
638
- region ,
639
- framework ,
640
- framework_version = None ,
641
- py_version = None ,
642
- image_uri = None ,
643
- distribution = None ,
644
- compiler_config = None ,
645
- tensorflow_version = None ,
646
- pytorch_version = None ,
647
- instance_type = None ,
638
+ region ,
639
+ framework ,
640
+ framework_version = None ,
641
+ py_version = None ,
642
+ image_uri = None ,
643
+ distribution = None ,
644
+ compiler_config = None ,
645
+ tensorflow_version = None ,
646
+ pytorch_version = None ,
647
+ instance_type = None ,
648
648
) -> str :
649
649
"""Retrieves the image URI for training.
650
650
@@ -748,26 +748,28 @@ def get_base_python_image_uri(region, py_version="310") -> str:
748
748
return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo_and_tag )
749
749
750
750
751
- def get_latest_container_image (framework : str ,
752
- image_scope : Optional [str ] = None ,
753
- instance_type : Optional [str ] = None ,
754
- py_version : Optional [str ] = None ,
755
- region : str = "us-west-2" ,
756
- version : Optional [str ] = None ,
757
- accelerator_type = None ,
758
- container_version = None ,
759
- distribution = None ,
760
- base_framework_version = None ,
761
- training_compiler_config = None ,
762
- model_id = None ,
763
- model_version = None ,
764
- hub_arn = None ,
765
- sdk_version = None ,
766
- inference_tool = None ,
767
- serverless_inference_config = None ,
768
- config_name = None ,
769
- ) -> Tuple [str , str ]:
751
+ def get_latest_container_image (
752
+ framework : str ,
753
+ image_scope : Optional [str ] = None ,
754
+ instance_type : Optional [str ] = None ,
755
+ py_version : Optional [str ] = None ,
756
+ region : str = "us-west-2" ,
757
+ version : Optional [str ] = None ,
758
+ accelerator_type = None ,
759
+ container_version = None ,
760
+ distribution = None ,
761
+ base_framework_version = None ,
762
+ training_compiler_config = None ,
763
+ model_id = None ,
764
+ model_version = None ,
765
+ hub_arn = None ,
766
+ sdk_version = None ,
767
+ inference_tool = None ,
768
+ serverless_inference_config = None ,
769
+ config_name = None ,
770
+ ) -> Tuple [str , str ]:
770
771
"""Retrieves the latest container image URI
772
+
771
773
Args:
772
774
framework (str): The name of the framework or algorithm.
773
775
image_scope (str): The image type, i.e. what it is used for.
@@ -818,31 +820,34 @@ def get_latest_container_image(framework: str,
818
820
819
821
if not version :
820
822
version = _fetch_latest_version_from_config (framework_config , image_scope )
821
- image_uri = retrieve (framework = framework ,
822
- region = region ,
823
- version = version ,
824
- instance_type = instance_type ,
825
- py_version = py_version ,
826
- accelerator_type = accelerator_type ,
827
- image_scope = image_scope ,
828
- container_version = container_version ,
829
- distribution = distribution ,
830
- base_framework_version = base_framework_version ,
831
- training_compiler_config = training_compiler_config ,
832
- model_id = model_id ,
833
- model_version = model_version ,
834
- hub_arn = hub_arn ,
835
- sdk_version = sdk_version ,
836
- inference_tool = inference_tool ,
837
- serverless_inference_config = serverless_inference_config ,
838
- config_name = config_name
839
- )
823
+ image_uri = retrieve (
824
+ framework = framework ,
825
+ region = region ,
826
+ version = version ,
827
+ instance_type = instance_type ,
828
+ py_version = py_version ,
829
+ accelerator_type = accelerator_type ,
830
+ image_scope = image_scope ,
831
+ container_version = container_version ,
832
+ distribution = distribution ,
833
+ base_framework_version = base_framework_version ,
834
+ training_compiler_config = training_compiler_config ,
835
+ model_id = model_id ,
836
+ model_version = model_version ,
837
+ hub_arn = hub_arn ,
838
+ sdk_version = sdk_version ,
839
+ inference_tool = inference_tool ,
840
+ serverless_inference_config = serverless_inference_config ,
841
+ config_name = config_name ,
842
+ )
840
843
return image_uri , version
841
844
842
845
843
- def _fetch_latest_version_from_config (framework_config : dict ,
844
- image_scope : Optional [str ] = None ) -> Optional [str ]:
845
- """ Helper function to fetch the latest version as a string from a framework's config
846
+ def _fetch_latest_version_from_config (
847
+ framework_config : dict , image_scope : Optional [str ] = None
848
+ ) -> Optional [str ]:
849
+ """Helper function to fetch the latest version as a string from a framework's config
850
+
846
851
Args:
847
852
framework_config (dict): A framework config dict.
848
853
image_scope (str): Scope of the image, eg: training, inference
@@ -863,8 +868,11 @@ def _fetch_latest_version_from_config(framework_config: dict,
863
868
bottom_version = versions [- 1 ]
864
869
if top_version == "latest" or bottom_version == "latest" :
865
870
return None
866
- elif (image_scope is not None and image_scope in framework_config
867
- and "versions" in framework_config [image_scope ]):
871
+ elif (
872
+ image_scope is not None
873
+ and image_scope in framework_config
874
+ and "versions" in framework_config [image_scope ]
875
+ ):
868
876
versions = list (framework_config [image_scope ]["versions" ].keys ())
869
877
top_version = versions [0 ]
870
878
bottom_version = versions [- 1 ]
0 commit comments