Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/test_protocol_v2_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,32 @@ def test_chunk_tensordict():
assert torch.all(torch.eq(tensor.data["pixel_values"], expect["pixel_values"])).item()


@pytest.mark.parametrize("rope_dim", [3, 4])
def test_maybe_fix_3d_position_ids_broken_equal_length_layout(rope_dim: int):
batch_num = 2
seq_len = 7
samples = [
torch.arange(i * rope_dim * seq_len, (i + 1) * rope_dim * seq_len).view(rope_dim, seq_len)
for i in range(batch_num)
]
broken_position_ids = torch.nested.as_nested_tensor(samples, layout=torch.jagged)
td = tu.get_tensordict({"position_ids": broken_position_ids})

tu.maybe_fix_3d_position_ids(td)

fixed_position_ids = td["position_ids"]
expected_offsets = torch.arange(
0,
(batch_num + 1) * seq_len,
seq_len,
dtype=fixed_position_ids.offsets().dtype,
device=fixed_position_ids.offsets().device,
)
torch.testing.assert_close(fixed_position_ids.offsets(), expected_offsets)
for idx, sample in enumerate(samples):
torch.testing.assert_close(fixed_position_ids[idx], sample)


def test_assign_non_tensor_stack_with_nested_lists():
"""Test assign_non_tensor_stack with lists of lists."""
td = tu.get_tensordict({"obs": torch.randn(3, 4)}, non_tensor_dict={})
Expand Down
14 changes: 14 additions & 0 deletions verl/models/transformers/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

114 changes: 114 additions & 0 deletions verl/models/transformers/common/vision_pos_embed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple

import torch


class BilinearInterpolationTensors(NamedTuple):
device: torch.device
idx_tensor: torch.Tensor
weight_tensor: torch.Tensor
grid_ts: torch.Tensor
grid_hs: torch.Tensor
grid_ws: torch.Tensor


def build_bilinear_interpolation_tensors(
grid_thw: torch.Tensor,
num_grid_per_side: int,
weight_dtype: torch.dtype,
) -> BilinearInterpolationTensors:
device = grid_thw.device
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]

idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]

for h, w in zip(grid_hs, grid_ws, strict=False):
h_size = int(h.item())
w_size = int(w.item())
h_idxs = torch.linspace(0, num_grid_per_side - 1, h_size, device=device)
w_idxs = torch.linspace(0, num_grid_per_side - 1, w_size, device=device)

h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs_floor + 1).clip(max=num_grid_per_side - 1)
w_idxs_ceil = (w_idxs_floor + 1).clip(max=num_grid_per_side - 1)

dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor

base_h = h_idxs_floor * num_grid_per_side
base_h_ceil = h_idxs_ceil * num_grid_per_side

indices = [
(base_h[:, None] + w_idxs_floor[None, :]).flatten(),
(base_h[:, None] + w_idxs_ceil[None, :]).flatten(),
(base_h_ceil[:, None] + w_idxs_floor[None, :]).flatten(),
(base_h_ceil[:, None] + w_idxs_ceil[None, :]).flatten(),
]

weights = [
((1 - dh)[:, None] * (1 - dw)[None, :]).flatten(),
((1 - dh)[:, None] * dw[None, :]).flatten(),
(dh[:, None] * (1 - dw)[None, :]).flatten(),
(dh[:, None] * dw[None, :]).flatten(),
]

for i in range(4):
idx_list[i].extend(indices[i].tolist())
weight_list[i].extend(weights[i].tolist())

idx_tensor = torch.as_tensor(idx_list, dtype=torch.long, device=device)
weight_tensor = torch.as_tensor(weight_list, dtype=weight_dtype, device=device)
return BilinearInterpolationTensors(
device=device,
idx_tensor=idx_tensor,
weight_tensor=weight_tensor,
grid_ts=grid_ts,
grid_hs=grid_hs,
grid_ws=grid_ws,
)


def merge_bilinear_interpolated_pos_embeds(
pos_embeds: torch.Tensor,
weight_tensor: torch.Tensor,
grid_ts: torch.Tensor,
grid_hs: torch.Tensor,
grid_ws: torch.Tensor,
merge_size: int,
) -> torch.Tensor:
pos_embeds = pos_embeds * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]

split_sizes = [int(h.item()) * int(w.item()) for h, w in zip(grid_hs, grid_ws, strict=False)]
patch_pos_embeds = patch_pos_embeds.split(split_sizes)

