Skip to content

Commit 220236f

Browse files
[Train] Add PyTorch local mode support for multi-process training with torchrun (#56218)
This PR extends the Ray Train v2 local mode support (from #55487) to enable users to launch multiple local mode processes using torchrun for PyTorch distributed training. **With this new feature, users can easily switch between torchrun and Ray Train without modifying their training code.** <img width="1249" height="811" alt="image" src="https://github.com/user-attachments/assets/5d998b5e-8f58-425a-b535-d4f4d0b64a5c" /> ### Note Ray data on multiple processes is not supported. Might need to wait for #55114 or similar components. ## Key Changes ### Multi-Process Local Mode Support - **`LocalTorchController`**: New controller that detects torchrun env variables and sets contexts accordingly - **Torchrun Integration**: Users can now launch multiple local mode processes using `torchrun` command - **Environment Detection**: Automatically detects torchrun environment variables and initializes distributed training ## Usage Example ```python import os import tempfile import torch from torch.nn import CrossEntropyLoss from torch.optim import Adam from torch.utils.data import DataLoader from torchvision.models import resnet18 from torchvision.datasets import FashionMNIST from torchvision.transforms import ToTensor, Normalize, Compose import ray from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig from ray.train.torch import TorchTrainer from ray.train.v2.api.config import FailureConfig import ray.train.torch def train_func(): # Model, Loss, Optimizer model = resnet18(num_classes=10) model.conv1 = torch.nn.Conv2d( 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False ) # [1] Prepare model. model = ray.train.torch.prepare_model(model) criterion = CrossEntropyLoss() optimizer = Adam(model.parameters(), lr=0.001) # Data transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))]) data_dir = os.path.join(tempfile.gettempdir(), "data") train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform) train_loader = DataLoader(train_data, batch_size=128, shuffle=True) # [2] Prepare dataloader. train_loader = ray.train.torch.prepare_data_loader(train_loader) # Training for epoch in range(10): if ray.train.get_context().get_world_size() > 1: train_loader.sampler.set_epoch(epoch) for images, labels in train_loader: outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() # [3] Report metrics and checkpoint. metrics = {"loss": loss.item(), "epoch": epoch} with tempfile.TemporaryDirectory() as temp_checkpoint_dir: torch.save( model.state_dict(), os.path.join(temp_checkpoint_dir, "model.pt") ) ray.train.report( metrics, checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir), ) if ray.train.get_context().get_world_rank() == 0: print(metrics) # Configuration for local mode use_gpu = True scaling_config = ScalingConfig(num_workers=0, use_gpu=use_gpu) # Local mode run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1)) # Note: Ray Data not supported with multiple processes in local mode # For multi-process training, use PyTorch DataLoader as shown above # Initialize the Trainer trainer = TorchTrainer( train_loop_per_worker=train_func, scaling_config=scaling_config, run_config=run_config, ) # Train the model result = trainer.fit() ``` ### Running Options: ```bash # Option 1: Single process local mode RAY_TRAIN_V2_ENABLED=1 python test.py # Option 2: Multi-process local mode with torchrun RAY_TRAIN_V2_ENABLED=1 torchrun --standalone --nnodes=1 --nproc-per-node=4 test.py # Option 3: Switch to distributed Ray Train (change num_workers=4) # Same training code works across all modes! ``` --------- Signed-off-by: xgui <xgui@anyscale.com> Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com> Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
1 parent 416e365 commit 220236f

File tree

8 files changed

+218
-8
lines changed

8 files changed

+218
-8
lines changed

python/ray/train/v2/_internal/execution/local_mode/__init__.py

Whitespace-only changes.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import logging
2+
import os
3+
from typing import Callable
4+
5+
import torch
6+
import torch.distributed as dist
7+
8+
from ray.train import Result
9+
from ray.train.v2._internal.execution.local_mode.utils import LocalController
10+
from ray.train.v2._internal.execution.train_fn_utils import (
11+
LocalTrainFnUtils,
12+
get_train_fn_utils,
13+
set_train_fn_utils,
14+
)
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
def has_torchrun_env() -> bool:
20+
"""Return True if this process has torch.distributed env vars set.
21+
22+
For torch.distributed.init_process_group with init_method="env://", these variables are required:
23+
- RANK: The rank of the current process
24+
- LOCAL_RANK: The local rank of the current process
25+
- WORLD_SIZE: Total number of processes participating in the job
26+
- LOCAL_WORLD_SIZE: Total number of processes participating in the job on the current node
27+
- MASTER_ADDR: The IP address or hostname of the master node (rank 0)
28+
- MASTER_PORT: A free port on the master node for communication
29+
30+
"""
31+
torch_dist_required_vars = {
32+
"RANK",
33+
"LOCAL_RANK",
34+
"WORLD_SIZE",
35+
"LOCAL_WORLD_SIZE",
36+
"MASTER_ADDR",
37+
"MASTER_PORT",
38+
}
39+
40+
return torch_dist_required_vars.issubset(os.environ.keys())
41+
42+
43+
class LocalTorchController(LocalController):
44+
def _set_train_fn_utils(self) -> None:
45+
world_size = 1
46+
global_rank = 0
47+
local_rank = 0
48+
nproc_per_node = 1
49+
node_rank = 0
50+
if has_torchrun_env():
51+
assert not dist.is_initialized(), "torch.distributed is already initialized"
52+
torch.distributed.init_process_group(
53+
backend="nccl" if torch.cuda.is_available() else "gloo"
54+
)
55+
world_size = torch.distributed.get_world_size()
56+
global_rank = torch.distributed.get_rank()
57+
local_rank = int(os.environ["LOCAL_RANK"])
58+
if torch.cuda.is_available():
59+
torch.cuda.set_device(local_rank)
60+
nproc_per_node = int(os.environ.get("LOCAL_WORLD_SIZE"))
61+
node_rank = global_rank // nproc_per_node
62+
63+
if world_size != 1:
64+
assert (
65+
self.datasets is None or len(self.datasets) == 0
66+
), "Ray Data is not supported in local mode with multiple workers."
67+
set_train_fn_utils(
68+
LocalTrainFnUtils(
69+
experiment_name=self.experiment_name,
70+
world_size=world_size,
71+
world_rank=global_rank,
72+
local_rank=local_rank,
73+
local_world_size=nproc_per_node,
74+
node_rank=node_rank,
75+
dataset_shards=self.datasets,
76+
)
77+
)
78+
79+
def run(self, train_func: Callable[[], None]) -> Result:
80+
self._set_train_fn_utils()
81+
train_func()
82+
train_fn_utils = get_train_fn_utils()
83+
assert isinstance(train_fn_utils, LocalTrainFnUtils)
84+
result = Result(
85+
metrics=train_fn_utils._get_last_metrics(),
86+
checkpoint=train_fn_utils.get_checkpoint(),
87+
path=None,
88+
error=None,
89+
)
90+
if dist.is_initialized():
91+
dist.destroy_process_group()
92+
return result

python/ray/train/v2/_internal/execution/local_mode_utils.py renamed to python/ray/train/v2/_internal/execution/local_mode/utils.py

File renamed without changes.

python/ray/train/v2/_internal/execution/train_fn_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,19 @@ def __init__(
166166
self,
167167
experiment_name: str,
168168
dataset_shards: Optional[Dict[str, DataIterator]] = None,
169+
world_size: int = 1,
170+
world_rank: int = 0,
171+
local_rank: int = 0,
172+
local_world_size: int = 1,
173+
node_rank: int = 0,
169174
):
170175
self._context = LocalTrainContext(
171176
experiment_name=experiment_name,
177+
world_size=world_size,
178+
world_rank=world_rank,
179+
local_rank=local_rank,
180+
local_world_size=local_world_size,
181+
node_rank=node_rank,
172182
)
173183
self._dataset_shards = dataset_shards
174184
self._last_metrics = None

python/ray/train/v2/api/context.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,27 +253,36 @@ class LocalTrainContext(TrainContext):
253253
def __init__(
254254
self,
255255
experiment_name: str,
256+
world_size: int = 1,
257+
world_rank: int = 0,
258+
local_rank: int = 0,
259+
local_world_size: int = 1,
260+
node_rank: int = 0,
256261
):
257262
self.experiment_name = experiment_name
263+
self.world_size = world_size
264+
self.world_rank = world_rank
265+
self.local_rank = local_rank
266+
self.local_world_size = local_world_size
267+
self.node_rank = node_rank
258268

259269
def get_experiment_name(self) -> str:
260270
return self.experiment_name
261271

262272
def get_world_size(self) -> int:
263-
return 1
273+
return self.world_size
264274

265275
def get_world_rank(self) -> int:
266-
return 0
276+
return self.world_rank
267277

268278
def get_local_rank(self) -> int:
269-
return 0
279+
return self.local_rank
270280

271281
def get_local_world_size(self) -> int:
272-
return 1
282+
return self.local_world_size
273283

274284
def get_node_rank(self) -> int:
275-
"""For local mode, we only use one node."""
276-
return 0
285+
return self.node_rank
277286

278287
def get_storage(self):
279288
raise NotImplementedError("Local storage context not yet implemented. ")

python/ray/train/v2/api/data_parallel_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from ray.train.v2._internal.execution.context import TrainRunContext
4747
from ray.train.v2._internal.execution.controller import TrainController
4848
from ray.train.v2._internal.execution.failure_handling import create_failure_policy
49-
from ray.train.v2._internal.execution.local_mode_utils import LocalController
49+
from ray.train.v2._internal.execution.local_mode.utils import LocalController
5050
from ray.train.v2._internal.execution.scaling_policy import create_scaling_policy
5151
from ray.train.v2._internal.util import ObjectRefWrapper, construct_train_func
5252
from ray.train.v2.api.callback import UserCallback

python/ray/train/v2/tests/test_local_mode.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
2+
import os
23
import sys
3-
from unittest.mock import MagicMock
4+
from unittest.mock import MagicMock, patch
45

56
import lightgbm
67
import pandas as pd
@@ -38,6 +39,8 @@
3839
from ray.train.tests.lightning_test_utils import DummyDataModule, LinearModule
3940
from ray.train.tests.util import create_dict_checkpoint
4041
from ray.train.torch import TorchTrainer
42+
from ray.train.v2._internal.execution.local_mode.torch import LocalTorchController
43+
from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
4144
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
4245
from ray.train.v2.jax import JaxTrainer
4346
from ray.train.xgboost import (
@@ -522,5 +525,94 @@ def xgboost_train_fn_per_worker():
522525
XGBoostTrainer.get_model(result.checkpoint)
523526

524527

528+
def test_torch_distributed_variables_local_train_fn_utils():
529+
"""Test that torch distributed variables are correctly used to create LocalTrainFnUtils."""
530+
531+
# Test scenario 1: Without torch distributed environment variables
532+
with patch.dict(os.environ, {}, clear=True):
533+
controller = LocalTorchController("test_experiment")
534+
535+
def dummy_train_func():
536+
train_fn_utils = get_train_fn_utils()
537+
# Verify default values when no torch distributed env vars are set
538+
context = train_fn_utils.get_context()
539+
assert context.get_world_size() == 1
540+
assert context.get_world_rank() == 0
541+
assert context.get_local_rank() == 0
542+
assert context.get_local_world_size() == 1
543+
assert context.get_node_rank() == 0
544+
545+
controller.run(dummy_train_func)
546+
547+
# Test scenario 2: With torch distributed environment variables (CPU)
548+
torch_env_vars = {
549+
"RANK": "2",
550+
"LOCAL_RANK": "1",
551+
"WORLD_SIZE": "4",
552+
"LOCAL_WORLD_SIZE": "2",
553+
"MASTER_ADDR": "127.0.0.1",
554+
"MASTER_PORT": "29500",
555+
}
556+
557+
with patch.dict(os.environ, torch_env_vars, clear=True), patch(
558+
"torch.distributed.is_initialized", return_value=False
559+
), patch("torch.distributed.get_world_size", return_value=4), patch(
560+
"torch.distributed.get_rank", return_value=2
561+
), patch(
562+
"torch.cuda.is_available", return_value=False
563+
), patch(
564+
"torch.distributed.init_process_group"
565+
) as mock_init_pg:
566+
567+
controller = LocalTorchController("test_experiment")
568+
569+
def dummy_train_func():
570+
train_fn_utils = get_train_fn_utils()
571+
# Verify torch distributed values are correctly passed
572+
context = train_fn_utils.get_context()
573+
assert context.get_world_size() == 4
574+
assert context.get_world_rank() == 2
575+
assert context.get_local_rank() == 1
576+
assert context.get_local_world_size() == 2
577+
assert (
578+
context.get_node_rank() == 1
579+
) # global_rank // nproc_per_node = 2 // 2 = 1
580+
581+
controller.run(dummy_train_func)
582+
583+
# Verify torch.distributed methods were called with CPU backend
584+
mock_init_pg.assert_called_once_with(backend="gloo")
585+
586+
# Test scenario 3: With torch distributed environment variables (GPU)
587+
with patch.dict(os.environ, torch_env_vars, clear=True), patch(
588+
"torch.distributed.is_initialized", return_value=False
589+
), patch("torch.distributed.get_world_size", return_value=4), patch(
590+
"torch.distributed.get_rank", return_value=2
591+
), patch(
592+
"torch.cuda.is_available", return_value=True
593+
), patch(
594+
"torch.distributed.init_process_group"
595+
) as mock_init_pg, patch(
596+
"torch.cuda.set_device"
597+
) as mock_set_device:
598+
599+
controller = LocalTorchController("test_experiment")
600+
601+
def dummy_train_func():
602+
train_fn_utils = get_train_fn_utils()
603+
# Verify torch distributed values are correctly passed
604+
context = train_fn_utils.get_context()
605+
assert context.get_world_size() == 4
606+
assert context.get_world_rank() == 2
607+
assert context.get_local_rank() == 1
608+
assert context.get_local_world_size() == 2
609+
assert context.get_node_rank() == 1
610+
611+
controller.run(dummy_train_func)
612+
613+
mock_init_pg.assert_called_once_with(backend="nccl")
614+
mock_set_device.assert_called_once_with(1)
615+
616+
525617
if __name__ == "__main__":
526618
sys.exit(pytest.main(["-v", "-x", __file__]))

python/ray/train/v2/torch/torch_trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ray.train import Checkpoint, DataConfig
44
from ray.train.trainer import GenDataset
5+
from ray.train.v2._internal.execution.local_mode.torch import LocalTorchController
56
from ray.train.v2.api.config import RunConfig, ScalingConfig
67
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
78
from ray.util import PublicAPI
@@ -213,3 +214,9 @@ def __init__(
213214
resume_from_checkpoint=resume_from_checkpoint,
214215
metadata=metadata,
215216
)
217+
218+
def _get_local_controller(self) -> LocalTorchController:
219+
return LocalTorchController(
220+
experiment_name=self.run_config.name,
221+
datasets=self.datasets,
222+
)

0 commit comments

Comments
 (0)