@@ -3087,7 +3087,13 @@ def gather_object(
3087
3087
3088
3088
3089
3089
@_exception_logger
3090
- def send_object_list (object_list , dst , group = None , device = None ):
3090
+ def send_object_list (
3091
+ object_list : List [Any ],
3092
+ dst : Optional [int ] = None ,
3093
+ group : Optional [ProcessGroup ] = None ,
3094
+ device : Optional [torch .device ] = None ,
3095
+ group_dst : Optional [int ] = None ,
3096
+ ):
3091
3097
"""
3092
3098
Sends picklable objects in ``object_list`` synchronously.
3093
3099
@@ -3105,7 +3111,8 @@ def send_object_list(object_list, dst, group=None, device=None):
3105
3111
device (``torch.device``, optional): If not None, the objects are
3106
3112
serialized and converted to tensors which are moved to the
3107
3113
``device`` before sending. Default is ``None``.
3108
-
3114
+ group_dst (int, optional): Destination rank on ``group``.
3115
+ Must specify one of ``dst`` and ``group_dst`` but not both
3109
3116
Returns:
3110
3117
``None``.
3111
3118
@@ -3143,11 +3150,9 @@ def send_object_list(object_list, dst, group=None, device=None):
3143
3150
>>> objects
3144
3151
['foo', 12, {1: 2}]
3145
3152
"""
3146
- if get_rank () == dst :
3147
- raise ValueError (
3148
- "Invalid destination rank: destination rank should not be the same as "
3149
- "the rank of the current process."
3150
- )
3153
+ group = _group_or_default_group (group )
3154
+ group_dst = _canonicalize_group_rank (group , dst , group_dst )
3155
+ _check_not_self_rank (group , group_dst , "destination" )
3151
3156
3152
3157
if _rank_not_in_group (group ):
3153
3158
_warn_not_in_group ("send_object_list" )
@@ -3167,7 +3172,7 @@ def send_object_list(object_list, dst, group=None, device=None):
3167
3172
object_sizes_tensor = torch .cat (size_list )
3168
3173
3169
3174
# Send object sizes
3170
- send (object_sizes_tensor , dst = dst , group = group )
3175
+ send (object_sizes_tensor , group_dst = group_dst , group = group )
3171
3176
3172
3177
# Concatenate and send serialized object tensors
3173
3178
# Note: torch.cat will do an extra memory copy to the current device, if the tensor_list
@@ -3177,11 +3182,17 @@ def send_object_list(object_list, dst, group=None, device=None):
3177
3182
else :
3178
3183
object_tensor = torch .cat (tensor_list )
3179
3184
3180
- send (object_tensor , dst = dst , group = group )
3185
+ send (object_tensor , group_dst = group_dst , group = group )
3181
3186
3182
3187
3183
3188
@_exception_logger
3184
- def recv_object_list (object_list , src = None , group = None , device = None ):
3189
+ def recv_object_list (
3190
+ object_list : List [Any ],
3191
+ src : Optional [int ] = None ,
3192
+ group : Optional [ProcessGroup ] = None ,
3193
+ device : Optional [torch .device ] = None ,
3194
+ group_src : Optional [int ] = None ,
3195
+ ):
3185
3196
"""
3186
3197
Receives picklable objects in ``object_list`` synchronously.
3187
3198
@@ -3197,6 +3208,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
3197
3208
the default process group will be used. Default is ``None``.
3198
3209
device (``torch.device``, optional): If not None, receives on this device.
3199
3210
Default is ``None``.
3211
+ group_src (int, optional): Destination rank on ``group``. Invalid to specify both ``src`` and ``group_src``.
3200
3212
3201
3213
Returns:
3202
3214
Sender rank. -1 if rank is not part of the group. If rank is part of the group,
@@ -3252,7 +3264,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
3252
3264
)
3253
3265
3254
3266
# Receive object sizes
3255
- rank_sizes = recv (object_sizes_tensor , src = src , group = group )
3267
+ rank_sizes = recv (object_sizes_tensor , src = src , group = group , group_src = group_src )
3256
3268
3257
3269
# Tensor to receive serialized objects into.
3258
3270
object_tensor = torch .empty ( # type: ignore[call-overload]
@@ -3261,7 +3273,7 @@ def recv_object_list(object_list, src=None, group=None, device=None):
3261
3273
device = current_device ,
3262
3274
)
3263
3275
3264
- rank_objects = recv (object_tensor , src = src , group = group )
3276
+ rank_objects = recv (object_tensor , src = src , group = group , group_src = group_src )
3265
3277
assert (
3266
3278
rank_sizes == rank_objects
3267
3279
), "Mismatch in return ranks for object sizes and objects."
0 commit comments