Skip to content

Commit aae6c5a

Browse files
committed
add fp8 padding for router replay
1 parent ef13027 commit aae6c5a

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

verl/utils/megatron/router_replay_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,15 +251,22 @@ def merge_router_topk_indices(attention_mask, input_ids, mini_layer_topk_idx_lis
251251
.contiguous()
252252
)
253253

254+
fp8 = tf_config.fp8
255+
use_fp8_padding = fp8 in ["e4m3", "hybrid"]
256+
254257
if input_ids.is_nested:
255258
batch_size = input_ids.shape[0]
256-
_, packed_seq_params, _ = preprocess_thd_engine(input_ids, pre_process=True)
259+
_, packed_seq_params, _ = preprocess_thd_engine(
260+
input_ids, pre_process=True, use_fp8_padding=use_fp8_padding
261+
)
257262
layers_topk_idx = postprocess_thd_engine(
258263
layers_topk_idx, packed_seq_params, input_ids, batch_size, post_process=True
259264
)
260265
else:
261266
batch_size, seq_len = attention_mask.shape[:2]
262-
_, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=True)
267+
_, packed_seq_params = preprocess_packed_seqs(
268+
input_ids, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding
269+
)
263270
layers_topk_idx = postprocess_packed_seqs(
264271
layers_topk_idx, packed_seq_params, attention_mask, batch_size, seq_len, post_process=True
265272
)
@@ -286,10 +293,17 @@ def set_router_replay_data(layers_topk_idx, attention_mask, tf_config, vp_rank=N
286293
None: The function updates internal RouterReplay instances in-place.
287294
"""
288295
with torch.no_grad():
296+
fp8 = tf_config.fp8
297+
use_fp8_padding = fp8 in ["e4m3", "hybrid"]
298+
289299
if layers_topk_idx.is_nested:
290-
layers_topk_idx_rmpad, _, _ = preprocess_thd_engine(layers_topk_idx, pre_process=True)
300+
layers_topk_idx_rmpad, _, _ = preprocess_thd_engine(
301+
layers_topk_idx, pre_process=True, use_fp8_padding=use_fp8_padding
302+
)
291303
else:
292-
layers_topk_idx_rmpad, _ = preprocess_packed_seqs(layers_topk_idx, attention_mask, pre_process=True)
304+
layers_topk_idx_rmpad, _ = preprocess_packed_seqs(
305+
layers_topk_idx, attention_mask, pre_process=True, use_fp8_padding=use_fp8_padding
306+
)
293307
layers_topk_idx_rmpad = layers_topk_idx_rmpad.contiguous() # 1, dynamic_bs_all, layer_num, topk
294308

295309
# 1, dynamic_bs_split, layer_num, topk

0 commit comments

Comments
 (0)