diff --git a/src/sagemaker/image_uri_config/pytorch-smp.json b/src/sagemaker/image_uri_config/pytorch-smp.json index 518da5f15d..61971e5128 100644 --- a/src/sagemaker/image_uri_config/pytorch-smp.json +++ b/src/sagemaker/image_uri_config/pytorch-smp.json @@ -7,7 +7,8 @@ "2.0": "2.0.1", "2.1": "2.1.2", "2.2": "2.3.1", - "2.2.0": "2.3.1" + "2.2.0": "2.3.1", + "2.3": "2.4.0" }, "versions": { "2.0.1": { @@ -134,6 +135,31 @@ "us-west-2": "658645717510" }, "repository": "smdistributed-modelparallel" + }, + "2.4.0": { + "py_versions": [ + "py311" + ], + "registries": { + "ap-northeast-1": "658645717510", + "ap-northeast-2": "658645717510", + "ap-northeast-3": "658645717510", + "ap-south-1": "658645717510", + "ap-southeast-1": "658645717510", + "ap-southeast-2": "658645717510", + "ca-central-1": "658645717510", + "eu-central-1": "658645717510", + "eu-north-1": "658645717510", + "eu-west-1": "658645717510", + "eu-west-2": "658645717510", + "eu-west-3": "658645717510", + "sa-east-1": "658645717510", + "us-east-1": "658645717510", + "us-east-2": "658645717510", + "us-west-1": "658645717510", + "us-west-2": "658645717510" + }, + "repository": "smdistributed-modelparallel" } } } diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index be5167dcc7..2bef305aeb 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -686,6 +686,7 @@ def get_training_image_uri( "p5" in instance_type or "2.1" in framework_version or "2.2" in framework_version + or "2.3" in framework_version ): container_version = "cu121" else: diff --git a/tests/unit/sagemaker/image_uris/test_smp_v2.py b/tests/unit/sagemaker/image_uris/test_smp_v2.py index b53a45133e..e9c8cec292 100644 --- a/tests/unit/sagemaker/image_uris/test_smp_v2.py +++ b/tests/unit/sagemaker/image_uris/test_smp_v2.py @@ -35,7 +35,7 @@ def test_smp_v2(load_config): for region in ACCOUNTS.keys(): for instance_type in CONTAINER_VERSIONS.keys(): cuda_vers = CONTAINER_VERSIONS[instance_type] - if "2.1" in version or "2.2" in version: + if "2.1" in version or "2.2" in version or "2.3" in version: cuda_vers = "cu121" uri = image_uris.get_training_image_uri(