patch_pos_embeds_permute = []
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws, strict=False):
t_size = int(t.item())
h_size = int(h.item())
w_size = int(w.item())
pos_embed = pos_embed.repeat(t_size, 1)
pos_embed = (
pos_embed.view(t_size, h_size // merge_size, merge_size, w_size // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)

return torch.cat(patch_pos_embeds_permute)
9 changes: 9 additions & 0 deletions verl/models/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def state_dict(self, *args, **kwargs):
from verl.models.transformers.qwen3_vl import (
forward_with_normal_backend,
patch_qwen3_vl_moe_sparse_moe_block_forward,
patch_qwen3_vl_vision_fast_pos_embed_interpolate,
qwen3_vl_base_forward,
)

Expand All @@ -436,6 +437,7 @@ def state_dict(self, *args, **kwargs):
# Step 1.5: patch Qwen3VLMoeTextSparseMoeBlock to fix transformers 4.57.3 bug
if model.config.model_type == "qwen3_vl_moe" and is_transformers_version_in_range(max_version="4.57.3"):
patch_qwen3_vl_moe_sparse_moe_block_forward()
patch_qwen3_vl_vision_fast_pos_embed_interpolate()

# Step 2: patch input for multimodal sequence parallelism
if ulysses_sp_size > 1:
Expand Down Expand Up @@ -489,11 +491,13 @@ def state_dict(self, *args, **kwargs):
from transformers.models.qwen3_5.modeling_qwen3_5 import (
Qwen3_5ForConditionalGeneration,
Qwen3_5Model,
Qwen3_5TextModel,
Qwen3_5VisionModel,
)
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeForConditionalGeneration,
Qwen3_5MoeModel,
Qwen3_5MoeTextModel,
Qwen3_5MoeVisionModel,
)

Expand All @@ -513,6 +517,11 @@ def state_dict(self, *args, **kwargs):
Qwen3_5VisionModel.fast_pos_embed_interpolate = fast_pos_embed_interpolate
Qwen3_5MoeVisionModel.fast_pos_embed_interpolate = fast_pos_embed_interpolate

# Step 3: patch input for multimodal sequence parallelism
if ulysses_sp_size > 1:
patch_vlm_for_ulysses_input_slicing(Qwen3_5TextModel)
patch_vlm_for_ulysses_input_slicing(Qwen3_5MoeTextModel)

if use_remove_padding or ulysses_sp_size > 1:
if hasattr(module, "_flash_attention_forward"): # transformers <= 4.47.1 or legacy models
module._flash_attention_forward = _ulysses_flash_attention_forward
Expand Down
4 changes: 2 additions & 2 deletions verl/models/transformers/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def forward_with_torch_backend(

# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
rolled_labels = labels
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
Expand Down Expand Up @@ -535,7 +535,7 @@ def forward_with_triton_backend(

# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
rolled_labels = labels
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
Expand Down
85 changes: 21 additions & 64 deletions verl/models/transformers/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,73 +23,30 @@
Qwen3_5ForConditionalGeneration,
)

from .common.vision_pos_embed_utils import (
build_bilinear_interpolation_tensors,
merge_bilinear_interpolated_pos_embeds,
)

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


def fast_pos_embed_interpolate(self, grid_thw):
grid_thw_list = grid_thw.tolist()
grid_ts = [row[0] for row in grid_thw_list]
grid_hs = [row[1] for row in grid_thw_list]
grid_ws = [row[2] for row in grid_thw_list]
# Modification: # Get device from grid_thw to avoid self.pos_embed being on CPU when FSDP2 enables cpu_offload
device = grid_thw.device

idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]

for t, h, w in grid_thw_list:
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)

h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)

dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor

base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side

indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]

weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]

for i in range(4):
idx_list[i].extend(indices[i].tolist())
weight_list[i].extend(weights[i].tolist())

idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
weight_tensor = torch.tensor(weight_list, dtype=self.pos_embed.weight.dtype, device=device)
pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]

patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws, strict=False)])

patch_pos_embeds_permute = []
merge_size = self.config.spatial_merge_size
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws, strict=False):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds
interpolation_tensors = build_bilinear_interpolation_tensors(
grid_thw=grid_thw,
num_grid_per_side=self.num_grid_per_side,
weight_dtype=self.pos_embed.weight.dtype,
)
pos_embeds = self.pos_embed(interpolation_tensors.idx_tensor).to(interpolation_tensors.device)
return merge_bilinear_interpolated_pos_embeds(
pos_embeds=pos_embeds,
weight_tensor=interpolation_tensors.weight_tensor,
grid_ts=interpolation_tensors.grid_ts,
grid_hs=interpolation_tensors.grid_hs,
grid_ws=interpolation_tensors.grid_ws,
merge_size=self.config.spatial_merge_size,
)


def _get_input_embeds(
Expand Down Expand Up @@ -208,7 +165,7 @@ def forward_with_torch_backend(

# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
rolled_labels = labels
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
Expand Down Expand Up @@ -242,7 +199,7 @@ def forward_with_triton_backend(

# Loss calculations
if labels is not None:
rolled_labels = torch.roll(labels, shifts=-1, dims=-1)
rolled_labels = labels
elif input_ids is not None:
rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1)
else:
Expand Down
Loading