Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions verl/workers/engine/fsdp/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,8 +904,13 @@ def prepare_model_inputs(self, micro_batch: TensorDict):

if pad_mode == DatasetPadMode.NO_PADDING:
input_ids_rmpad = input_ids.values().unsqueeze(0) # (1, total_nnz)
if position_ids.dim() == 3:
position_ids_rmpad = position_ids.values().unsqueeze(1) # (4, 1, total_nnz)
num_pos_components = tu.get_non_tensor_data(data=micro_batch, key="num_pos_components", default=0)
if num_pos_components > 0:
# position_ids stored as flattened 1D nested tensor: (total_nnz * num_components,)
# reshape to (num_components, 1, total_nnz)
flat_pos = position_ids.values() # (total_nnz * num_components,)
total_nnz = flat_pos.shape[0] // num_pos_components
position_ids_rmpad = flat_pos.view(total_nnz, num_pos_components).T.unsqueeze(1)
Comment on lines +912 to +913
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The explicit calculation of total_nnz is redundant here. Using -1 in the view method is more idiomatic and robust, as it allows PyTorch to automatically infer the dimension size while ensuring the total number of elements is compatible with num_pos_components. This also makes the code slightly cleaner by removing an unnecessary intermediate variable.

Suggested change
total_nnz = flat_pos.shape[0] // num_pos_components
position_ids_rmpad = flat_pos.view(total_nnz, num_pos_components).T.unsqueeze(1)
position_ids_rmpad = flat_pos.view(-1, num_pos_components).T.unsqueeze(1)

else:
position_ids_rmpad = position_ids.values().unsqueeze(0) # (1, total_nnz)
else:
Expand Down Expand Up @@ -976,10 +981,13 @@ def prepare_model_inputs(self, micro_batch: TensorDict):
input_ids, padding=pad_token_id, output_size=(batch_size, max_seq_len)
)

if position_ids.dim() == 3:
num_pos_components = tu.get_non_tensor_data(data=micro_batch, key="num_pos_components", default=0)
if num_pos_components > 0:
# position_ids stored as flattened 1D nested: each sample has (seq_len * num_components,)
# pad to (batch, max_seq_len * num_components), then reshape to (num_components, batch, max_seq_len)
position_ids = torch.nested.to_padded_tensor(
position_ids, padding=0, output_size=(batch_size, 4, max_seq_len)
).transpose(0, 1) # (4, batch_size, max_seq_len)
position_ids, padding=0, output_size=(batch_size, max_seq_len * num_pos_components)
).view(batch_size, max_seq_len, num_pos_components).permute(2, 0, 1)
else:
position_ids = torch.nested.to_padded_tensor(
position_ids, padding=0, output_size=(batch_size, max_seq_len)
Expand Down
10 changes: 8 additions & 2 deletions verl/workers/utils/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,21 @@ def left_right_2_no_padding(data: TensorDict) -> TensorDict:
input_ids_nested = torch.nested.nested_tensor_from_jagged(input_ids_rmpad.squeeze(-1), offsets=cu_seqlens)

position_ids_list = []
num_pos_components = 0 # 0 means 1D position_ids, >0 means multi-component (e.g. 4 for Qwen3.5/Qwen2-VL)
for i in range(attention_mask.shape[0]):
curr_mask = attention_mask[i].bool()
curr_pos_ids = position_ids[i]
if curr_pos_ids.dim() == 1: # (seq_len,)
valid_ids = curr_pos_ids[curr_mask]
else: # (4, seq_len)
valid_ids = curr_pos_ids[:, curr_mask]
else: # (num_components, seq_len) — flatten to 1D for nested tensor compatibility
# 3D jagged nested tensors have broken unbind() and to_padded_tensor() in PyTorch
# (see pytorch/pytorch#153238), so we flatten to 1D and reshape back in prepare_model_inputs
num_pos_components = curr_pos_ids.shape[0]
valid_ids = curr_pos_ids[:, curr_mask].T.contiguous().flatten() # (valid_len * num_components,)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation performs multiple intermediate operations and an explicit contiguous copy. Using curr_pos_ids.T[curr_mask].flatten() is more efficient and idiomatic. Since boolean indexing with a mask already creates a new contiguous tensor, this approach avoids the extra overhead of contiguous() and simplifies the operation in this hot path.

Suggested change
valid_ids = curr_pos_ids[:, curr_mask].T.contiguous().flatten() # (valid_len * num_components,)
valid_ids = curr_pos_ids.T[curr_mask].flatten() # (valid_len * num_components,)

position_ids_list.append(valid_ids)
position_ids_nested = torch.nested.as_nested_tensor(position_ids_list, layout=torch.jagged)
if num_pos_components > 0:
tu.assign_non_tensor_data(data, "num_pos_components", num_pos_components)

data["input_ids"] = input_ids_nested
data["position_ids"] = position_ids_nested
Expand Down