Skip to content

Commit 3283385

Browse files
xuxinyi389zty-king
andauthored
[AutoParallel] send/recv_object_list function and serialize method for placement object (PaddlePaddle#72098)
* add___reduce___method * support send/recv_object_list * polish api and add tests * typing-fix * typing-fix * fix_tests --------- Co-authored-by: zty-king <17786324919@163.com>
1 parent 30e5f93 commit 3283385

File tree

9 files changed

+382
-4
lines changed

9 files changed

+382
-4
lines changed

paddle/fluid/pybind/auto_parallel_py.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,9 @@ void BindAutoParallel(py::module *m) {
487487
py::arg("memo"))
488488
.def(py::self == py::self) // NOLINT
489489
.def(py::self != py::self); // NOLINT
490+
Shard.def("__reduce__", [Shard](const phi::distributed::Shard &self) {
491+
return py::make_tuple(Shard, py::make_tuple(self.get_dim()));
492+
});
490493

491494
auto Replicate =
492495
py::class_<phi::distributed::Replicate,
@@ -522,6 +525,10 @@ void BindAutoParallel(py::module *m) {
522525
py::arg("memo"))
523526
.def(py::self == py::self) // NOLINT
524527
.def(py::self != py::self); // NOLINT
528+
Replicate.def("__reduce__",
529+
[Replicate](const phi::distributed::Replicate &self) {
530+
return py::make_tuple(Replicate, py::make_tuple());
531+
});
525532

526533
auto Partial =
527534
py::class_<phi::distributed::Partial,
@@ -562,6 +569,9 @@ void BindAutoParallel(py::module *m) {
562569
py::arg("memo"))
563570
.def(py::self == py::self) // NOLINT
564571
.def(py::self != py::self); // NOLINT
572+
Partial.def("__reduce__", [Partial](const phi::distributed::Partial &self) {
573+
return py::make_tuple(Partial, py::make_tuple(self.get_reduce_type()));
574+
});
565575

566576
g_placement_shard_pytype = reinterpret_cast<PyTypeObject *>(Shard.ptr());
567577
g_placement_replicated_pytype =

python/paddle/distributed/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,13 @@
9494
is_initialized,
9595
isend,
9696
recv,
97+
recv_object_list,
9798
reduce,
9899
reduce_scatter,
99100
scatter,
100101
scatter_object_list,
101102
send,
103+
send_object_list,
102104
stream,
103105
wait,
104106
)
@@ -167,6 +169,8 @@
167169
"destroy_process_group",
168170
"isend",
169171
"irecv",
172+
"send_object_list",
173+
"recv_object_list",
170174
"reduce_scatter",
171175
"is_available",
172176
"get_backend",

python/paddle/distributed/communication/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
is_initialized,
2626
wait,
2727
)
28-
from .recv import irecv, recv # noqa: F401
28+
from .recv import irecv, recv, recv_object_list # noqa: F401
2929
from .reduce import ReduceOp, reduce # noqa: F401
3030
from .reduce_scatter import reduce_scatter # noqa: F401
3131
from .scatter import scatter, scatter_object_list # noqa: F401
32-
from .send import isend, send # noqa: F401
32+
from .send import isend, send, send_object_list # noqa: F401

python/paddle/distributed/communication/group.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,23 @@ def get_group_rank(self, rank: int) -> int | Literal[-1]:
9191
else:
9292
return -1
9393

