Skip to content

Commit 3266957

Browse files
Aya-ZIbrafacebook-github-bot
authored andcommitted
kv_dq zero initialization to avoid NaNs from FA3 (pytorch#3632)
Summary: X-link: facebookresearch/FBGEMM#708 Pull Request resolved: pytorch#3632 Running evals with FP8 KV gives NaNs due to issues in FA3. For more context: D68708685 To reproduce: > sh ai_codesign/gen_ai/disagg_generator_launcher/start_server_moe.sh -m 17b_text_sft -a " --ffn_quantize_mode=fp8_rowwise --attn_quantize_mode=fp8_rowwise --kv_cache_quantization=8 " Mitigating these issues, change dequantize_fp8_cache initialization of output buffers from at::empty to at::zeros Reviewed By: jasonjk-park Differential Revision: D68574038 fbshipit-source-id: 3f3f5573d13f1b4046e6880363533eb1c2dfa268
1 parent 4965f35 commit 3266957

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,9 +1874,11 @@ std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
18741874
// correct block_tables. (2) From outside, keep a persistent buffer that has a
18751875
// matching shape with the original paged KV and feed the same buffer
18761876
// into this function at every layer to reuse it and prevent allocation.
1877-
auto cache_K_dq = at::empty(
1877+
1878+
// FIXME: T213958042
1879+
auto cache_K_dq = at::zeros(
18781880
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
1879-
auto cache_V_dq = at::empty(
1881+
auto cache_V_dq = at::zeros(
18801882
{B_KV, MAX_T, N_KVH, D_H}, cache_K.options().dtype(at::kBFloat16));
18811883
18821884
if (B == 0) {

0 commit comments

Comments
 (0)