@@ -251,15 +251,22 @@ def merge_router_topk_indices(attention_mask, input_ids, mini_layer_topk_idx_lis
251251 .contiguous ()
252252 )
253253
254+ fp8 = tf_config .fp8
255+ use_fp8_padding = fp8 in ["e4m3" , "hybrid" ]
256+
254257 if input_ids .is_nested :
255258 batch_size = input_ids .shape [0 ]
256- _ , packed_seq_params , _ = preprocess_thd_engine (input_ids , pre_process = True )
259+ _ , packed_seq_params , _ = preprocess_thd_engine (
260+ input_ids , pre_process = True , use_fp8_padding = use_fp8_padding
261+ )
257262 layers_topk_idx = postprocess_thd_engine (
258263 layers_topk_idx , packed_seq_params , input_ids , batch_size , post_process = True
259264 )
260265 else :
261266 batch_size , seq_len = attention_mask .shape [:2 ]
262- _ , packed_seq_params = preprocess_packed_seqs (input_ids , attention_mask , pre_process = True )
267+ _ , packed_seq_params = preprocess_packed_seqs (
268+ input_ids , attention_mask , pre_process = True , use_fp8_padding = use_fp8_padding
269+ )
263270 layers_topk_idx = postprocess_packed_seqs (
264271 layers_topk_idx , packed_seq_params , attention_mask , batch_size , seq_len , post_process = True
265272 )
@@ -286,10 +293,17 @@ def set_router_replay_data(layers_topk_idx, attention_mask, tf_config, vp_rank=N
286293 None: The function updates internal RouterReplay instances in-place.
287294 """
288295 with torch .no_grad ():
296+ fp8 = tf_config .fp8
297+ use_fp8_padding = fp8 in ["e4m3" , "hybrid" ]
298+
289299 if layers_topk_idx .is_nested :
290- layers_topk_idx_rmpad , _ , _ = preprocess_thd_engine (layers_topk_idx , pre_process = True )
300+ layers_topk_idx_rmpad , _ , _ = preprocess_thd_engine (
301+ layers_topk_idx , pre_process = True , use_fp8_padding = use_fp8_padding
302+ )
291303 else :
292- layers_topk_idx_rmpad , _ = preprocess_packed_seqs (layers_topk_idx , attention_mask , pre_process = True )
304+ layers_topk_idx_rmpad , _ = preprocess_packed_seqs (
305+ layers_topk_idx , attention_mask , pre_process = True , use_fp8_padding = use_fp8_padding
306+ )
293307 layers_topk_idx_rmpad = layers_topk_idx_rmpad .contiguous () # 1, dynamic_bs_all, layer_num, topk
294308
295309 # 1, dynamic_bs_split, layer_num, topk
0 commit comments