diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 992cafcbea..17f17f1a13 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -30,7 +30,8 @@ from sagemaker.xgboost.defaults import ( XGBOOST_LATEST_VERSION, XGBOOST_SUPPORTED_VERSIONS, - XGBOOST_VERSION_1, + XGBOOST_VERSION_0_90_1, + XGBOOST_VERSION_0_90, XGBOOST_VERSION_EQUIVALENTS, ) from sagemaker.xgboost.estimator import get_xgboost_image_uri @@ -626,8 +627,10 @@ def get_image_uri(region_name, repo_name, repo_version=1): XGBOOST_LATEST_VERSION, ) - if repo_version in ["0.90", "0.90-1", "0.90-1-cpu-py3"]: - return get_xgboost_image_uri(region_name, XGBOOST_VERSION_1) + if repo_version in [XGBOOST_VERSION_0_90] + _generate_version_equivalents( + XGBOOST_VERSION_0_90_1 + ): + return get_xgboost_image_uri(region_name, XGBOOST_VERSION_0_90_1) supported_version = [ version diff --git a/src/sagemaker/xgboost/defaults.py b/src/sagemaker/xgboost/defaults.py index 9c957f4929..c287626943 100644 --- a/src/sagemaker/xgboost/defaults.py +++ b/src/sagemaker/xgboost/defaults.py @@ -14,7 +14,13 @@ from __future__ import absolute_import XGBOOST_NAME = "xgboost" -XGBOOST_VERSION_1 = "0.90-1" -XGBOOST_LATEST_VERSION = "0.90-2" -XGBOOST_SUPPORTED_VERSIONS = [XGBOOST_VERSION_1, XGBOOST_LATEST_VERSION] +XGBOOST_VERSION_0_90 = "0.90" +XGBOOST_VERSION_0_90_1 = "0.90-1" +XGBOOST_VERSION_0_90_2 = "0.90-2" +XGBOOST_LATEST_VERSION = "1.0-1" +XGBOOST_SUPPORTED_VERSIONS = [ + XGBOOST_VERSION_0_90_1, + XGBOOST_VERSION_0_90_2, + XGBOOST_LATEST_VERSION, +] XGBOOST_VERSION_EQUIVALENTS = ["-cpu-py3"]