Commit 220236f
[Train] Add PyTorch local mode support for multi-process training with torchrun (#56218)
This PR extends the Ray Train v2 local mode support (from #55487) to
enable users to launch multiple local mode processes using torchrun for
PyTorch distributed training. **With this new feature, users can easily
switch between torchrun and Ray Train without modifying their training
code.**
<img width="1249" height="811" alt="image"
src="https://github.com/user-attachments/assets/5d998b5e-8f58-425a-b535-d4f4d0b64a5c"
/>
### Note
Ray data on multiple processes is not supported. Might need to wait for
#55114 or similar components.
## Key Changes
### Multi-Process Local Mode Support
- **`LocalTorchController`**: New controller that detects torchrun env
variables and sets contexts accordingly
- **Torchrun Integration**: Users can now launch multiple local mode
processes using `torchrun` command
- **Environment Detection**: Automatically detects torchrun environment
variables and initializes distributed training
## Usage Example
```python
import os
import tempfile
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
import ray
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.v2.api.config import FailureConfig
import ray.train.torch
def train_func():
# Model, Loss, Optimizer
model = resnet18(num_classes=10)
model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
# [1] Prepare model.
model = ray.train.torch.prepare_model(model)
criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
# Data
transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
data_dir = os.path.join(tempfile.gettempdir(), "data")
train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
# [2] Prepare dataloader.
train_loader = ray.train.torch.prepare_data_loader(train_loader)
# Training
for epoch in range(10):
if ray.train.get_context().get_world_size() > 1:
train_loader.sampler.set_epoch(epoch)
for images, labels in train_loader:
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# [3] Report metrics and checkpoint.
metrics = {"loss": loss.item(), "epoch": epoch}
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
torch.save(
model.state_dict(),
os.path.join(temp_checkpoint_dir, "model.pt")
)
ray.train.report(
metrics,
checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
)
if ray.train.get_context().get_world_rank() == 0:
print(metrics)
# Configuration for local mode
use_gpu = True
scaling_config = ScalingConfig(num_workers=0, use_gpu=use_gpu) # Local mode
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))
# Note: Ray Data not supported with multiple processes in local mode
# For multi-process training, use PyTorch DataLoader as shown above
# Initialize the Trainer
trainer = TorchTrainer(
train_loop_per_worker=train_func,
scaling_config=scaling_config,
run_config=run_config,
)
# Train the model
result = trainer.fit()
```
### Running Options:
```bash
# Option 1: Single process local mode
RAY_TRAIN_V2_ENABLED=1 python test.py
# Option 2: Multi-process local mode with torchrun
RAY_TRAIN_V2_ENABLED=1 torchrun --standalone --nnodes=1 --nproc-per-node=4 test.py
# Option 3: Switch to distributed Ray Train (change num_workers=4)
# Same training code works across all modes!
```
---------
Signed-off-by: xgui <xgui@anyscale.com>
Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com>
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>1 parent 416e365 commit 220236f
File tree
8 files changed
+218
-8
lines changed- python/ray/train/v2
- _internal/execution
- local_mode
- api
- tests
- torch
8 files changed
+218
-8
lines changedWhitespace-only changes.
Lines changed: 92 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
Lines changed: 10 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
166 | 166 | | |
167 | 167 | | |
168 | 168 | | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
169 | 174 | | |
170 | 175 | | |
171 | 176 | | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
172 | 182 | | |
173 | 183 | | |
174 | 184 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
253 | 253 | | |
254 | 254 | | |
255 | 255 | | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
256 | 261 | | |
257 | 262 | | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
258 | 268 | | |
259 | 269 | | |
260 | 270 | | |
261 | 271 | | |
262 | 272 | | |
263 | | - | |
| 273 | + | |
264 | 274 | | |
265 | 275 | | |
266 | | - | |
| 276 | + | |
267 | 277 | | |
268 | 278 | | |
269 | | - | |
| 279 | + | |
270 | 280 | | |
271 | 281 | | |
272 | | - | |
| 282 | + | |
273 | 283 | | |
274 | 284 | | |
275 | | - | |
276 | | - | |
| 285 | + | |
277 | 286 | | |
278 | 287 | | |
279 | 288 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
49 | | - | |
| 49 | + | |
50 | 50 | | |
51 | 51 | | |
52 | 52 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | | - | |
| 4 | + | |
4 | 5 | | |
5 | 6 | | |
6 | 7 | | |
| |||
38 | 39 | | |
39 | 40 | | |
40 | 41 | | |
| 42 | + | |
| 43 | + | |
41 | 44 | | |
42 | 45 | | |
43 | 46 | | |
| |||
522 | 525 | | |
523 | 526 | | |
524 | 527 | | |
| 528 | + | |
| 529 | + | |
| 530 | + | |
| 531 | + | |
| 532 | + | |
| 533 | + | |
| 534 | + | |
| 535 | + | |
| 536 | + | |
| 537 | + | |
| 538 | + | |
| 539 | + | |
| 540 | + | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
525 | 617 | | |
526 | 618 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
5 | 6 | | |
6 | 7 | | |
7 | 8 | | |
| |||
213 | 214 | | |
214 | 215 | | |
215 | 216 | | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
0 commit comments