Skip to content

Convert pytorchddp distribution to smdistributed distribution #4698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 2 additions & 97 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,6 @@
],
}

PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [
"1.10",
"1.10.0",
"1.10.2",
"1.11",
"1.11.0",
"1.12",
"1.12.0",
"1.12.1",
"1.13.1",
"2.0.0",
"2.0.1",
"2.1.0",
"2.2.0",
]

TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
"1.13.1",
"2.0.0",
Expand Down Expand Up @@ -795,7 +779,6 @@ def _validate_smdataparallel_args(

Raises:
ValueError: if
(`instance_type` is not in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES or
`py_version` is not python3 or
`framework_version` is not in SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSION
"""
Expand All @@ -806,17 +789,10 @@ def _validate_smdataparallel_args(
if not smdataparallel_enabled:
return

is_instance_type_supported = instance_type in SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES

err_msg = ""

if not is_instance_type_supported:
# instance_type is required
err_msg += (
f"Provided instance_type {instance_type} is not supported by smdataparallel.\n"
"Please specify one of the supported instance types:"
f"{SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES}\n"
)
if not instance_type:
err_msg += "Please specify an instance_type for smdataparallel.\n"

if not image_uri:
# ignore framework_version & py_version if image_uri is set
Expand Down Expand Up @@ -928,13 +904,6 @@ def validate_distribution(
)
if framework_name and framework_name == "pytorch":
# We need to validate only for PyTorch framework
validate_pytorch_distribution(
distribution=validated_distribution,
framework_name=framework_name,
framework_version=framework_version,
py_version=py_version,
image_uri=image_uri,
)
validate_torch_distributed_distribution(
instance_type=instance_type,
distribution=validated_distribution,
Expand Down Expand Up @@ -968,13 +937,6 @@ def validate_distribution(
)
if framework_name and framework_name == "pytorch":
# We need to validate only for PyTorch framework
validate_pytorch_distribution(
distribution=validated_distribution,
framework_name=framework_name,
framework_version=framework_version,
py_version=py_version,
image_uri=image_uri,
)
validate_torch_distributed_distribution(
instance_type=instance_type,
distribution=validated_distribution,
Expand Down Expand Up @@ -1023,63 +985,6 @@ def validate_distribution_for_instance_type(instance_type, distribution):
raise ValueError(err_msg)


def validate_pytorch_distribution(
distribution, framework_name, framework_version, py_version, image_uri
):
"""Check if pytorch distribution strategy is correctly invoked by the user.

Args:
distribution (dict): A dictionary with information to enable distributed training.
(Defaults to None if distributed training is not enabled.) For example:

.. code:: python

{
"pytorchddp": {
"enabled": True
}
}
framework_name (str): A string representing the name of framework selected.
framework_version (str): A string representing the framework version selected.
py_version (str): A string representing the python version selected.
image_uri (str): A string representing a Docker image URI.

Raises:
ValueError: if
`py_version` is not python3 or
`framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS
"""
if framework_name and framework_name != "pytorch":
# We need to validate only for PyTorch framework
return

pytorch_ddp_enabled = False
if "pytorchddp" in distribution:
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
if not pytorch_ddp_enabled:
# Distribution strategy other than pytorchddp is selected
return

err_msg = ""
if not image_uri:
# ignore framework_version and py_version if image_uri is set
# in case image_uri is not set, then both are mandatory
if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS:
err_msg += (
f"Provided framework_version {framework_version} is not supported by"
" pytorchddp.\n"
"Please specify one of the supported framework versions:"
f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n"
)
if "py3" not in py_version:
err_msg += (
f"Provided py_version {py_version} is not supported by pytorchddp.\n"
"Please specify py_version>=py3"
)
if err_msg:
raise ValueError(err_msg)


def validate_torch_distributed_distribution(
instance_type,
distribution,
Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,20 @@ def __init__(
kwargs["entry_point"] = entry_point

if distribution is not None:
# rewrite pytorchddp to smdistributed
if "pytorchddp" in distribution:
if "smdistributed" in distribution:
raise ValueError(
"Cannot use both pytorchddp and smdistributed "
"distribution options together.",
distribution,
)

# convert pytorchddp distribution into smdistributed distribution
distribution = distribution.copy()
distribution["smdistributed"] = {"dataparallel": distribution["pytorchddp"]}
del distribution["pytorchddp"]

distribution = validate_distribution(
distribution,
self.instance_groups,
Expand Down
75 changes: 2 additions & 73 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,17 +854,14 @@ def test_validate_smdataparallel_args_raises():

# Cases {PT|TF2}
# 1. None instance type
# 2. incorrect instance type
# 3. incorrect python version
# 4. incorrect framework version
# 2. incorrect python version
# 3. incorrect framework version

bad_args = [
(None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
("ml.p3.2xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled),
("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled),
(None, "pytorch", "1.6.0", "py3", smdataparallel_enabled),
("ml.p3.2xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled),
("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled),
]
Expand Down Expand Up @@ -966,74 +963,6 @@ def test_validate_smdataparallel_args_not_raises():
)


def test_validate_pytorchddp_not_raises():
# Case 1: Framework is not PyTorch
fw_utils.validate_pytorch_distribution(
distribution=None,
framework_name="tensorflow",
framework_version="2.9.1",
py_version="py3",
image_uri="custom-container",
)
# Case 2: Framework is PyTorch, but distribution is not PyTorchDDP
pytorchddp_disabled = {"pytorchddp": {"enabled": False}}
fw_utils.validate_pytorch_distribution(
distribution=pytorchddp_disabled,
framework_name="pytorch",
framework_version="1.10",
py_version="py3",
image_uri="custom-container",
)
# Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
pytorchddp_supported_fw_versions = [
"1.10",
"1.10.0",
"1.10.2",
"1.11",
"1.11.0",
"1.12",
"1.12.0",
"1.12.1",
"1.13.1",
"2.0.0",
"2.0.1",
"2.1.0",
"2.2.0",
]
for framework_version in pytorchddp_supported_fw_versions:
fw_utils.validate_pytorch_distribution(
distribution=pytorchddp_enabled,
framework_name="pytorch",
framework_version=framework_version,
py_version="py3",
image_uri="custom-container",
)


def test_validate_pytorchddp_raises():
pytorchddp_enabled = {"pytorchddp": {"enabled": True}}
# Case 1: Unsupported framework version
with pytest.raises(ValueError):
fw_utils.validate_pytorch_distribution(
distribution=pytorchddp_enabled,
framework_name="pytorch",
framework_version="1.8",
py_version="py3",
image_uri=None,
)

# Case 2: Unsupported Py version
with pytest.raises(ValueError):
fw_utils.validate_pytorch_distribution(
distribution=pytorchddp_enabled,
framework_name="pytorch",
framework_version="1.10",
py_version="py2",
image_uri=None,
)


def test_validate_torch_distributed_not_raises():
# Case 1: Framework is PyTorch, but torch_distributed is not enabled
torch_distributed_disabled = {"torch_distributed": {"enabled": False}}
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,14 +801,15 @@ def test_pytorch_ddp_distribution_configuration(
distribution=pytorch.distribution
)
expected_torch_ddp = {
"sagemaker_pytorch_ddp_enabled": True,
"sagemaker_distributed_dataparallel_enabled": True,
"sagemaker_distributed_dataparallel_custom_mpi_options": "",
"sagemaker_instance_type": test_instance_type,
}
assert actual_pytorch_ddp == expected_torch_ddp


def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session):
unsupported_framework_version = "1.9.1"
unsupported_framework_version = "1.5.0"
unsupported_py_version = "py2"
with pytest.raises(ValueError) as error:
_pytorch_estimator(
Expand Down