Skip to content

Commit 9b718e1

Browse files
authored
Correct Attention FLOPS estimation in flops_counter.py (#4929)
### What does this PR do? Fix Attention FLOPS Calculation for Causal LLMs ### Problem The current attention FLOPS calculations for all causal LLMs are missing the `/2` factor for causal (lower triangular) attention mask. This causes **2× overestimation** of attention FLOPS. Additionally, DeepSeek V3's MLA attention incorrectly uses the same dimension for both Q@K^T and attn@V operations, when `v_head_dim` differs from `q_head_dim`. ### Changes Summary | Function | Models | Change | |----------|--------|--------| | `_estimate_qwen2_flops` | qwen2, llama, qwen3, mistral, etc. | `12 *` → `6 *` | | `_estimate_qwen3_vl_flops` | qwen3_vl | `12 *` → `6 *` | | `_estimate_qwen3_vl_moe_flops` | qwen3_vl_moe | `12 *` → `6 *` | | `_estimate_qwen2_moe_flops` | qwen2_moe, qwen3_moe | `12 *` → `6 *` | | `_estimate_gemma3_flops` | gemma3_text | `12 *` → `6 *` | | `_estimate_apertus_flops` | apertus | `12 *` → `6 *` | | `_estimate_gpt_oss_flops` | gpt_oss | `12 *` → `6 *` | | `_estimate_deepseek_v3_flops` | deepseek_v3 | `12 * q` → `3 * (q + v)` | | `_estimate_qwen3_vit_flop` | ViT (vision) | **No change** (bidirectional) | For causal (autoregressive) attention, only the lower triangular portion of the attention matrix is computed: ``` Attention Matrix (causal): [✓ · · ·] [✓ ✓ · ·] [✓ ✓ ✓ ·] [✓ ✓ ✓ ✓] ``` | Model Type | Before | After | Overestimation | |------------|--------|-------|----------------| | Standard GQA/MHA | `12 * seq² * d` | `6 * seq² * d` | **2.0×** | | DeepSeek V3 MLA | `12 * seq² * q` | `3 * seq² * (q+v)` | **2.4×** | ### Reference This fix aligns with [Megatron-LM's FLOPS calculation](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/training/training.py): - Uses `/2` for causal attention - Separately accounts for `q_head_dim` and `v_head_dim` in MLA ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `veomni`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data`, `cfg`, `reward` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) - [ ] If your PR is related to the `recipe` submodule, please also update the reference to the submodule commit via `git submodule update --remote` or `cd recipe && git pull origin main`.
1 parent 07d4033 commit 9b718e1

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

verl/utils/flops_counter.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _estimate_qwen2_flops(config, tokens_sum, batch_seqlens, delta_time):
112112
seqlen_square_sum = 0
113113
for seqlen in batch_seqlens:
114114
seqlen_square_sum += seqlen * seqlen
115-
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
115+
attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
116116

117117
# all_layer & all_token fwd & bwd flops
118118
flops_all_token = dense_N_flops + attn_qkv_flops
@@ -149,7 +149,7 @@ def _estimate_qwen3_vl_flops(config, tokens_sum, batch_seqlens, delta_time, **ka
149149
seqlen_square_sum = 0
150150
for seqlen in batch_seqlens:
151151
seqlen_square_sum += seqlen * seqlen
152-
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
152+
attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
153153

154154
# vit flops
155155
images_seqlens = kargs.get("images_seqlens", None)
@@ -197,7 +197,7 @@ def _estimate_qwen3_vl_moe_flops(config, tokens_sum, batch_seqlens, delta_time,
197197
seqlen_square_sum = 0
198198
for seqlen in batch_seqlens:
199199
seqlen_square_sum += seqlen * seqlen
200-
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
200+
attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
201201

202202
# vit flops
203203
images_seqlens = kargs.get("images_seqlens", None)
@@ -304,7 +304,10 @@ def _estimate_deepseek_v3_flops(config, tokens_sum, batch_seqlens, delta_time):
304304
for seqlen in batch_seqlens:
305305
seqlen_square_sum += seqlen * seqlen * num_hidden_layers
306306

307-
attn_qkv_flops = 12 * seqlen_square_sum * q_head_dim * num_query_heads
307+
# Core attention FLOPS for MLA with causal mask:
308+
# Q @ K^T: 3 * 2 * seq^2 * q_head_dim * num_heads / 2 (causal)
309+
# attn @ V: 3 * 2 * seq^2 * v_head_dim * num_heads / 2 (causal)
310+
attn_qkv_flops = 3 * seqlen_square_sum * (q_head_dim + config.v_head_dim) * num_query_heads
308311
# all_layer & all_token fwd & bwk flops
309312
flops_all_token = dense_N_flops + attn_qkv_flops
310313
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
@@ -341,7 +344,7 @@ def _estimate_qwen2_moe_flops(config, tokens_sum, batch_seqlens, delta_time):
341344
seqlen_square_sum = 0
342345
for seqlen in batch_seqlens:
343346
seqlen_square_sum += seqlen * seqlen
344-
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
347+
attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
345348

346349
# all_layer & all_token fwd & bwd flops
347350
flops_all_token = dense_N_flops + attn_qkv_flops
@@ -409,7 +412,7 @@ def _estimate_gemma3_flops(config, tokens_sum, batch_seqlens, delta_time):
409412
seqlen_square_sum += seqlen * seqlen
410413
seqlen_square_sum *= num_hidden_layers
411414

412-
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads
415+
attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads
413416

414417
# all_layer & all_token fwd & bwd flops
415418
flops_all_token = dense_N_flops + attn_qkv_flops
@@ -449,7 +452,7 @@ def _estimate_apertus_flops(config, tokens_sum, batch_seqlens, delta_time):
449452
seqlen_square_sum = 0
450453
for seqlen in batch_seqlens:
451454
seqlen_square_sum += seqlen * seqlen
452-
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
455+
attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
453456

454457
# all_layer & all_token fwd & bwd flops
455458
flops_all_token = dense_N_flops + attn_qkv_flops
@@ -520,7 +523,7 @@ def _estimate_gpt_oss_flops(config, tokens_sum, batch_seqlens, delta_time):
520523
seqlen_square_sum += seqlen * seqlen
521524
seqlen_square_sum *= num_hidden_layers
522525

523-
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads
526+
attn_qkv_flops = 6 * seqlen_square_sum * head_dim * num_attention_heads
524527

525528
# Total FLOPs
526529
flops_all_token = dense_N_flops + attn_qkv_flops

0 commit comments

Comments
 (0)