|
145 | 145 | ],
|
146 | 146 | }
|
147 | 147 |
|
148 |
| -PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ |
149 |
| - "1.10", |
150 |
| - "1.10.0", |
151 |
| - "1.10.2", |
152 |
| - "1.11", |
153 |
| - "1.11.0", |
154 |
| - "1.12", |
155 |
| - "1.12.0", |
156 |
| - "1.12.1", |
157 |
| - "1.13.1", |
158 |
| - "2.0.0", |
159 |
| - "2.0.1", |
160 |
| - "2.1.0", |
161 |
| - "2.2.0", |
162 |
| -] |
163 |
| - |
164 | 148 | TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
|
165 | 149 | "1.13.1",
|
166 | 150 | "2.0.0",
|
@@ -915,13 +899,6 @@ def validate_distribution(
|
915 | 899 | )
|
916 | 900 | if framework_name and framework_name == "pytorch":
|
917 | 901 | # We need to validate only for PyTorch framework
|
918 |
| - validate_pytorch_distribution( |
919 |
| - distribution=validated_distribution, |
920 |
| - framework_name=framework_name, |
921 |
| - framework_version=framework_version, |
922 |
| - py_version=py_version, |
923 |
| - image_uri=image_uri, |
924 |
| - ) |
925 | 902 | validate_torch_distributed_distribution(
|
926 | 903 | instance_type=instance_type,
|
927 | 904 | distribution=validated_distribution,
|
@@ -955,13 +932,6 @@ def validate_distribution(
|
955 | 932 | )
|
956 | 933 | if framework_name and framework_name == "pytorch":
|
957 | 934 | # We need to validate only for PyTorch framework
|
958 |
| - validate_pytorch_distribution( |
959 |
| - distribution=validated_distribution, |
960 |
| - framework_name=framework_name, |
961 |
| - framework_version=framework_version, |
962 |
| - py_version=py_version, |
963 |
| - image_uri=image_uri, |
964 |
| - ) |
965 | 935 | validate_torch_distributed_distribution(
|
966 | 936 | instance_type=instance_type,
|
967 | 937 | distribution=validated_distribution,
|
@@ -1010,63 +980,6 @@ def validate_distribution_for_instance_type(instance_type, distribution):
|
1010 | 980 | raise ValueError(err_msg)
|
1011 | 981 |
|
1012 | 982 |
|
1013 |
| -def validate_pytorch_distribution( |
1014 |
| - distribution, framework_name, framework_version, py_version, image_uri |
1015 |
| -): |
1016 |
| - """Check if pytorch distribution strategy is correctly invoked by the user. |
1017 |
| -
|
1018 |
| - Args: |
1019 |
| - distribution (dict): A dictionary with information to enable distributed training. |
1020 |
| - (Defaults to None if distributed training is not enabled.) For example: |
1021 |
| -
|
1022 |
| - .. code:: python |
1023 |
| -
|
1024 |
| - { |
1025 |
| - "pytorchddp": { |
1026 |
| - "enabled": True |
1027 |
| - } |
1028 |
| - } |
1029 |
| - framework_name (str): A string representing the name of framework selected. |
1030 |
| - framework_version (str): A string representing the framework version selected. |
1031 |
| - py_version (str): A string representing the python version selected. |
1032 |
| - image_uri (str): A string representing a Docker image URI. |
1033 |
| -
|
1034 |
| - Raises: |
1035 |
| - ValueError: if |
1036 |
| - `py_version` is not python3 or |
1037 |
| - `framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS |
1038 |
| - """ |
1039 |
| - if framework_name and framework_name != "pytorch": |
1040 |
| - # We need to validate only for PyTorch framework |
1041 |
| - return |
1042 |
| - |
1043 |
| - pytorch_ddp_enabled = False |
1044 |
| - if "pytorchddp" in distribution: |
1045 |
| - pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) |
1046 |
| - if not pytorch_ddp_enabled: |
1047 |
| - # Distribution strategy other than pytorchddp is selected |
1048 |
| - return |
1049 |
| - |
1050 |
| - err_msg = "" |
1051 |
| - if not image_uri: |
1052 |
| - # ignore framework_version and py_version if image_uri is set |
1053 |
| - # in case image_uri is not set, then both are mandatory |
1054 |
| - if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS: |
1055 |
| - err_msg += ( |
1056 |
| - f"Provided framework_version {framework_version} is not supported by" |
1057 |
| - " pytorchddp.\n" |
1058 |
| - "Please specify one of the supported framework versions:" |
1059 |
| - f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n" |
1060 |
| - ) |
1061 |
| - if "py3" not in py_version: |
1062 |
| - err_msg += ( |
1063 |
| - f"Provided py_version {py_version} is not supported by pytorchddp.\n" |
1064 |
| - "Please specify py_version>=py3" |
1065 |
| - ) |
1066 |
| - if err_msg: |
1067 |
| - raise ValueError(err_msg) |
1068 |
| - |
1069 |
| - |
1070 | 983 | def validate_torch_distributed_distribution(
|
1071 | 984 | instance_type,
|
1072 | 985 | distribution,
|
|
0 commit comments