94+
def get_global_rank(self, rank: int) -> int | Literal[-1]:
95+
"""
96+
Get the global rank of a process within a group.
97+
98+
Args:
99+
rank (int): The local rank within the group.
100+
101+
Returns:
102+
If the current process is a member of the group, returns the corresponding global rank;
103+
otherwise returns -1.
104+
105+
"""
106+
if self.is_member():
107+
return self.ranks[rank]
108+
else:
109+
return -1
110+
94111
def __repr__(self) -> str:
95112
debug_str = (
96113
f"rank: {self.rank}, nranks: {self.nranks}, id: {self.id}, ranks: "

python/paddle/distributed/communication/recv.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING
17+
from typing import TYPE_CHECKING, Any
1818

19+
import paddle
1920
from paddle.distributed.communication import stream
21+
from paddle.distributed.communication.group import (
22+
_get_global_group,
23+
_warn_cur_rank_not_in_group,
24+
)
25+
from paddle.distributed.communication.serialization_utils import (
26+
convert_tensor_to_object,
27+
)
2028

2129
if TYPE_CHECKING:
2230
from paddle import Tensor
@@ -102,3 +110,69 @@ def irecv(
102110
>>> # [7, 8, 9] (2 GPUs)
103111
"""
104112
return recv(tensor, src, group, sync_op=False)
113+
114+
115+
def recv_object_list(
116+
object_list: list[Any],
117+
src: int | None = None,
118+
group: Group | None = None,
119+
src_in_group: int | None = None,
120+
):
121+
"""
122+
Receive a list of Python objects from the sender.
123+
124+
Args:
125+
object_list (list): The list to store received objects. Must be pre-allocated with correct size.
126+
src (int, optional): The source rank id. Default: 0.
127+
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
128+
src_in_group (int, optional): The source rank within the group. Cannot be specified together with src. Default: None.
129+
130+
Returns:
131+
This function does not return any value.
132+
133+
Examples:
134+
.. code-block:: python
135+
136+
>>> # doctest: +REQUIRES(env: DISTRIBUTED)
137+
>>> import paddle
138+
>>> import paddle.distributed as dist
139+
140+
>>> dist.init_parallel_env()
141+
>>> if dist.get_rank() == 0:
142+
... data = ["hello", {"key": 100}, [1, 2, 3]]
143+
... dist.send_object_list(data, dst=1)
144+
>>> else:
145+
... data = [None] * 3 # type: ignore
146+
... dist.recv_object_list(data, src=0)
147+
>>> print(data)
148+
>>> # ["hello", {"key": 100}, [1, 2, 3]] (2 GPUs)
149+
"""
150+
if object_list is None or len(object_list) == 0:
151+
raise ValueError("object_list cannot be None or empty")
152+
153+
group = _get_global_group() if group is None else group
154+
if _warn_cur_rank_not_in_group(group):
155+
return
156+
157+
if src_in_group is not None:
158+
if src is not None:
159+
raise ValueError(
160+
"Cannot specify both 'src' and 'src_in_group' arguments."
161+
)
162+
src = group.get_global_rank(src_in_group)
163+
else:
164+
src = 0 if src is None else src
165+
166+
object_sizes_tensor = paddle.empty((len(object_list),), dtype='int64')
167+
recv(object_sizes_tensor, src=src, group=group)
168+
169+
total_size = paddle.sum(object_sizes_tensor).item()
170+
object_tensor = paddle.empty((total_size,), dtype=paddle.uint8)
171+
recv(object_tensor, src=src, group=group)
172+
173+
offset = 0
174+
for i, obj_size in enumerate(object_sizes_tensor):
175+
obj_size = obj_size.item()
176+
obj_view = object_tensor[offset : offset + obj_size]
177+
object_list[i] = convert_tensor_to_object(obj_view, obj_size)
178+
offset += obj_size

python/paddle/distributed/communication/send.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import TYPE_CHECKING
17+
from typing import TYPE_CHECKING, Any
1818

19+
import paddle
1920
from paddle.distributed.communication import stream
21+
from paddle.distributed.communication.group import (
22+
_get_global_group,
23+
_warn_cur_rank_not_in_group,
24+
)
25+
from paddle.distributed.communication.serialization_utils import (
26+
convert_object_to_tensor,
27+
)
2028

2129
if TYPE_CHECKING:
2230
from paddle import Tensor
@@ -101,3 +109,72 @@ def isend(tensor: Tensor, dst: int, group: Group | None = None) -> task | None:
101109
102110
"""
103111
return send(tensor, dst, group, sync_op=False)
112+
113+
114+
def send_object_list(
115+
object_list: list[Any],
116+
dst: int | None = None,
117+
group: Group | None = None,
118+
dst_in_group: int | None = None,
119+
):
120+
"""
121+
Send a list of Python objects to the receiver.
122+
123+
Args:
124+
object_list (list): The list of Python objects to send.
125+
dst (int, optional): The destination rank id. Default: 0.
126+
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
127+
dst_in_group (int, optional): The destination rank within the group. Cannot be specified together with dst. Default: None.
128+
129+
Returns:
130+
This function does not return any value.
131+
132+
Examples:
133+
.. code-block:: python
134+
135+
>>> # doctest: +REQUIRES(env: DISTRIBUTED)
136+
>>> import paddle
137+
>>> import paddle.distributed as dist
138+
139+
>>> dist.init_parallel_env()
140+
>>> if dist.get_rank() == 0:
141+
... data = ["hello", {"key": 100}, [1, 2, 3]]
142+
... dist.send_object_list(data, dst=1)
143+
>>> else:
144+
... data = [None] * 3 # type: ignore
145+
... dist.recv_object_list(data, src=0)
146+
>>> print(data)
147+
>>> # ["hello", {"key": 100}, [1, 2, 3]] (2 GPUs)
148+
"""
149+
if object_list is None or len(object_list) == 0:
150+
raise ValueError("object_list cannot be None or empty")
151+
152+
group = _get_global_group() if group is None else group
153+
if _warn_cur_rank_not_in_group(group):
154+
return
155+
156+
if dst_in_group is not None:
157+
if dst is not None:
158+
raise ValueError(
159+
"Cannot specify both 'dst' and 'dst_in_group' arguments."
160+
)
161+
dst = group.get_global_rank(dst_in_group)
162+
else:
163+
dst = 0 if dst is None else dst
164+
165+
# Convert objects to tensors and get their sizes
166+
tensor_list, size_list = zip(
167+
*[convert_object_to_tensor(obj) for obj in object_list]
168+
)
169+
size_list_values = [size.item() for size in size_list]
170+
171+
# Send sizes first
172+
object_sizes_tensor = paddle.to_tensor(size_list_values, dtype='int64')
173+
send(object_sizes_tensor, dst=dst, group=group)
174+
175+
# Send object data
176+
if len(tensor_list) == 1:
177+
object_tensor = tensor_list[0]
178+
else:
179+
object_tensor = paddle.concat(tensor_list)
180+
send(object_tensor, dst=dst, group=group)

test/auto_parallel/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
140140
py_test_modules(test_moe_utils MODULES test_moe_utils)
141141
set_tests_properties(test_moe_utils PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
142142
TIMEOUT 30)
143+
py_test_modules(test_object_list_communication MODULES
144+
test_object_list_communication)
145+
set_tests_properties(test_object_list_communication
146+
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
143147
# End of unittests WITH multi cards and timeout
144148

145149
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout

0 commit comments

Comments
 (0)