|
26 | 26 | from sagemaker.inputs import FileSystemInput, TrainingInput |
27 | 27 | from sagemaker.model import NEO_IMAGE_ACCOUNT |
28 | 28 | from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix |
29 | | -from sagemaker.xgboost.defaults import ( |
30 | | - XGBOOST_1P_VERSIONS, |
31 | | - XGBOOST_LATEST_VERSION, |
32 | | - XGBOOST_NAME, |
33 | | - XGBOOST_SUPPORTED_VERSIONS, |
34 | | - XGBOOST_VERSION_EQUIVALENTS, |
35 | | -) |
36 | | -from sagemaker.xgboost.estimator import get_xgboost_image_uri |
37 | 29 |
|
38 | 30 | logger = logging.getLogger(__name__) |
39 | 31 |
|
@@ -622,76 +614,5 @@ def get_image_uri(region_name, repo_name, repo_version=1): |
622 | 614 | "in SageMaker Python SDK v2." |
623 | 615 | ) |
624 | 616 |
|
625 | | - repo_version = str(repo_version) |
626 | | - |
627 | | - if repo_name == XGBOOST_NAME: |
628 | | - |
629 | | - if repo_version in XGBOOST_1P_VERSIONS: |
630 | | - _warn_newer_xgboost_image() |
631 | | - return "{}/{}:{}".format(registry(region_name, repo_name), repo_name, repo_version) |
632 | | - |
633 | | - if "-" not in repo_version: |
634 | | - xgboost_version_matches = [ |
635 | | - version |
636 | | - for version in XGBOOST_SUPPORTED_VERSIONS |
637 | | - if repo_version == version.split("-")[0] |
638 | | - ] |
639 | | - if xgboost_version_matches: |
640 | | - # Assumes that XGBOOST_SUPPORTED_VERSION is sorted from oldest version to latest. |
641 | | - # When SageMaker version is not specified, we use the oldest one that matches |
642 | | - # XGBoost version for backward compatibility. |
643 | | - repo_version = xgboost_version_matches[0] |
644 | | - |
645 | | - supported_framework_versions = [ |
646 | | - version |
647 | | - for version in XGBOOST_SUPPORTED_VERSIONS |
648 | | - if repo_version in _generate_version_equivalents(version) |
649 | | - ] |
650 | | - |
651 | | - if not supported_framework_versions: |
652 | | - raise ValueError( |
653 | | - "SageMaker XGBoost version {} is not supported. Supported versions: {}".format( |
654 | | - repo_version, ", ".join(XGBOOST_SUPPORTED_VERSIONS) |
655 | | - ) |
656 | | - ) |
657 | | - |
658 | | - if not _is_latest_xgboost_version(repo_version): |
659 | | - _warn_newer_xgboost_image() |
660 | | - |
661 | | - return get_xgboost_image_uri(region_name, supported_framework_versions[-1]) |
662 | | - |
663 | 617 | repo = "{}:{}".format(repo_name, repo_version) |
664 | 618 | return "{}/{}".format(registry(region_name, repo_name), repo) |
665 | | - |
666 | | - |
667 | | -def _warn_newer_xgboost_image(): |
668 | | - """Print a warning when there is a newer XGBoost image""" |
669 | | - logging.warning( |
670 | | - "There is a more up to date SageMaker XGBoost image. " |
671 | | - "To use the newer image, please set 'repo_version'=" |
672 | | - "'%s'. For example:\n" |
673 | | - "\tget_image_uri(region, '%s', '%s').", |
674 | | - XGBOOST_LATEST_VERSION, |
675 | | - XGBOOST_NAME, |
676 | | - XGBOOST_LATEST_VERSION, |
677 | | - ) |
678 | | - |
679 | | - |
680 | | -def _is_latest_xgboost_version(repo_version): |
681 | | - """Compare xgboost image version with latest version |
682 | | -
|
683 | | - Args: |
684 | | - repo_version: |
685 | | - """ |
686 | | - if repo_version in XGBOOST_1P_VERSIONS: |
687 | | - return False |
688 | | - return repo_version in _generate_version_equivalents(XGBOOST_LATEST_VERSION) |
689 | | - |
690 | | - |
691 | | -def _generate_version_equivalents(version): |
692 | | - """Returns a list of version equivalents for XGBoost |
693 | | -
|
694 | | - Args: |
695 | | - version: |
696 | | - """ |
697 | | - return [version + suffix for suffix in XGBOOST_VERSION_EQUIVALENTS] + [version] |
0 commit comments