fix: flatten multi-component position_ids to 1D for nested tensor compatibility#5886
fix: flatten multi-component position_ids to 1D for nested tensor compatibility#5886yifannnwu wants to merge 1 commit intoverl-project:mainfrom
Conversation
…patibility Models with multi-dimensional RoPE (Qwen3.5, Qwen2-VL) produce position_ids of shape (num_components, seq_len). When passed to torch.nested.as_nested_tensor with jagged layout, dim 0 is treated as the ragged dimension, causing num_components (e.g. 4) to be ragged instead of seq_len. Additionally, 3D jagged nested tensors have broken unbind() and to_padded_tensor() in PyTorch (pytorch/pytorch#153238). Fix: flatten to 1D (seq_len * num_components) before creating the nested tensor, store num_pos_components as metadata, and reshape back in prepare_model_inputs. Fixes the FSDP engine_workers path (use_legacy_worker_impl=disable) for multi-component RoPE models. Related: verl-project#5772 (same bug class on Megatron path).
There was a problem hiding this comment.
Code Review
This pull request introduces support for multi-component position_ids by flattening them into 1D tensors to circumvent PyTorch nested tensor limitations, while adding logic to reconstruct the original shapes during model input preparation. Review feedback suggests optimizing the flattening process in padding.py for better performance and using inferred dimensions in view calls within transformer_impl.py for more idiomatic code.
| # 3D jagged nested tensors have broken unbind() and to_padded_tensor() in PyTorch | ||
| # (see pytorch/pytorch#153238), so we flatten to 1D and reshape back in prepare_model_inputs | ||
| num_pos_components = curr_pos_ids.shape[0] | ||
| valid_ids = curr_pos_ids[:, curr_mask].T.contiguous().flatten() # (valid_len * num_components,) |
There was a problem hiding this comment.
The current implementation performs multiple intermediate operations and an explicit contiguous copy. Using curr_pos_ids.T[curr_mask].flatten() is more efficient and idiomatic. Since boolean indexing with a mask already creates a new contiguous tensor, this approach avoids the extra overhead of contiguous() and simplifies the operation in this hot path.
| valid_ids = curr_pos_ids[:, curr_mask].T.contiguous().flatten() # (valid_len * num_components,) | |
| valid_ids = curr_pos_ids.T[curr_mask].flatten() # (valid_len * num_components,) |
| total_nnz = flat_pos.shape[0] // num_pos_components | ||
| position_ids_rmpad = flat_pos.view(total_nnz, num_pos_components).T.unsqueeze(1) |
There was a problem hiding this comment.
The explicit calculation of total_nnz is redundant here. Using -1 in the view method is more idiomatic and robust, as it allows PyTorch to automatically infer the dimension size while ensuring the total number of elements is compatible with num_pos_components. This also makes the code slightly cleaner by removing an unnecessary intermediate variable.
| total_nnz = flat_pos.shape[0] // num_pos_components | |
| position_ids_rmpad = flat_pos.view(total_nnz, num_pos_components).T.unsqueeze(1) | |
| position_ids_rmpad = flat_pos.view(-1, num_pos_components).T.unsqueeze(1) |
|
|
Summary
Fixes a crash when training models with multi-dimensional RoPE position_ids (e.g. Qwen3.5's 4-component
(text, height, width, temporal), Qwen2-VL) on the FSDPengine_workers.pypath (use_legacy_worker_impl=disable).Related: #5772 (same class of bug on the Megatron/mbridge path)
Problem
left_right_2_no_padding()inpadding.pycreates per-sample position_ids of shape(num_components, valid_len)and passes them directly totorch.nested.as_nested_tensor(..., layout=torch.jagged).Jagged layout treats dim 0 as the ragged dimension, so it interprets
num_components(e.g. 4) as ragged instead ofseq_len. The resulting nested tensor has shape[batch, *(ragged=4), seq_len]instead of the intended[batch, *(ragged=seq_len), num_components].Downstream in
transformer_impl.py,to_padded_tensor()produces incorrect shapes → crash inapply_rotary_pos_emb.Transposing to
(valid_len, num_components)before nesting doesn't work either, because 3D jagged nested tensors have brokenunbind()andto_padded_tensor()in PyTorch (see pytorch/pytorch#153238).Fix
Flatten multi-component position_ids to 1D
(seq_len * num_components,)before creating the nested tensor, keeping it purely 2D jagged. Storenum_pos_componentsas non-tensor metadata. Reshape back to(num_components, batch, seq_len)inprepare_model_inputs.Changes
verl/workers/utils/padding.py—left_right_2_no_padding():(num_components, valid_len)→(valid_len * num_components,)via.T.contiguous().flatten()num_pos_componentsin non-tensor metadata viatu.assign_non_tensor_dataverl/workers/engine/fsdp/transformer_impl.py—prepare_model_inputs():.view(total_nnz, num_components).T.unsqueeze(1)→(num_components, 1, total_nnz)to_padded_tensor(..., output_size=(batch, max_seq_len * num_components)).view(batch, max_seq_len, num_components).permute(2, 0, 1)→(num_components, batch, max_seq_len)Affected models
Any model with multi-dimensional RoPE on the FSDP
engine_workers.pypath:Test plan
use_legacy_worker_impl=disable: 3 GRPO training steps + validation pass. Previously crashed with RoPE shape mismatch.