Skip to content

Commit 570c678

Browse files
icywang86ruiRui Wang Napieralski
and
Rui Wang Napieralski
authored
change: refactor distribution config construction (#2099)
Co-authored-by: Rui Wang Napieralski <[email protected]>
1 parent 1d84c6b commit 570c678

File tree

4 files changed

+86
-70
lines changed

4 files changed

+86
-70
lines changed

src/sagemaker/estimator.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
UploadedCode,
5050
validate_source_dir,
5151
_region_supports_debugger,
52+
get_mp_parameters,
5253
)
5354
from sagemaker.inputs import TrainingInput
5455
from sagemaker.job import _Job
@@ -2539,6 +2540,50 @@ def transformer(
25392540
sagemaker_session=self.sagemaker_session,
25402541
)
25412542

2543+
def _distribution_configuration(self, distribution):
2544+
"""Returns a dict of distribution configurations.
2545+
2546+
Args:
2547+
distribution (dict): A dictionary with information on how to run distributed training.
2548+
2549+
Returns:
2550+
dict that
2551+
"""
2552+
distribution_config = {}
2553+
2554+
if "parameter_server" in distribution:
2555+
ps_enabled = distribution.get("parameter_server").get("enabled", False)
2556+
distribution_config[self.LAUNCH_PS_ENV_NAME] = ps_enabled
2557+
2558+
if "mpi" in distribution:
2559+
mpi_dict = distribution["mpi"]
2560+
mpi_enabled = mpi_dict.get("enabled", False)
2561+
distribution_config[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
2562+
2563+
if mpi_dict.get("processes_per_host"):
2564+
distribution_config[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
2565+
"processes_per_host"
2566+
)
2567+
2568+
distribution_config[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
2569+
"custom_mpi_options", ""
2570+
)
2571+
2572+
if get_mp_parameters(distribution):
2573+
distribution_config["mp_parameters"] = get_mp_parameters(distribution)
2574+
2575+
elif "modelparallel" in distribution.get("smdistributed", {}):
2576+
raise ValueError("Cannot use Model Parallelism without MPI enabled!")
2577+
2578+
if "smdistributed" in distribution:
2579+
# smdistributed strategy selected
2580+
smdistributed = distribution["smdistributed"]
2581+
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
2582+
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
2583+
distribution_config[self.INSTANCE_TYPE] = self.instance_type
2584+
2585+
return distribution_config
2586+
25422587

25432588
def _s3_uri_prefix(channel_name, s3_data):
25442589
"""Placeholder docstring"""

src/sagemaker/pytorch/estimator.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
validate_version_or_image_args,
2727
warn_if_parameter_server_with_multi_gpu,
2828
validate_smdistributed,
29-
get_mp_parameters,
3029
)
3130
from sagemaker.pytorch import defaults
3231
from sagemaker.pytorch.model import PyTorchModel
@@ -190,39 +189,9 @@ def __init__(
190189
def hyperparameters(self):
191190
"""Return hyperparameters used by your custom PyTorch code during model training."""
192191
hyperparameters = super(PyTorch, self).hyperparameters()
193-
additional_hyperparameters = {}
194-
195-
if "parameter_server" in self.distribution:
196-
ps_enabled = self.distribution.get("parameter_server").get("enabled", False)
197-
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled
198-
199-
if "mpi" in self.distribution:
200-
mpi_dict = self.distribution["mpi"]
201-
mpi_enabled = mpi_dict.get("enabled", False)
202-
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
203-
204-
if mpi_dict.get("processes_per_host"):
205-
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
206-
"processes_per_host"
207-
)
208-
209-
additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
210-
"custom_mpi_options", ""
211-
)
212-
213-
if get_mp_parameters(self.distribution):
214-
additional_hyperparameters["mp_parameters"] = get_mp_parameters(self.distribution)
215-
216-
elif "modelparallel" in self.distribution.get("smdistributed", {}):
217-
raise ValueError("Cannot use Model Parallelism without MPI enabled!")
218-
219-
if "smdistributed" in self.distribution:
220-
# smdistributed strategy selected
221-
smdistributed = self.distribution["smdistributed"]
222-
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
223-
additional_hyperparameters[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
224-
additional_hyperparameters[self.INSTANCE_TYPE] = self.instance_type
225-
192+
additional_hyperparameters = self._distribution_configuration(
193+
distribution=self.distribution
194+
)
226195
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
227196
return hyperparameters
228197

src/sagemaker/tensorflow/estimator.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -320,44 +320,12 @@ def create_model(
320320
def hyperparameters(self):
321321
"""Return hyperparameters used by your custom TensorFlow code during model training."""
322322
hyperparameters = super(TensorFlow, self).hyperparameters()
323-
additional_hyperparameters = {}
324-
325-
if "parameter_server" in self.distribution:
326-
ps_enabled = self.distribution["parameter_server"].get("enabled", False)
327-
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled
328-
329-
mpi_enabled = False
330-
if "mpi" in self.distribution:
331-
mpi_dict = self.distribution["mpi"]
332-
mpi_enabled = mpi_dict.get("enabled", False)
333-
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
334-
335-
if mpi_dict.get("processes_per_host"):
336-
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
337-
"processes_per_host"
338-
)
339-
340-
additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
341-
"custom_mpi_options", ""
342-
)
343-
344-
if fw.get_mp_parameters(self.distribution):
345-
additional_hyperparameters["mp_parameters"] = fw.get_mp_parameters(
346-
self.distribution
347-
)
348-
349-
elif "modelparallel" in self.distribution.get("smdistributed", {}):
350-
raise ValueError("Cannot use Model Parallelism without MPI enabled!")
351-
352-
if "smdistributed" in self.distribution:
353-
# smdistributed strategy selected
354-
smdistributed = self.distribution["smdistributed"]
355-
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
356-
additional_hyperparameters[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
357-
additional_hyperparameters[self.INSTANCE_TYPE] = self.instance_type
323+
additional_hyperparameters = self._distribution_configuration(self.distribution)
358324

359325
if self.model_dir is not False:
360-
self.model_dir = self.model_dir or self._default_s3_path("model", mpi=mpi_enabled)
326+
self.model_dir = self.model_dir or self._default_s3_path(
327+
"model", mpi=additional_hyperparameters.get(self.LAUNCH_MPI_ENV_NAME, False)
328+
)
361329
additional_hyperparameters["model_dir"] = self.model_dir
362330

363331
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))

tests/unit/test_estimator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@
115115

116116
LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]}
117117

118+
DISTRIBUTION_PS_ENABLED = {"parameter_server": {"enabled": True}}
119+
DISTRIBUTION_MPI_ENABLED = {
120+
"mpi": {"enabled": True, "custom_mpi_options": "options", "processes_per_host": 2}
121+
}
122+
DISTRIBUTION_SM_DDP_ENABLED = {"smdistributed": {"dataparallel": {"enabled": True}}}
123+
118124

119125
class DummyFramework(Framework):
120126
_framework_name = "dummy"
@@ -3209,3 +3215,31 @@ def test_estimator_local_mode_ok(sagemaker_local_session):
32093215
sagemaker_session=sagemaker_local_session,
32103216
base_job_name="base_job_name",
32113217
)
3218+
3219+
3220+
def test_framework_distribution_configuration(sagemaker_session):
3221+
framework = DummyFramework(
3222+
entry_point="script",
3223+
role=ROLE,
3224+
sagemaker_session=sagemaker_session,
3225+
instance_count=INSTANCE_COUNT,
3226+
instance_type=INSTANCE_TYPE,
3227+
)
3228+
actual_ps = framework._distribution_configuration(distribution=DISTRIBUTION_PS_ENABLED)
3229+
expected_ps = {"sagemaker_parameter_server_enabled": True}
3230+
assert actual_ps == expected_ps
3231+
3232+
actual_mpi = framework._distribution_configuration(distribution=DISTRIBUTION_MPI_ENABLED)
3233+
expected_mpi = {
3234+
"sagemaker_mpi_enabled": True,
3235+
"sagemaker_mpi_num_of_processes_per_host": 2,
3236+
"sagemaker_mpi_custom_mpi_options": "options",
3237+
}
3238+
assert actual_mpi == expected_mpi
3239+
3240+
actual_ddp = framework._distribution_configuration(distribution=DISTRIBUTION_SM_DDP_ENABLED)
3241+
expected_ddp = {
3242+
"sagemaker_distributed_dataparallel_enabled": True,
3243+
"sagemaker_instance_type": INSTANCE_TYPE,
3244+
}
3245+
assert actual_ddp == expected_ddp

0 commit comments

Comments
 (0)