24
24
_get_global_group ,
25
25
_warn_cur_rank_not_in_group ,
26
26
)
27
+ from paddle .distributed .communication .serialization_utils import (
28
+ convert_object_to_tensor ,
29
+ convert_tensor_to_object ,
30
+ )
27
31
from paddle .framework .recall_error import check_naninf
28
32
from paddle .utils import strtobool
29
33
@@ -58,10 +62,12 @@ def __init__(self):
58
62
def init_or_erase_meta (self ):
59
63
self .send_shape_message = None
60
64
self .send_dtype_message = None
65
+ self .send_key_message = None
61
66
62
67
self .recv_shape_message = None
63
68
self .recv_dtype_message = None
64
69
self .recv_stop_gradient = None
70
+ self .recv_key_message = None
65
71
66
72
self .has_send_meta = False
67
73
self .has_recv_meta = False
@@ -99,17 +105,31 @@ def recv_meta(self, group, reverse=False, broadcast=False):
99
105
shapes = []
100
106
dtypes = []
101
107
stop_grads = []
108
+ keys = []
102
109
103
110
for _ in range (tensor_num ):
104
111
shape_len = data .pop (0 )
105
112
shape = data [:shape_len ]
106
113
data = data [shape_len :]
107
114
dtype_number = data .pop (0 )
108
115
stop_gradient = bool (data .pop (0 ))
116
+ # ------------------tensor key meta send-------------
117
+ key_len = data .pop (0 )
118
+ key_data = data [:key_len ]
119
+ if key_len > 0 :
120
+ key = convert_tensor_to_object (
121
+ paddle .to_tensor (key_data ).astype ("uint8" ),
122
+ paddle .to_tensor (key_len ),
123
+ )
124
+ else :
125
+ key = None
126
+ data = data [key_len :]
127
+ # ------------------tensor key meta send-------------
109
128
110
129
shapes .append (shape )
111
130
dtypes .append (dtype_number )
112
131
stop_grads .append (stop_gradient )
132
+ keys .append (key )
113
133
114
134
assert (
115
135
len (data ) == 0
@@ -119,10 +139,12 @@ def recv_meta(self, group, reverse=False, broadcast=False):
119
139
self .recv_shape_message = shapes [0 ]
120
140
self .recv_dtype_message = dtypes [0 ]
121
141
self .recv_stop_gradient = stop_grads [0 ]
142
+ self .recv_key_message = keys [0 ]
122
143
else :
123
144
self .recv_shape_message = tuple (shapes )
124
145
self .recv_dtype_message = tuple (dtypes )
125
146
self .recv_stop_gradient = tuple (stop_grads )
147
+ self .recv_key_message = tuple (keys )
126
148
127
149
def send_meta (self , tensor , group , reverse = False , broadcast = False ):
128
150
if reverse :
@@ -152,12 +174,24 @@ def send_meta(self, tensor, group, reverse=False, broadcast=False):
152
174
153
175
for t in tensors_to_send :
154
176
assert isinstance (t , paddle .Tensor )
177
+ # ------------------tensor key meta send-------------
178
+ if hasattr (t , "key" ):
179
+ current_tensor_name = t .key
180
+ key_data_tensor , _ = convert_object_to_tensor (
181
+ current_tensor_name
182
+ )
183
+ key_data = key_data_tensor .numpy ().tolist ()
184
+ else :
185
+ key_data = []
186
+ # ------------------tensor key meta send-------------
155
187
data .extend (
156
188
[
157
189
len (t .shape ),
158
190
* t .shape ,
159
191
paddle_2_number (t .dtype ),
160
192
int (t .stop_gradient ),
193
+ len (key_data ),
194
+ * key_data ,
161
195
]
162
196
)
163
197
@@ -184,35 +218,44 @@ def send_meta(self, tensor, group, reverse=False, broadcast=False):
184
218
185
219
def _obtain_send_message (self , tensor ):
186
220
if isinstance (tensor , paddle .Tensor ):
187
- return tensor .shape , paddle_2_number (tensor .dtype )
221
+ key = tensor .key if hasattr (tensor , "key" ) else None
222
+ return tensor .shape , paddle_2_number (tensor .dtype ), key
188
223
else :
189
224
shapes = []
190
225
dtypes = []
226
+ keys = []
191
227
for d in tensor :
192
228
assert isinstance (d , paddle .Tensor )
193
229
if d .stop_gradient :
194
230
continue
195
- shape , dtype = self ._obtain_send_message (d )
231
+ shape , dtype , key = self ._obtain_send_message (d )
196
232
shapes .append (shape )
197
233
dtypes .append (dtype )
198
- return tuple (shapes ), tuple (dtypes )
234
+ keys .append (key )
235
+ return tuple (shapes ), tuple (dtypes ), tuple (keys )
199
236
200
237
def set_send_message (self , tensor ):
201
238
(
202
239
self .send_shape_message ,
203
240
self .send_dtype_message ,
241
+ self .send_key_message , # (key1_str, key2_str, key3_str ... )
204
242
) = self ._obtain_send_message (tensor )
205
243
206
244
def check_send_message (self , tensor ):
207
245
if self .send_shape_message is None or self .send_dtype_message is None :
208
246
return
209
- actual_shape , actual_dtype = self ._obtain_send_message (tensor )
247
+ actual_shape , actual_dtype , actual_key = self ._obtain_send_message (
248
+ tensor
249
+ )
210
250
assert (
211
251
self .send_shape_message == actual_shape
212
252
), f"send_shape_message: { self .send_shape_message } , actual_shape: { actual_shape } "
213
253
assert (
214
254
self .send_dtype_message == actual_dtype
215
255
), f"send_dtype_message: { self .send_dtype_message } , actual_dtype: { actual_dtype } "
256
+ assert (
257
+ self .send_key_message == actual_key
258
+ ), f"send_key_message: { self .send_key_message } , actual_key: { actual_key } "
216
259
217
260
def __repr__ (self ):
218
261
return f"send_shape_message: { self .send_shape_message } , send_dtype_message: { self .send_dtype_message } , recv_shape_message: { self .recv_shape_message } , recv_dtype_message: { self .recv_dtype_message } , recv_stop_gradient: { self .recv_stop_gradient } "
@@ -619,9 +662,11 @@ def _p2p_helper(
619
662
recv_shape_msg = send_recv_meta .recv_shape_message
620
663
recv_dtype_msg = send_recv_meta .recv_dtype_message
621
664
recv_stop_gradient = send_recv_meta .recv_stop_gradient
665
+ recv_key_msg = send_recv_meta .recv_key_message
622
666
623
667
send_shape_msg = send_recv_meta .send_shape_message
624
668
send_dtype_msg = send_recv_meta .send_dtype_message
669
+ # backward has no key meta message
625
670
626
671
# model parallel message
627
672
mp_group = _hcg .get_model_parallel_group ()
@@ -636,13 +681,17 @@ def _p2p_helper(
636
681
shape = shape , dtype = number_2_dtype (recv_dtype_msg [idx ])
637
682
)
638
683
tmp .stop_gradient = recv_stop_gradient [idx ]
684
+ if recv_key_msg [idx ] is not None :
685
+ tmp .key = recv_key_msg [idx ]
639
686
tensor_recv_prev .append (tmp )
640
687
tensor_recv_prev = tuple (tensor_recv_prev )
641
688
else :
642
689
tensor_recv_prev = paddle .empty (
643
690
shape = recv_shape_msg , dtype = number_2_dtype (recv_dtype_msg )
644
691
)
645
692
tensor_recv_prev .stop_gradient = recv_stop_gradient
693
+ if recv_key_msg is not None :
694
+ tensor_recv_prev .key = recv_key_msg
646
695
647
696
if recv_next :
648
697
if dynamic_shape :
0 commit comments