Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 6 additions & 23 deletions verl/models/mcore/model_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional

import torch
from megatron.core import parallel_state as mpu
from torch.nested._internal.nested_tensor import NestedTensor

from verl.utils.megatron_utils import unwrap_model
from verl.workers.config import MtpConfig

from .util import (
build_vlm_attn_mask_bshd,
build_vlm_attn_mask_thd,
Comment on lines 24 to +26
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Import is_npu_available from .util so it can be used to handle NPU-specific attention mask requirements in the forward pass.

from .util (
    build_vlm_attn_mask_bshd,
    build_vlm_attn_mask_thd,
    is_npu_available,

postprocess_bshd,
postprocess_bshd_engine,
postprocess_packed_seqs,
Expand Down Expand Up @@ -279,11 +279,7 @@ def gptmodel_forward_model_engine(
# For VLM model, need to pass bshd format `input_ids` and `attention_mask`.
attention_mask = None
if vision_model:
input_ids_rmpad = input_ids.to_padded_tensor(pad_token_id)
seqlens_in_batch = input_ids.offsets().diff()
attention_mask = torch.zeros_like(input_ids_rmpad, dtype=torch.bool)
for i, seqlen in enumerate(seqlens_in_batch):
attention_mask[i, :seqlen] = True
attention_mask = build_vlm_attn_mask_thd(input_ids, pad_token_id)

output_orig = model(
input_ids=input_ids_rmpad,
Expand Down Expand Up @@ -354,23 +350,10 @@ def gptmodel_forward_model_engine(
logits_processor_args.pop("loss_mask")

# For VLM model, need to pass bshd format `input_ids` and `attention_mask`.
attention_mask = attention_mask_bshd
if vision_model:
seqlens_in_batch = input_ids.offsets().diff()
max_seqlen = seqlens_in_batch.max().item()

# For CP, sequence length must be divisible by (2 * cp_size), and for SP by tp_size.
tp_size = mpu.get_tensor_model_parallel_world_size()
cp_size = mpu.get_context_parallel_world_size()
align_size = math.lcm(tp_size, 2 * cp_size) if cp_size > 1 else tp_size
if align_size > 1:
pad_size = (align_size - max_seqlen % align_size) % align_size
max_seqlen += pad_size

input_ids_bshd = input_ids.to_padded_tensor(pad_token_id, output_size=(batch_size, max_seqlen))
attention_mask = torch.zeros_like(input_ids_bshd, dtype=torch.bool)
for i, seqlen in enumerate(seqlens_in_batch):
attention_mask[i, :seqlen] = True
attention_mask = build_vlm_attn_mask_bshd(input_ids, batch_size, pad_token_id)
else:
attention_mask = attention_mask_bshd
Comment on lines 353 to +356
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

On NPU, the attention backend (e.g., npu_fusion_attention) typically expects None instead of a [B, S] mask for padded sequences. While we must preserve the 2D mask (attention_mask_bshd) for post-processing (to filter out padding tokens), we should pass None to the model forward call on NPU. This ensures consistency for both VLM and non-VLM models in bshd format.

        if vision_model:
            attention_mask = build_vlm_attn_mask_bshd(input_ids, batch_size, pad_token_id)
        else:
            attention_mask = attention_mask_bshd

        if is_npu_available:
            attention_mask = None


output_orig = model(
input_ids=input_ids_bshd,
Expand Down
59 changes: 44 additions & 15 deletions verl/models/mcore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,15 +522,6 @@ def postprocess_thd_engine(
return output_new_tensor


def _build_npu_attn_mask(original_attention_mask: torch.Tensor) -> torch.Tensor:
"""Build attn_mask for torch_npu.npu_fusion_attention (B1SS / [B, 1, Sq, Skv])"""
_, seq_len = original_attention_mask.shape
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=original_attention_mask.device)).to(torch.bool)
attn_mask = original_attention_mask.unsqueeze(-1) & original_attention_mask.unsqueeze(-2)
attn_mask = attn_mask & causal_mask
return (~attn_mask).unsqueeze(1).contiguous()


def preprocess_bshd_engine(
input_ids: torch.Tensor, pre_process: bool = True, need_roll: bool = False, use_fp8_padding: bool = False
):
Expand Down Expand Up @@ -623,8 +614,8 @@ def preprocess_bshd_engine(
input_ids_bshd = torch.roll(input_ids_bshd, shifts=-1, dims=1)

if is_npu_available:
# Ascend npu_fusion_attention's attn_mask must be BNSS / B1SS / 11SS / SS; [B, S] is invalid.
attention_mask = _build_npu_attn_mask(attention_mask)
# Ascend npu_fusion_attention sparse-mode-2 attn_mask must be [2048, 2048] is invalid.
attention_mask = None
Comment on lines 616 to +618
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Setting attention_mask to None here will cause a critical failure in postprocess_bshd_engine (lines 663 and 692), as the 2D mask is required to filter out padding tokens from the model output.

Instead of discarding the mask here, you should keep the 2D mask for post-processing. The requirement to pass None to the model forward call on NPU should be handled in the caller (model_forward.py) while preserving this mask for later use.

Additionally, the comment on line 617 is confusingly phrased ("must be [2048, 2048] is invalid").

Suggested change
if is_npu_available:
# Ascend npu_fusion_attention's attn_mask must be BNSS / B1SS / 11SS / SS; [B, S] is invalid.
attention_mask = _build_npu_attn_mask(attention_mask)
# Ascend npu_fusion_attention sparse-mode-2 attn_mask must be [2048, 2048] is invalid.
attention_mask = None
if is_npu_available:
# Keep the 2D attention_mask for post-processing (filtering padding).
# The caller should pass None to the model if the NPU attention backend doesn't support [B, S] masks.
pass


return input_ids_bshd, attention_mask, position_ids

Expand All @@ -643,10 +634,10 @@ def postprocess_bshd_engine(
if is_npu_available:
attention_mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1)
attention_mask = ~attention_mask.bool()

assert output.shape[:2] == attention_mask.shape, (
f"output.shape: {output.shape}, attention_mask.shape: {attention_mask.shape}"
)
else:
assert output.shape[:2] == attention_mask.shape, (
f"output.shape: {output.shape}, attention_mask.shape: {attention_mask.shape}"
)
Comment on lines 634 to +640
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block is now incorrect and will cause a crash on NPU. It attempts to convert a 4D attention mask back to 2D using .diagonal(). However, since _build_npu_attn_mask (which created the 4D mask) has been removed, the mask is either None (causing a crash at line 635) or already 2D (making the diagonal operation invalid).

By keeping the 2D mask in preprocess_bshd_engine, the standard assertion and subsequent logic will work correctly for NPU without this special block.

    assert output.shape[:2] == attention_mask.shape, (
        f"output.shape: {output.shape}, attention_mask.shape: {attention_mask.shape}"
    )


cp_size = mpu.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
Expand Down Expand Up @@ -703,3 +694,41 @@ def postprocess_bshd_engine(
output_new_tensor = torch.nested.as_nested_tensor(output_new, layout=torch.jagged)

return output_new_tensor


def build_vlm_attn_mask_thd(input_ids: torch.Tensor, pad_token_id: int = None) -> Optional[torch.Tensor]:
if is_npu_available:
return None

input_ids_rmpad = input_ids.to_padded_tensor(pad_token_id)
seqlens_in_batch = input_ids.offsets().diff()
attention_mask = torch.zeros_like(input_ids_rmpad, dtype=torch.bool)
for i, seqlen in enumerate(seqlens_in_batch):
attention_mask[i, :seqlen] = True

return attention_mask


def build_vlm_attn_mask_bshd(
input_ids: torch.Tensor, batch_size: int, pad_token_id: int = None
) -> Optional[torch.Tensor]:
if is_npu_available:
return None

seqlens_in_batch = input_ids.offsets().diff()
max_seqlen = seqlens_in_batch.max().item()

# For CP, sequence length must be divisible by (2 * cp_size), and for SP by tp_size.
tp_size = mpu.get_tensor_model_parallel_world_size()
cp_size = mpu.get_context_parallel_world_size()
align_size = math.lcm(tp_size, 2 * cp_size) if cp_size > 1 else tp_size
if align_size > 1:
pad_size = (align_size - max_seqlen % align_size) % align_size
max_seqlen += pad_size

input_ids_bshd = input_ids.to_padded_tensor(pad_token_id, output_size=(batch_size, max_seqlen))
attention_mask = torch.zeros_like(input_ids_bshd, dtype=torch.bool)
for i, seqlen in enumerate(seqlens_in_batch):
attention_mask[i, :seqlen] = True

return attention_mask
Loading