Skip to content

Commit c71c15f

Browse files
wconstabyoussef62
authored andcommitted
[C10D] Support group ranks in P2POp and batch_isend_irecv (pytorch#141054)
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead of global ranks. I think this is OK since I also updated the field names to make this obvious. Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#141054 Approved by: https://github.com/kwen2501
1 parent 5edda56 commit c71c15f

File tree

2 files changed

+68
-14
lines changed

2 files changed

+68
-14
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3869,6 +3869,40 @@ def test_send_recv_subgroup(self, async_op, group_rank):
38693869
else:
38703870
c10d.send(x, dst=self.rank - 1, group=subgroup)
38713871

3872+
@requires_nccl()
3873+
@skip_if_lt_x_gpu(4)
3874+
@parametrize("group_rank", [True, False])
3875+
def test_batch_send_recv_subgroup(self, group_rank):
3876+
world_size = 4
3877+
if self.rank >= world_size:
3878+
return
3879+
subgroup = self._init_two_pg2_subgroups(world_size)
3880+
device = torch.device("cuda:%d" % self.rank)
3881+
ops = []
3882+
if self.rank == 0 or self.rank == 2:
3883+
x = torch.empty((10,), device=device)
3884+
if group_rank:
3885+
ops.append(c10d.P2POp(dist.irecv, x, group=subgroup, group_peer=1))
3886+
else:
3887+
ops.append(
3888+
c10d.P2POp(dist.irecv, x, peer=self.rank + 1, group=subgroup)
3889+
)
3890+
3891+
for work in dist.batch_isend_irecv(ops):
3892+
work.wait()
3893+
expected = torch.ones((10,), device=device) * (self.rank + 1)
3894+
self.assertEqual(x, expected)
3895+
else:
3896+
x = torch.ones((10,), device=device) * self.rank
3897+
if group_rank:
3898+
ops.append(c10d.P2POp(dist.isend, x, group=subgroup, group_peer=0))
3899+
else:
3900+
ops.append(
3901+
c10d.P2POp(dist.isend, x, peer=self.rank - 1, group=subgroup)
3902+
)
3903+
for work in dist.batch_isend_irecv(ops):
3904+
work.wait()
3905+
38723906
@requires_nccl()
38733907
@skip_if_lt_x_gpu(4)
38743908
@parametrize("group_rank", [True, False])

torch/distributed/distributed_c10d.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -469,57 +469,61 @@ class P2POp:
469469
The type of ``op`` is either ``torch.distributed.isend`` or
470470
``torch.distributed.irecv``.
471471
tensor (Tensor): Tensor to send or receive.
472-
peer (int): Destination or source rank.
472+
peer (int, optional): Destination or source rank.
473473
group (ProcessGroup, optional): The process group to work on. If None,
474474
the default process group will be used.
475475
tag (int, optional): Tag to match send with recv.
476+
group_peer (int, optional): Destination or source rank.
476477
"""
477478

478479
def __init__(
479480
self,
480481
op: Callable,
481482
tensor: torch.Tensor,
482-
peer: int,
483+
peer: Optional[int] = None,
483484
group: Optional[ProcessGroup] = None,
484485
tag: int = 0,
486+
group_peer: Optional[int] = None,
485487
):
486488
"""Init."""
487489
self.op = op
488490
self.tensor = tensor
489-
self.peer = peer
490-
self.group = group
491+
self.group = _group_or_default_group(group)
492+
self.peer = _canonicalize_group_rank(
493+
self.group, peer, group_peer, return_global=True
494+
)
491495
self.tag = tag
496+
self.group_peer = _canonicalize_group_rank(self.group, peer, group_peer)
492497

493498
def __new__(
494499
cls,
495500
op: Callable,
496501
tensor: torch.Tensor,
497-
peer: int,
502+
peer: Optional[int] = None,
498503
group: Optional[ProcessGroup] = None,
499504
tag: int = 0,
505+
group_peer: Optional[int] = None,
500506
):
501507
"""Create and return a new instance of the class."""
502508
_check_op(op)
503509
_check_single_tensor(tensor, "tensor")
510+
504511
return object.__new__(cls)
505512

506513
def __repr__(self):
507514
my_group_rank = get_rank(self.group)
508-
peer_group_rank = (
509-
get_group_rank(self.group, self.peer) if self.group else self.peer
510-
)
511515
op_name = self.op.__name__
512516
group_name = self.group.group_name if self.group else "default_pg"
513517
if "send" in op_name:
514518
s = my_group_rank
515-
d = peer_group_rank
519+
d = self.group_peer
516520
elif "recv" in op_name:
517-
s = peer_group_rank
521+
s = self.group_peer
518522
d = my_group_rank
519523
else:
520524
return super().__repr__()
521525

522-
return f"P2POp({op_name} pg={group_name}, s={s}, d={d}, {self.tensor.shape}, {self.tensor.dtype})"
526+
return f"P2POp({op_name} pg={group_name}, group_src={s}, group_dst={d}, {self.tensor.shape}, {self.tensor.dtype})"
523527

524528

525529
class _CollOp:
@@ -2545,7 +2549,7 @@ def _coalescing_manager(
25452549
work.wait() # type: ignore[possibly-undefined]
25462550

25472551

2548-
def batch_isend_irecv(p2p_op_list):
2552+
def batch_isend_irecv(p2p_op_list: List[P2POp]) -> List[Work]:
25492553
"""
25502554
Send or Receive a batch of tensors asynchronously and return a list of requests.
25512555
@@ -2588,17 +2592,33 @@ def batch_isend_irecv(p2p_op_list):
25882592
_check_p2p_op_list(p2p_op_list)
25892593
group = p2p_op_list[0].group
25902594
device = p2p_op_list[0].tensor.device
2595+
2596+
def peer_kwarg(op: P2POp) -> Dict[str, int]:
2597+
key = "group_dst" if op.op == isend else "group_src"
2598+
return {key: op.group_peer}
2599+
25912600
if device.type == "cuda":
25922601
# NCCL style coalescing
25932602
with _coalescing_manager(group, device, async_ops=True) as cm:
25942603
for p2p_op in p2p_op_list:
2595-
p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
2604+
p2p_op.op(
2605+
p2p_op.tensor,
2606+
group=p2p_op.group,
2607+
tag=p2p_op.tag,
2608+
**peer_kwarg(p2p_op),
2609+
)
2610+
25962611
return cm.works
25972612
else:
25982613
# Backward support for Gloo
25992614
reqs = []
26002615
for p2p_op in p2p_op_list:
2601-
work = p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
2616+
work = p2p_op.op(
2617+
p2p_op.tensor,
2618+
group=p2p_op.group,
2619+
tag=p2p_op.tag,
2620+
**peer_kwarg(p2p_op),
2621+
)
26022622
if work:
26032623
reqs.append(work)
26042624
return reqs

0 commit comments

Comments
 (0)