Skip to content

Commit 52985f3

Browse files
wconstabpobin6
authored andcommitted
[C10D] Support group_dst/group_src in c10d send/recv object_list (pytorch#140847)
Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in pytorch#140460 Pull Request resolved: pytorch#140847 Approved by: https://github.com/H-Huang ghstack dependencies: pytorch#140843
1 parent 8f7dd05 commit 52985f3

File tree

2 files changed

+40
-15
lines changed

2 files changed

+40
-15
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3928,7 +3928,10 @@ def test_broadcast_subgroup(self, group_rank):
39283928
"set_device",
39293929
[SetDeviceMethod.TORCH_CUDA_SET, SetDeviceMethod.COLLECTIVE_ARGUMENT],
39303930
)
3931-
def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod):
3931+
@parametrize("group_rank", [True, False])
3932+
def test_send_recv_object_list_subgroup(
3933+
self, set_device: SetDeviceMethod, group_rank
3934+
):
39323935
world_size = 4
39333936
if self.rank >= world_size:
39343937
return
@@ -3940,12 +3943,22 @@ def test_send_recv_object_list_subgroup(self, set_device: SetDeviceMethod):
39403943
device = torch.device("cuda:%d" % self.rank)
39413944
if self.rank == 0 or self.rank == 2:
39423945
x = [{}]
3943-
c10d.recv_object_list(x, src=self.rank + 1, group=subgroup, device=device)
3946+
if group_rank:
3947+
c10d.recv_object_list(x, group_src=1, group=subgroup, device=device)
3948+
else:
3949+
c10d.recv_object_list(
3950+
x, src=self.rank + 1, group=subgroup, device=device
3951+
)
39443952
expected = [{"rank": self.rank + 1}]
39453953
self.assertEqual(x, expected)
39463954
else:
39473955
x = [{"rank": self.rank}]
3948-
c10d.send_object_list(x, dst=self.rank - 1, group=subgroup, device=device)
3956+
if group_rank:
3957+
c10d.send_object_list(x, group_dst=0, group=subgroup, device=device)
3958+
else:
3959+
c10d.send_object_list(
3960+
x, dst=self.rank - 1, group=subgroup, device=device
3961+
)
39493962

39503963
@requires_nccl()
39513964
@skip_if_lt_x_gpu(4)

torch/distributed/distributed_c10d.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3087,7 +3087,13 @@ def gather_object(
30873087

30883088

30893089
@_exception_logger
3090-
def send_object_list(object_list, dst, group=None, device=None):
3090+
def send_object_list(
3091+
object_list: List[Any],
3092+
dst: Optional[int] = None,
3093+
group: Optional[ProcessGroup] = None,
3094+
device: Optional[torch.device] = None,
3095+
group_dst: Optional[int] = None,
3096+
):
30913097
"""
30923098
Sends picklable objects in ``object_list`` synchronously.
30933099
@@ -3105,7 +3111,8 @@ def send_object_list(object_list, dst, group=None, device=None):
31053111
device (``torch.device``, optional): If not None, the objects are
31063112
serialized and converted to tensors which are moved to the
31073113
``device`` before sending. Default is ``None``.
3108-
3114+
group_dst (int, optional): Destination rank on ``group``.
3115+
Must specify one of ``dst`` and ``group_dst`` but not both
31093116
Returns:
31103117
``None``.
31113118
@@ -3143,11 +3150,9 @@ def send_object_list(object_list, dst, group=None, device=None):
31433150
>>> objects
31443151
['foo', 12, {1: 2}]
31453152
"""
3146-
if get_rank() == dst:
3147-
raise ValueError(
3148-
"Invalid destination rank: destination rank should not be the same as "
3149-
"the rank of the current process."
3150-
)
3153+
group = _group_or_default_group(group)
3154+
group_dst = _canonicalize_group_rank(group, dst, group_dst)
3155+
_check_not_self_rank(group, group_dst, "destination")
31513156

31523157
if _rank_not_in_group(group):
31533158
_warn_not_in_group("send_object_list")
@@ -3167,7 +3172,7 @@ def send_object_list(object_list, dst, group=None, device=None):
31673172
object_sizes_tensor = torch.cat(size_list)
31683173

31693174
# Send object sizes
3170-
send(object_sizes_tensor, dst=dst, group=group)
3175+
send(object_sizes_tensor, group_dst=group_dst, group=group)
31713176

31723177
# Concatenate and send serialized object tensors
31733178
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
@@ -3177,11 +3182,17 @@ def send_object_list(object_list, dst, group=None, device=None):
31773182
else:
31783183
object_tensor = torch.cat(tensor_list)
31793184

3180-
send(object_tensor, dst=dst, group=group)
3185+
send(object_tensor, group_dst=group_dst, group=group)
31813186

31823187

31833188
@_exception_logger
3184-
def recv_object_list(object_list, src=None, group=None, device=None):
3189+
def recv_object_list(
3190+
object_list: List[Any],
3191+
src: Optional[int] = None,
3192+
group: Optional[ProcessGroup] = None,
3193+
device: Optional[torch.device] = None,
3194+
group_src: Optional[int] = None,
3195+
):
31853196
"""
31863197
Receives picklable objects in ``object_list`` synchronously.
31873198
@@ -3197,6 +3208,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
31973208
the default process group will be used. Default is ``None``.
31983209
device (``torch.device``, optional): If not None, receives on this device.
31993210
Default is ``None``.
3211+
group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
32003212
32013213
Returns:
32023214
Sender rank. -1 if rank is not part of the group. If rank is part of the group,
@@ -3252,7 +3264,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
32523264
)
32533265

32543266
# Receive object sizes
3255-
rank_sizes = recv(object_sizes_tensor, src=src, group=group)
3267+
rank_sizes = recv(object_sizes_tensor, src=src, group=group, group_src=group_src)
32563268

32573269
# Tensor to receive serialized objects into.
32583270
object_tensor = torch.empty( # type: ignore[call-overload]
@@ -3261,7 +3273,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
32613273
device=current_device,
32623274
)
32633275

3264-
rank_objects = recv(object_tensor, src=src, group=group)
3276+
rank_objects = recv(object_tensor, src=src, group=group, group_src=group_src)
32653277
assert (
32663278
rank_sizes == rank_objects
32673279
), "Mismatch in return ranks for object sizes and objects."

0 commit comments

Comments
 (0)