Skip to content

Commit 78ec2c3

Browse files
committed
Implement support for passing dictionary arguments in Pipeline Parallel
1 parent 2beec18 commit 78ec2c3

File tree

6 files changed

+406
-18
lines changed

6 files changed

+406
-18
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -757,27 +757,37 @@ def forward_backward_pipeline(
757757
schedule += f"f{step_id};"
758758
logger.info(f"forward step for micro step {step_id}")
759759
continue
760+
760761
input_tensor = self._p2p_helper.recv_forward(
761762
self.is_pipeline_first_stage(),
762763
batch_p2p_comm=self._use_batch_p2p_comm,
763764
)
764765

766+
input_tensor_dict, use_dict = tuple_to_dict_helper(input_tensor)
767+
765768
self._record_stamp("F", step_id, '"B"', self._forward_color)
766769
output_tensor, _, _ = self._forward_step(
767-
input_tensor, micro_dataset, step_id=step_id
770+
input_tensor=input_tensor_dict if use_dict else input_tensor,
771+
micro_dataset=micro_dataset,
772+
step_id=step_id,
768773
)
774+
775+
# convert dict to tuple whose tensor element has a key attribution
776+
output_tensor_tuple = dict_to_tuple_helper(output_tensor)
777+
769778
self._record_stamp("F", step_id, '"E"', self._forward_color)
779+
# fwd output dict -> send tuple
770780
self._p2p_helper.send_forward(
771-
output_tensor,
772-
self.is_pipeline_last_stage(),
781+
output_tensor=output_tensor_tuple,
782+
pp_last_stage=self.is_pipeline_last_stage(),
773783
batch_p2p_comm=self._use_batch_p2p_comm,
774784
)
775785

776786
input_buffers.append(input_tensor)
777-
output_buffers.append(output_tensor)
787+
output_buffers.append(output_tensor_tuple)
778788

779789
if not self.is_pipeline_last_stage():
780-
self._release_output(output_tensor)
790+
self._release_output(output_tensor_tuple)
781791

782792
if steady_steps > 0 and not static_scheduler:
783793
input_tensor = self._p2p_helper.recv_forward(
@@ -794,27 +804,33 @@ def forward_backward_pipeline(
794804
continue
795805
last_iter = i == (steady_steps - 1)
796806

807+
input_tensor_dict, use_dict = tuple_to_dict_helper(input_tensor)
808+
797809
self._record_stamp(
798810
"F", startup_steps + i, '"B"', self._forward_color
799811
)
800812
output_tensor, _, _ = self._forward_step(
801-
input_tensor, micro_dataset, step_id=startup_steps + i
813+
input_tensor=input_tensor_dict if use_dict else input_tensor,
814+
micro_dataset=micro_dataset,
815+
step_id=startup_steps + i,
802816
)
803817
self._record_stamp(
804818
"F", startup_steps + i, '"E"', self._forward_color
805819
)
806820

821+
output_tensor_tuple = dict_to_tuple_helper(output_tensor)
822+
807823
output_tensor_grad = self._p2p_helper.send_forward_recv_backward(
808-
output_tensor,
824+
output_tensor_tuple,
809825
self.is_pipeline_last_stage(),
810826
batch_p2p_comm=self._use_batch_p2p_comm,
811827
)
812828

813829
input_buffers.append(input_tensor)
814-
output_buffers.append(output_tensor)
830+
output_buffers.append(output_tensor_tuple)
815831

816832
if not self.is_pipeline_last_stage():
817-
self._release_output(output_tensor)
833+
self._release_output(output_tensor_tuple)
818834

819835
input_tensor, output_tensor = input_buffers.pop(
820836
0
@@ -1692,18 +1708,22 @@ def _forward_step_helper(
16921708

16931709
input_tensor = self._get_forward_input(virtual_pp_rank)
16941710

1711+
input_tensor_dict, use_dict = tuple_to_dict_helper(input_tensor)
1712+
16951713
output_tensor, schedule_chunk, loss_fn_node = self._forward_step(
1696-
input_tensor,
1714+
input_tensor_dict if use_dict else input_tensor,
16971715
micro_dataset,
1698-
virtual_pp_rank,
1716+
virtual_pp_rank, # chunk_id
16991717
step_id=micro_step,
17001718
overlap_schedule_mode=overlap_schedule_mode,
17011719
)
17021720

1721+
output_tensor_tuple = dict_to_tuple_helper(output_tensor)
1722+
17031723
self._store_forward_outputs(
1704-
virtual_pp_rank, output_tensor, schedule_chunk, loss_fn_node
1724+
virtual_pp_rank, output_tensor_tuple, schedule_chunk, loss_fn_node
17051725
)
1706-
return output_tensor
1726+
return output_tensor_tuple
17071727

17081728
def _overlap_comm_grads(self):
17091729
if self._comm_overlap:
@@ -2953,7 +2973,6 @@ def forward_backward_pipeline(
29532973
)
29542974
)
29552975

2956-
# run startup steps
29572976
for micro_step in range(num_steps):
29582977
output_tensor = self._forward_step_helper(micro_dataset, micro_step)
29592978
# determine whether recv forward tensor or not
@@ -3433,3 +3452,41 @@ def forward_backward_pipeline(
34333452
self.processed_steps += 1
34343453
self._check_user_hooks_status_at_step_end()
34353454
return train_loss
3455+
3456+
3457+
def tuple_to_dict_helper(input_tensor):
3458+
# recv tuple -> fwd input dict
3459+
use_dict = False
3460+
if isinstance(input_tensor, tuple):
3461+
use_dict = hasattr(input_tensor[0], "key")
3462+
else: # single tensor
3463+
use_dict = hasattr(input_tensor, "key")
3464+
if use_dict:
3465+
input_tensor = convert_tensor_tuple_to_dict(input_tensor)
3466+
return input_tensor, use_dict
3467+
3468+
3469+
def dict_to_tuple_helper(output_tensor):
3470+
if isinstance(output_tensor, dict):
3471+
output_tensor_tuple = convert_tensor_dict_to_tuple(
3472+
output_tensor_dict=output_tensor
3473+
)
3474+
else: # single tensor or tensor tuple
3475+
output_tensor_tuple = output_tensor
3476+
return output_tensor_tuple
3477+
3478+
3479+
def convert_tensor_dict_to_tuple(output_tensor_dict):
3480+
for key, tensor in output_tensor_dict.items():
3481+
tensor.key = key
3482+
3483+
return tuple(output_tensor_dict.values())
3484+
3485+
3486+
def convert_tensor_tuple_to_dict(input_tensor_tuple):
3487+
input_tensor_dict = {}
3488+
for tensor in input_tensor_tuple:
3489+
key = tensor.key
3490+
input_tensor_dict[key] = tensor
3491+
delattr(tensor, "key")
3492+
return input_tensor_dict

python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
_get_global_group,
2525
_warn_cur_rank_not_in_group,
2626
)
27+
from paddle.distributed.communication.serialization_utils import (
28+
convert_object_to_tensor,
29+
convert_tensor_to_object,
30+
)
2731
from paddle.framework.recall_error import check_naninf
2832
from paddle.utils import strtobool
2933

@@ -58,10 +62,12 @@ def __init__(self):
5862
def init_or_erase_meta(self):
5963
self.send_shape_message = None
6064
self.send_dtype_message = None
65+
self.send_key_message = None
6166

6267
self.recv_shape_message = None
6368
self.recv_dtype_message = None
6469
self.recv_stop_gradient = None
70+
self.recv_key_message = None
6571

6672
self.has_send_meta = False
6773
self.has_recv_meta = False
@@ -99,17 +105,31 @@ def recv_meta(self, group, reverse=False, broadcast=False):
99105
shapes = []
100106
dtypes = []
101107
stop_grads = []
108+
keys = []
102109

103110
for _ in range(tensor_num):
104111
shape_len = data.pop(0)
105112
shape = data[:shape_len]
106113
data = data[shape_len:]
107114
dtype_number = data.pop(0)
108115
stop_gradient = bool(data.pop(0))
116+
# ------------------tensor key meta send-------------
117+
key_len = data.pop(0)
118+
key_data = data[:key_len]
119+
if key_len > 0:
120+
key = convert_tensor_to_object(
121+
paddle.to_tensor(key_data).astype("uint8"),
122+
paddle.to_tensor(key_len),
123+
)
124+
else:
125+
key = None
126+
data = data[key_len:]
127+
# ------------------tensor key meta send-------------
109128

110129
shapes.append(shape)
111130
dtypes.append(dtype_number)
112131
stop_grads.append(stop_gradient)
132+
keys.append(key)
113133

114134
assert (
115135
len(data) == 0
@@ -119,10 +139,12 @@ def recv_meta(self, group, reverse=False, broadcast=False):
119139
self.recv_shape_message = shapes[0]
120140
self.recv_dtype_message = dtypes[0]
121141
self.recv_stop_gradient = stop_grads[0]
142+
self.recv_key_message = keys[0]
122143
else:
123144
self.recv_shape_message = tuple(shapes)
124145
self.recv_dtype_message = tuple(dtypes)
125146
self.recv_stop_gradient = tuple(stop_grads)
147+
self.recv_key_message = tuple(keys)
126148

127149
def send_meta(self, tensor, group, reverse=False, broadcast=False):
128150
if reverse:
@@ -152,12 +174,24 @@ def send_meta(self, tensor, group, reverse=False, broadcast=False):
152174

153175
for t in tensors_to_send:
154176
assert isinstance(t, paddle.Tensor)
177+
# ------------------tensor key meta send-------------
178+
if hasattr(t, "key"):
179+
current_tensor_name = t.key
180+
key_data_tensor, _ = convert_object_to_tensor(
181+
current_tensor_name
182+
)
183+
key_data = key_data_tensor.numpy().tolist()
184+
else:
185+
key_data = []
186+
# ------------------tensor key meta send-------------
155187
data.extend(
156188
[
157189
len(t.shape),
158190
*t.shape,
159191
paddle_2_number(t.dtype),
160192
int(t.stop_gradient),
193+
len(key_data),
194+
*key_data,
161195
]
162196
)
163197

@@ -184,35 +218,44 @@ def send_meta(self, tensor, group, reverse=False, broadcast=False):
184218

185219
def _obtain_send_message(self, tensor):
186220
if isinstance(tensor, paddle.Tensor):
187-
return tensor.shape, paddle_2_number(tensor.dtype)
221+
key = tensor.key if hasattr(tensor, "key") else None
222+
return tensor.shape, paddle_2_number(tensor.dtype), key
188223
else:
189224
shapes = []
190225
dtypes = []
226+
keys = []
191227
for d in tensor:
192228
assert isinstance(d, paddle.Tensor)
193229
if d.stop_gradient:
194230
continue
195-
shape, dtype = self._obtain_send_message(d)
231+
shape, dtype, key = self._obtain_send_message(d)
196232
shapes.append(shape)
197233
dtypes.append(dtype)
198-
return tuple(shapes), tuple(dtypes)
234+
keys.append(key)
235+
return tuple(shapes), tuple(dtypes), tuple(keys)
199236

200237
def set_send_message(self, tensor):
201238
(
202239
self.send_shape_message,
203240
self.send_dtype_message,
241+
self.send_key_message, # (key1_str, key2_str, key3_str ... )
204242
) = self._obtain_send_message(tensor)
205243

206244
def check_send_message(self, tensor):
207245
if self.send_shape_message is None or self.send_dtype_message is None:
208246
return
209-
actual_shape, actual_dtype = self._obtain_send_message(tensor)
247+
actual_shape, actual_dtype, actual_key = self._obtain_send_message(
248+
tensor
249+
)
210250
assert (
211251
self.send_shape_message == actual_shape
212252
), f"send_shape_message: {self.send_shape_message}, actual_shape: {actual_shape}"
213253
assert (
214254
self.send_dtype_message == actual_dtype
215255
), f"send_dtype_message: {self.send_dtype_message}, actual_dtype: {actual_dtype}"
256+
assert (
257+
self.send_key_message == actual_key
258+
), f"send_key_message: {self.send_key_message}, actual_key: {actual_key}"
216259

217260
def __repr__(self):
218261
return f"send_shape_message: {self.send_shape_message}, send_dtype_message: {self.send_dtype_message}, recv_shape_message: {self.recv_shape_message}, recv_dtype_message: {self.recv_dtype_message}, recv_stop_gradient: {self.recv_stop_gradient}"
@@ -619,9 +662,11 @@ def _p2p_helper(
619662
recv_shape_msg = send_recv_meta.recv_shape_message
620663
recv_dtype_msg = send_recv_meta.recv_dtype_message
621664
recv_stop_gradient = send_recv_meta.recv_stop_gradient
665+
recv_key_msg = send_recv_meta.recv_key_message
622666

623667
send_shape_msg = send_recv_meta.send_shape_message
624668
send_dtype_msg = send_recv_meta.send_dtype_message
669+
# backward has no key meta message
625670

626671
# model parallel message
627672
mp_group = _hcg.get_model_parallel_group()
@@ -636,13 +681,17 @@ def _p2p_helper(
636681
shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx])
637682
)
638683
tmp.stop_gradient = recv_stop_gradient[idx]
684+
if recv_key_msg[idx] is not None:
685+
tmp.key = recv_key_msg[idx]
639686
tensor_recv_prev.append(tmp)
640687
tensor_recv_prev = tuple(tensor_recv_prev)
641688
else:
642689
tensor_recv_prev = paddle.empty(
643690
shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg)
644691
)
645692
tensor_recv_prev.stop_gradient = recv_stop_gradient
693+
if recv_key_msg is not None:
694+
tensor_recv_prev.key = recv_key_msg
646695

647696
if recv_next:
648697
if dynamic_shape:

test/collective/fleet/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,3 +836,17 @@ if((WITH_GPU) AND LOCAL_ALL_PLAT)
836836
)
837837
set_tests_properties(test_shutdown_process_group PROPERTIES TIMEOUT "200")
838838
endif()
839+
if((WITH_GPU) AND LOCAL_ALL_PLAT)
840+
bash_test_modules(
841+
test_pp_send_recv_dict
842+
START_BASH
843+
../../legacy_test/dist_test.sh
844+
TIMEOUT
845+
"500"
846+
LABELS
847+
"RUN_TYPE=DIST"
848+
ENVS
849+
"PADDLE_DIST_UT_PORT=21282;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
850+
)
851+
set_tests_properties(test_pp_send_recv_dict PROPERTIES TIMEOUT "500")
852+
endif()

0 commit comments

Comments
 (0)