[megatron] fix: Adjust the attention mask shape for VLM with Megatron on NPU#5904
[megatron] fix: Adjust the attention mask shape for VLM with Megatron on NPU#5904ZLiao097 wants to merge 1 commit intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors VLM attention mask generation by centralizing the logic into utility functions and modifies NPU-specific mask handling. The review highlights critical issues where discarding the attention mask in preprocess_bshd_engine will break post-processing logic that relies on it to filter padding tokens. Furthermore, the NPU-specific code in postprocess_bshd_engine is now broken because it expects a 4D mask that is no longer being created. The feedback suggests preserving the 2D mask for utility purposes while passing None to the model forward call on NPU devices.
| if is_npu_available: | ||
| # Ascend npu_fusion_attention's attn_mask must be BNSS / B1SS / 11SS / SS; [B, S] is invalid. | ||
| attention_mask = _build_npu_attn_mask(attention_mask) | ||
| # Ascend npu_fusion_attention sparse-mode-2 attn_mask must be [2048, 2048] is invalid. | ||
| attention_mask = None |
There was a problem hiding this comment.
Setting attention_mask to None here will cause a critical failure in postprocess_bshd_engine (lines 663 and 692), as the 2D mask is required to filter out padding tokens from the model output.
Instead of discarding the mask here, you should keep the 2D mask for post-processing. The requirement to pass None to the model forward call on NPU should be handled in the caller (model_forward.py) while preserving this mask for later use.
Additionally, the comment on line 617 is confusingly phrased ("must be [2048, 2048] is invalid").
| if is_npu_available: | |
| # Ascend npu_fusion_attention's attn_mask must be BNSS / B1SS / 11SS / SS; [B, S] is invalid. | |
| attention_mask = _build_npu_attn_mask(attention_mask) | |
| # Ascend npu_fusion_attention sparse-mode-2 attn_mask must be [2048, 2048] is invalid. | |
| attention_mask = None | |
| if is_npu_available: | |
| # Keep the 2D attention_mask for post-processing (filtering padding). | |
| # The caller should pass None to the model if the NPU attention backend doesn't support [B, S] masks. | |
| pass |
| if is_npu_available: | ||
| attention_mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1) | ||
| attention_mask = ~attention_mask.bool() | ||
|
|
||
| assert output.shape[:2] == attention_mask.shape, ( | ||
| f"output.shape: {output.shape}, attention_mask.shape: {attention_mask.shape}" | ||
| ) | ||
| else: | ||
| assert output.shape[:2] == attention_mask.shape, ( | ||
| f"output.shape: {output.shape}, attention_mask.shape: {attention_mask.shape}" | ||
| ) |
There was a problem hiding this comment.
This block is now incorrect and will cause a crash on NPU. It attempts to convert a 4D attention mask back to 2D using .diagonal(). However, since _build_npu_attn_mask (which created the 4D mask) has been removed, the mask is either None (causing a crash at line 635) or already 2D (making the diagonal operation invalid).
By keeping the 2D mask in preprocess_bshd_engine, the standard assertion and subsequent logic will work correctly for NPU without this special block.
assert output.shape[:2] == attention_mask.shape, (
f"output.shape: {output.shape}, attention_mask.shape: {attention_mask.shape}"
)| from .util import ( | ||
| build_vlm_attn_mask_bshd, | ||
| build_vlm_attn_mask_thd, |
| if vision_model: | ||
| seqlens_in_batch = input_ids.offsets().diff() | ||
| max_seqlen = seqlens_in_batch.max().item() | ||
|
|
||
| # For CP, sequence length must be divisible by (2 * cp_size), and for SP by tp_size. | ||
| tp_size = mpu.get_tensor_model_parallel_world_size() | ||
| cp_size = mpu.get_context_parallel_world_size() | ||
| align_size = math.lcm(tp_size, 2 * cp_size) if cp_size > 1 else tp_size | ||
| if align_size > 1: | ||
| pad_size = (align_size - max_seqlen % align_size) % align_size | ||
| max_seqlen += pad_size | ||
|
|
||
| input_ids_bshd = input_ids.to_padded_tensor(pad_token_id, output_size=(batch_size, max_seqlen)) | ||
| attention_mask = torch.zeros_like(input_ids_bshd, dtype=torch.bool) | ||
| for i, seqlen in enumerate(seqlens_in_batch): | ||
| attention_mask[i, :seqlen] = True | ||
| attention_mask = build_vlm_attn_mask_bshd(input_ids, batch_size, pad_token_id) | ||
| else: | ||
| attention_mask = attention_mask_bshd |
There was a problem hiding this comment.
On NPU, the attention backend (e.g., npu_fusion_attention) typically expects None instead of a [B, S] mask for padded sequences. While we must preserve the 2D mask (attention_mask_bshd) for post-processing (to filter out padding tokens), we should pass None to the model forward call on NPU. This ensures consistency for both VLM and non-VLM models in bshd format.
if vision_model:
attention_mask = build_vlm_attn_mask_bshd(input_ids, batch_size, pad_token_id)
else:
attention_mask = attention_mask_bshd
if is_npu_available:
attention_mask = None
What does this PR do?
detail in this issue: #5878
In short, the VLM+Megatron pipeline has seen little usage on NPUs, so it was not previously adapted to the NPU environment. This PR made the adapted for it.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.