|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | | -from typing import TYPE_CHECKING |
| 17 | +from typing import TYPE_CHECKING, Any |
18 | 18 |
|
| 19 | +import paddle |
19 | 20 | from paddle.distributed.communication import stream |
| 21 | +from paddle.distributed.communication.group import ( |
| 22 | + _get_global_group, |
| 23 | + _warn_cur_rank_not_in_group, |
| 24 | +) |
| 25 | +from paddle.distributed.communication.serialization_utils import ( |
| 26 | + convert_object_to_tensor, |
| 27 | +) |
20 | 28 |
|
21 | 29 | if TYPE_CHECKING: |
22 | 30 | from paddle import Tensor |
@@ -101,3 +109,72 @@ def isend(tensor: Tensor, dst: int, group: Group | None = None) -> task | None: |
101 | 109 |
|
102 | 110 | """ |
103 | 111 | return send(tensor, dst, group, sync_op=False) |
| 112 | + |
| 113 | + |
| 114 | +def send_object_list( |
| 115 | + object_list: list[Any], |
| 116 | + dst: int | None = None, |
| 117 | + group: Group | None = None, |
| 118 | + dst_in_group: int | None = None, |
| 119 | +): |
| 120 | + """ |
| 121 | + Send a list of Python objects to the receiver. |
| 122 | +
|
| 123 | + Args: |
| 124 | + object_list (list): The list of Python objects to send. |
| 125 | + dst (int, optional): The destination rank id. Default: 0. |
| 126 | + group (Group, optional): The group instance return by new_group or None for global default group. Default: None. |
| 127 | + dst_in_group (int, optional): The destination rank within the group. Cannot be specified together with dst. Default: None. |
| 128 | +
|
| 129 | + Returns: |
| 130 | + This function does not return any value. |
| 131 | +
|
| 132 | + Examples: |
| 133 | + .. code-block:: python |
| 134 | +
|
| 135 | + >>> # doctest: +REQUIRES(env: DISTRIBUTED) |
| 136 | + >>> import paddle |
| 137 | + >>> import paddle.distributed as dist |
| 138 | +
|
| 139 | + >>> dist.init_parallel_env() |
| 140 | + >>> if dist.get_rank() == 0: |
| 141 | + ... data = ["hello", {"key": 100}, [1, 2, 3]] |
| 142 | + ... dist.send_object_list(data, dst=1) |
| 143 | + >>> else: |
| 144 | + ... data = [None] * 3 # type: ignore |
| 145 | + ... dist.recv_object_list(data, src=0) |
| 146 | + >>> print(data) |
| 147 | + >>> # ["hello", {"key": 100}, [1, 2, 3]] (2 GPUs) |
| 148 | + """ |
| 149 | + if object_list is None or len(object_list) == 0: |
| 150 | + raise ValueError("object_list cannot be None or empty") |
| 151 | + |
| 152 | + group = _get_global_group() if group is None else group |
| 153 | + if _warn_cur_rank_not_in_group(group): |
| 154 | + return |
| 155 | + |
| 156 | + if dst_in_group is not None: |
| 157 | + if dst is not None: |
| 158 | + raise ValueError( |
| 159 | + "Cannot specify both 'dst' and 'dst_in_group' arguments." |
| 160 | + ) |
| 161 | + dst = group.get_global_rank(dst_in_group) |
| 162 | + else: |
| 163 | + dst = 0 if dst is None else dst |
| 164 | + |
| 165 | + # Convert objects to tensors and get their sizes |
| 166 | + tensor_list, size_list = zip( |
| 167 | + *[convert_object_to_tensor(obj) for obj in object_list] |
| 168 | + ) |
| 169 | + size_list_values = [size.item() for size in size_list] |
| 170 | + |
| 171 | + # Send sizes first |
| 172 | + object_sizes_tensor = paddle.to_tensor(size_list_values, dtype='int64') |
| 173 | + send(object_sizes_tensor, dst=dst, group=group) |
| 174 | + |
| 175 | + # Send object data |
| 176 | + if len(tensor_list) == 1: |
| 177 | + object_tensor = tensor_list[0] |
| 178 | + else: |
| 179 | + object_tensor = paddle.concat(tensor_list) |
| 180 | + send(object_tensor, dst=dst, group=group) |
0 commit comments