-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[megatron] fix: Adjust the attention mask shape for VLM with Megatron on NPU #5904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,17 +13,17 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import math | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from megatron.core import parallel_state as mpu | ||
| from torch.nested._internal.nested_tensor import NestedTensor | ||
|
|
||
| from verl.utils.megatron_utils import unwrap_model | ||
| from verl.workers.config import MtpConfig | ||
|
|
||
| from .util import ( | ||
| build_vlm_attn_mask_bshd, | ||
| build_vlm_attn_mask_thd, | ||
| postprocess_bshd, | ||
| postprocess_bshd_engine, | ||
| postprocess_packed_seqs, | ||
|
|
@@ -279,11 +279,7 @@ def gptmodel_forward_model_engine( | |
| # For VLM model, need to pass bshd format `input_ids` and `attention_mask`. | ||
| attention_mask = None | ||
| if vision_model: | ||
| input_ids_rmpad = input_ids.to_padded_tensor(pad_token_id) | ||
| seqlens_in_batch = input_ids.offsets().diff() | ||
| attention_mask = torch.zeros_like(input_ids_rmpad, dtype=torch.bool) | ||
| for i, seqlen in enumerate(seqlens_in_batch): | ||
| attention_mask[i, :seqlen] = True | ||
| attention_mask = build_vlm_attn_mask_thd(input_ids, pad_token_id) | ||
|
|
||
| output_orig = model( | ||
| input_ids=input_ids_rmpad, | ||
|
|
@@ -354,23 +350,10 @@ def gptmodel_forward_model_engine( | |
| logits_processor_args.pop("loss_mask") | ||
|
|
||
| # For VLM model, need to pass bshd format `input_ids` and `attention_mask`. | ||
| attention_mask = attention_mask_bshd | ||
| if vision_model: | ||
| seqlens_in_batch = input_ids.offsets().diff() | ||
| max_seqlen = seqlens_in_batch.max().item() | ||
|
|
||
| # For CP, sequence length must be divisible by (2 * cp_size), and for SP by tp_size. | ||
| tp_size = mpu.get_tensor_model_parallel_world_size() | ||
| cp_size = mpu.get_context_parallel_world_size() | ||
| align_size = math.lcm(tp_size, 2 * cp_size) if cp_size > 1 else tp_size | ||
| if align_size > 1: | ||
| pad_size = (align_size - max_seqlen % align_size) % align_size | ||
| max_seqlen += pad_size | ||
|
|
||
| input_ids_bshd = input_ids.to_padded_tensor(pad_token_id, output_size=(batch_size, max_seqlen)) | ||
| attention_mask = torch.zeros_like(input_ids_bshd, dtype=torch.bool) | ||
| for i, seqlen in enumerate(seqlens_in_batch): | ||
| attention_mask[i, :seqlen] = True | ||
| attention_mask = build_vlm_attn_mask_bshd(input_ids, batch_size, pad_token_id) | ||
| else: | ||
| attention_mask = attention_mask_bshd | ||
|
Comment on lines
353
to
+356
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On NPU, the attention backend (e.g., if vision_model:
attention_mask = build_vlm_attn_mask_bshd(input_ids, batch_size, pad_token_id)
else:
attention_mask = attention_mask_bshd
if is_npu_available:
attention_mask = None |
||
|
|
||
| output_orig = model( | ||
| input_ids=input_ids_bshd, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -522,15 +522,6 @@ def postprocess_thd_engine( | |||||||||||||||||||
| return output_new_tensor | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def _build_npu_attn_mask(original_attention_mask: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||
| """Build attn_mask for torch_npu.npu_fusion_attention (B1SS / [B, 1, Sq, Skv])""" | ||||||||||||||||||||
| _, seq_len = original_attention_mask.shape | ||||||||||||||||||||
| causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=original_attention_mask.device)).to(torch.bool) | ||||||||||||||||||||
| attn_mask = original_attention_mask.unsqueeze(-1) & original_attention_mask.unsqueeze(-2) | ||||||||||||||||||||
| attn_mask = attn_mask & causal_mask | ||||||||||||||||||||
| return (~attn_mask).unsqueeze(1).contiguous() | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def preprocess_bshd_engine( | ||||||||||||||||||||
| input_ids: torch.Tensor, pre_process: bool = True, need_roll: bool = False, use_fp8_padding: bool = False | ||||||||||||||||||||
| ): | ||||||||||||||||||||
|
|
@@ -623,8 +614,8 @@ def preprocess_bshd_engine( | |||||||||||||||||||
| input_ids_bshd = torch.roll(input_ids_bshd, shifts=-1, dims=1) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if is_npu_available: | ||||||||||||||||||||
| # Ascend npu_fusion_attention's attn_mask must be BNSS / B1SS / 11SS / SS; [B, S] is invalid. | ||||||||||||||||||||
| attention_mask = _build_npu_attn_mask(attention_mask) | ||||||||||||||||||||
| # Ascend npu_fusion_attention sparse-mode-2 attn_mask must be [2048, 2048] is invalid. | ||||||||||||||||||||
| attention_mask = None | ||||||||||||||||||||
|
Comment on lines
616
to
+618
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Setting Instead of discarding the mask here, you should keep the 2D mask for post-processing. The requirement to pass Additionally, the comment on line 617 is confusingly phrased ("must be [2048, 2048] is invalid").
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| return input_ids_bshd, attention_mask, position_ids | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -643,10 +634,10 @@ def postprocess_bshd_engine( | |||||||||||||||||||
| if is_npu_available: | ||||||||||||||||||||
| attention_mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1) | ||||||||||||||||||||
| attention_mask = ~attention_mask.bool() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| assert output.shape[:2] == attention_mask.shape, ( | ||||||||||||||||||||
| f"output.shape: {output.shape}, attention_mask.shape: {attention_mask.shape}" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| assert output.shape[:2] == attention_mask.shape, ( | ||||||||||||||||||||
| f"output.shape: {output.shape}, attention_mask.shape: {attention_mask.shape}" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
Comment on lines
634
to
+640
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block is now incorrect and will cause a crash on NPU. It attempts to convert a 4D attention mask back to 2D using By keeping the 2D mask in assert output.shape[:2] == attention_mask.shape, (
f"output.shape: {output.shape}, attention_mask.shape: {attention_mask.shape}"
) |
||||||||||||||||||||
|
|
||||||||||||||||||||
| cp_size = mpu.get_context_parallel_world_size() | ||||||||||||||||||||
| cp_rank = mpu.get_context_parallel_rank() | ||||||||||||||||||||
|
|
@@ -703,3 +694,41 @@ def postprocess_bshd_engine( | |||||||||||||||||||
| output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return output_new_tensor | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def build_vlm_attn_mask_thd(input_ids: torch.Tensor, pad_token_id: int = None) -> Optional[torch.Tensor]: | ||||||||||||||||||||
| if is_npu_available: | ||||||||||||||||||||
| return None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| input_ids_rmpad = input_ids.to_padded_tensor(pad_token_id) | ||||||||||||||||||||
| seqlens_in_batch = input_ids.offsets().diff() | ||||||||||||||||||||
| attention_mask = torch.zeros_like(input_ids_rmpad, dtype=torch.bool) | ||||||||||||||||||||
| for i, seqlen in enumerate(seqlens_in_batch): | ||||||||||||||||||||
| attention_mask[i, :seqlen] = True | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return attention_mask | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def build_vlm_attn_mask_bshd( | ||||||||||||||||||||
| input_ids: torch.Tensor, batch_size: int, pad_token_id: int = None | ||||||||||||||||||||
| ) -> Optional[torch.Tensor]: | ||||||||||||||||||||
| if is_npu_available: | ||||||||||||||||||||
| return None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| seqlens_in_batch = input_ids.offsets().diff() | ||||||||||||||||||||
| max_seqlen = seqlens_in_batch.max().item() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # For CP, sequence length must be divisible by (2 * cp_size), and for SP by tp_size. | ||||||||||||||||||||
| tp_size = mpu.get_tensor_model_parallel_world_size() | ||||||||||||||||||||
| cp_size = mpu.get_context_parallel_world_size() | ||||||||||||||||||||
| align_size = math.lcm(tp_size, 2 * cp_size) if cp_size > 1 else tp_size | ||||||||||||||||||||
| if align_size > 1: | ||||||||||||||||||||
| pad_size = (align_size - max_seqlen % align_size) % align_size | ||||||||||||||||||||
| max_seqlen += pad_size | ||||||||||||||||||||
|
|
||||||||||||||||||||
| input_ids_bshd = input_ids.to_padded_tensor(pad_token_id, output_size=(batch_size, max_seqlen)) | ||||||||||||||||||||
| attention_mask = torch.zeros_like(input_ids_bshd, dtype=torch.bool) | ||||||||||||||||||||
| for i, seqlen in enumerate(seqlens_in_batch): | ||||||||||||||||||||
| attention_mask[i, :seqlen] = True | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return attention_mask | ||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import
is_npu_availablefrom.utilso it can be used to handle NPU-specific attention mask requirements in the forward pass.