Skip to content

fix: flatten multi-component position_ids to 1D for nested tensor compatibility#5886

Open
yifannnwu wants to merge 1 commit intoverl-project:mainfrom
yifannnwu:fix/3d-position-ids-nested-tensor
Open

fix: flatten multi-component position_ids to 1D for nested tensor compatibility#5886
yifannnwu wants to merge 1 commit intoverl-project:mainfrom
yifannnwu:fix/3d-position-ids-nested-tensor

Conversation

@yifannnwu
Copy link
Copy Markdown
Contributor

Summary

Fixes a crash when training models with multi-dimensional RoPE position_ids (e.g. Qwen3.5's 4-component (text, height, width, temporal), Qwen2-VL) on the FSDP engine_workers.py path (use_legacy_worker_impl=disable).

Related: #5772 (same class of bug on the Megatron/mbridge path)

Problem

left_right_2_no_padding() in padding.py creates per-sample position_ids of shape (num_components, valid_len) and passes them directly to torch.nested.as_nested_tensor(..., layout=torch.jagged).

Jagged layout treats dim 0 as the ragged dimension, so it interprets num_components (e.g. 4) as ragged instead of seq_len. The resulting nested tensor has shape [batch, *(ragged=4), seq_len] instead of the intended [batch, *(ragged=seq_len), num_components].

Downstream in transformer_impl.py, to_padded_tensor() produces incorrect shapes → crash in apply_rotary_pos_emb.

Transposing to (valid_len, num_components) before nesting doesn't work either, because 3D jagged nested tensors have broken unbind() and to_padded_tensor() in PyTorch (see pytorch/pytorch#153238).

Fix

Flatten multi-component position_ids to 1D (seq_len * num_components,) before creating the nested tensor, keeping it purely 2D jagged. Store num_pos_components as non-tensor metadata. Reshape back to (num_components, batch, seq_len) in prepare_model_inputs.

Changes

verl/workers/utils/padding.pyleft_right_2_no_padding():

  • Flatten (num_components, valid_len)(valid_len * num_components,) via .T.contiguous().flatten()
  • Store num_pos_components in non-tensor metadata via tu.assign_non_tensor_data

verl/workers/engine/fsdp/transformer_impl.pyprepare_model_inputs():

  • NO_PADDING path: .view(total_nnz, num_components).T.unsqueeze(1)(num_components, 1, total_nnz)
  • Padded path: to_padded_tensor(..., output_size=(batch, max_seq_len * num_components)).view(batch, max_seq_len, num_components).permute(2, 0, 1)(num_components, batch, max_seq_len)

Affected models

Any model with multi-dimensional RoPE on the FSDP engine_workers.py path:

  • Qwen3.5 (4-component: text + 3D spatial)
  • Qwen2-VL (3-component mRoPE)
  • Other VLMs with multi-component position_ids

Test plan

  • Qwen3.5-0.8B, FSDP2, use_legacy_worker_impl=disable: 3 GRPO training steps + validation pass. Previously crashed with RoPE shape mismatch.
  • Qwen2-VL (not tested, same position_ids format — should work)

…patibility

Models with multi-dimensional RoPE (Qwen3.5, Qwen2-VL) produce position_ids
of shape (num_components, seq_len). When passed to torch.nested.as_nested_tensor
with jagged layout, dim 0 is treated as the ragged dimension, causing
num_components (e.g. 4) to be ragged instead of seq_len.

Additionally, 3D jagged nested tensors have broken unbind() and to_padded_tensor()
in PyTorch (pytorch/pytorch#153238).

Fix: flatten to 1D (seq_len * num_components) before creating the nested tensor,
store num_pos_components as metadata, and reshape back in prepare_model_inputs.

Fixes the FSDP engine_workers path (use_legacy_worker_impl=disable) for
multi-component RoPE models. Related: verl-project#5772 (same bug class on Megatron path).
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for multi-component position_ids by flattening them into 1D tensors to circumvent PyTorch nested tensor limitations, while adding logic to reconstruct the original shapes during model input preparation. Review feedback suggests optimizing the flattening process in padding.py for better performance and using inferred dimensions in view calls within transformer_impl.py for more idiomatic code.

# 3D jagged nested tensors have broken unbind() and to_padded_tensor() in PyTorch
# (see pytorch/pytorch#153238), so we flatten to 1D and reshape back in prepare_model_inputs
num_pos_components = curr_pos_ids.shape[0]
valid_ids = curr_pos_ids[:, curr_mask].T.contiguous().flatten() # (valid_len * num_components,)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation performs multiple intermediate operations and an explicit contiguous copy. Using curr_pos_ids.T[curr_mask].flatten() is more efficient and idiomatic. Since boolean indexing with a mask already creates a new contiguous tensor, this approach avoids the extra overhead of contiguous() and simplifies the operation in this hot path.

Suggested change
valid_ids = curr_pos_ids[:, curr_mask].T.contiguous().flatten() # (valid_len * num_components,)
valid_ids = curr_pos_ids.T[curr_mask].flatten() # (valid_len * num_components,)

Comment on lines +912 to +913
total_nnz = flat_pos.shape[0] // num_pos_components
position_ids_rmpad = flat_pos.view(total_nnz, num_pos_components).T.unsqueeze(1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The explicit calculation of total_nnz is redundant here. Using -1 in the view method is more idiomatic and robust, as it allows PyTorch to automatically infer the dimension size while ensuring the total number of elements is compatible with num_pos_components. This also makes the code slightly cleaner by removing an unnecessary intermediate variable.

Suggested change
total_nnz = flat_pos.shape[0] // num_pos_components
position_ids_rmpad = flat_pos.view(total_nnz, num_pos_components).T.unsqueeze(1)
position_ids_rmpad = flat_pos.view(-1, num_pos_components).T.unsqueeze(1)

@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants