Skip to content

implement collective reduce op #9437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,29 +103,6 @@ def test_reduce_scatter(self, pin_layout):
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [-ordinal])

@staticmethod
def _scatter():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
tensors = None
if xr.global_ordinal() == 0:
tensors = [
torch.tensor([i], device=device, dtype=torch.float)
for i in range(world_size)
]

output_tensor = torch.tensor([-1], dtype=torch.float, device=device)
dist.scatter(output_tensor, tensors, src=0)
return output_tensor.cpu()

def test_scatter(self):
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
on device 0, then scatters it. Device i should therefore receive [i]."""
results = pjrt.run_multiprocess(self._scatter)
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [ordinal])

@staticmethod
def _all_to_all(pin_layout):
device = torch_xla.device()
Expand Down Expand Up @@ -359,6 +336,49 @@ def test_all_to_all_single(self, use_dynamo):
expected.sort().values),
f"Got {val}, expected {expected}")

@staticmethod
def _scatter():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
tensors = None
if xr.global_ordinal() == 0:
tensors = [
torch.tensor([i], device=device, dtype=torch.float)
for i in range(world_size)
]

output_tensor = torch.tensor([-1], dtype=torch.float, device=device)
dist.scatter(output_tensor, tensors, src=0)
return output_tensor.cpu()

def test_scatter(self):
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
on device 0, then scatters it. Device i should therefore receive [i]."""
results = pjrt.run_multiprocess(self._scatter)
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [ordinal])

@staticmethod
def _reduce():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
input = torch.tensor([xr.global_ordinal()],
dtype=torch.float,
device=device)
dist.reduce(input, dst=0, op=dist.ReduceOp.SUM)

return input.cpu()

def test_reduce(self):
results = pjrt.run_multiprocess(self._reduce)
for ordinal, value in results.items():
if ordinal == 0:
expected = sum(range(tpu.num_expected_global_devices()))
else:
expected = ordinal
np.testing.assert_array_equal(value, [expected])


if __name__ == '__main__':
absltest.main()
1 change: 0 additions & 1 deletion test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ def test_barrier(self):
dist.barrier()

@parameterized.parameters(
'reduce',
'allreduce_coalesced',
'alltoall',
'gather',
Expand Down
26 changes: 20 additions & 6 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import rendezvous
import logging
import os
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions
from torch._C._distributed_c10d import ProcessGroup, ScatterOptions, ReduceScatterOptions, AllgatherOptions, ReduceOptions


def _create_xla_process_group(prefix_store, rank, size, timeout):
Expand Down Expand Up @@ -224,11 +225,24 @@ def _reduce_scatter_base(self, output_tensor, input_tensor, opts):
def barrier(self, opts):
return _ret_work([])

# Call site:
# https://github.com/pytorch/pytorch/blob/70f57bcb1e45d21532bdb1c44d3aab018d1cbe88/torch/distributed/distributed_c10d.py#L1417
# `reduce` is not needed by DeepSpeed for now.
def reduce(self, *args):
raise NotImplementedError
# Called by torch.distributed.reduce. Call site example:
# https://github.com/pytorch/pytorch/blob/v2.7.1/torch/distributed/distributed_c10d.py#L2925
# Tensors are reduced but result is only saved on dst device.
# Input tensor is unchanged on all other devices.
# This is an inefficient operation. In order to avoid XLA deadlocks it
# performs redundant reductions on all devices and materializes the result.
def reduce(self, tensors: list[torch.Tensor], opts: ReduceOptions):
rank = xr.global_ordinal()
dst = opts.rootRank
reduce_type = self._get_reduce_type(opts.reduceOp)
for tensor in tensors:
result = xm.all_reduce(reduce_type, inputs=tensor)
torch_xla.sync()

if rank == dst:
tensor.copy_(result)

return _ret_work(tensors)

def allreduce_coalesced(self, *args):
raise NotImplementedError
Expand Down
Loading