@@ -469,57 +469,61 @@ class P2POp:
469
469
The type of ``op`` is either ``torch.distributed.isend`` or
470
470
``torch.distributed.irecv``.
471
471
tensor (Tensor): Tensor to send or receive.
472
- peer (int): Destination or source rank.
472
+ peer (int, optional ): Destination or source rank.
473
473
group (ProcessGroup, optional): The process group to work on. If None,
474
474
the default process group will be used.
475
475
tag (int, optional): Tag to match send with recv.
476
+ group_peer (int, optional): Destination or source rank.
476
477
"""
477
478
478
479
def __init__ (
479
480
self ,
480
481
op : Callable ,
481
482
tensor : torch .Tensor ,
482
- peer : int ,
483
+ peer : Optional [ int ] = None ,
483
484
group : Optional [ProcessGroup ] = None ,
484
485
tag : int = 0 ,
486
+ group_peer : Optional [int ] = None ,
485
487
):
486
488
"""Init."""
487
489
self .op = op
488
490
self .tensor = tensor
489
- self .peer = peer
490
- self .group = group
491
+ self .group = _group_or_default_group (group )
492
+ self .peer = _canonicalize_group_rank (
493
+ self .group , peer , group_peer , return_global = True
494
+ )
491
495
self .tag = tag
496
+ self .group_peer = _canonicalize_group_rank (self .group , peer , group_peer )
492
497
493
498
def __new__ (
494
499
cls ,
495
500
op : Callable ,
496
501
tensor : torch .Tensor ,
497
- peer : int ,
502
+ peer : Optional [ int ] = None ,
498
503
group : Optional [ProcessGroup ] = None ,
499
504
tag : int = 0 ,
505
+ group_peer : Optional [int ] = None ,
500
506
):
501
507
"""Create and return a new instance of the class."""
502
508
_check_op (op )
503
509
_check_single_tensor (tensor , "tensor" )
510
+
504
511
return object .__new__ (cls )
505
512
506
513
def __repr__ (self ):
507
514
my_group_rank = get_rank (self .group )
508
- peer_group_rank = (
509
- get_group_rank (self .group , self .peer ) if self .group else self .peer
510
- )
511
515
op_name = self .op .__name__
512
516
group_name = self .group .group_name if self .group else "default_pg"
513
517
if "send" in op_name :
514
518
s = my_group_rank
515
- d = peer_group_rank
519
+ d = self . group_peer
516
520
elif "recv" in op_name :
517
- s = peer_group_rank
521
+ s = self . group_peer
518
522
d = my_group_rank
519
523
else :
520
524
return super ().__repr__ ()
521
525
522
- return f"P2POp({ op_name } pg={ group_name } , s ={ s } , d ={ d } , { self .tensor .shape } , { self .tensor .dtype } )"
526
+ return f"P2POp({ op_name } pg={ group_name } , group_src ={ s } , group_dst ={ d } , { self .tensor .shape } , { self .tensor .dtype } )"
523
527
524
528
525
529
class _CollOp :
@@ -2545,7 +2549,7 @@ def _coalescing_manager(
2545
2549
work .wait () # type: ignore[possibly-undefined]
2546
2550
2547
2551
2548
- def batch_isend_irecv (p2p_op_list ) :
2552
+ def batch_isend_irecv (p2p_op_list : List [ P2POp ]) -> List [ Work ] :
2549
2553
"""
2550
2554
Send or Receive a batch of tensors asynchronously and return a list of requests.
2551
2555
@@ -2588,17 +2592,33 @@ def batch_isend_irecv(p2p_op_list):
2588
2592
_check_p2p_op_list (p2p_op_list )
2589
2593
group = p2p_op_list [0 ].group
2590
2594
device = p2p_op_list [0 ].tensor .device
2595
+
2596
+ def peer_kwarg (op : P2POp ) -> Dict [str , int ]:
2597
+ key = "group_dst" if op .op == isend else "group_src"
2598
+ return {key : op .group_peer }
2599
+
2591
2600
if device .type == "cuda" :
2592
2601
# NCCL style coalescing
2593
2602
with _coalescing_manager (group , device , async_ops = True ) as cm :
2594
2603
for p2p_op in p2p_op_list :
2595
- p2p_op .op (p2p_op .tensor , p2p_op .peer , p2p_op .group , p2p_op .tag )
2604
+ p2p_op .op (
2605
+ p2p_op .tensor ,
2606
+ group = p2p_op .group ,
2607
+ tag = p2p_op .tag ,
2608
+ ** peer_kwarg (p2p_op ),
2609
+ )
2610
+
2596
2611
return cm .works
2597
2612
else :
2598
2613
# Backward support for Gloo
2599
2614
reqs = []
2600
2615
for p2p_op in p2p_op_list :
2601
- work = p2p_op .op (p2p_op .tensor , p2p_op .peer , p2p_op .group , p2p_op .tag )
2616
+ work = p2p_op .op (
2617
+ p2p_op .tensor ,
2618
+ group = p2p_op .group ,
2619
+ tag = p2p_op .tag ,
2620
+ ** peer_kwarg (p2p_op ),
2621
+ )
2602
2622
if work :
2603
2623
reqs .append (work )
2604
2624
return reqs
0 commit comments