Skip to content

Commit 8018f4f

Browse files
hiyougamasoudhashemi
authored andcommitted
[model] fix: refactor qwen2vl patches & support no-image input for fsdp (verl-project#3496)
### What does this PR do? This PR tries to fix verl-project#3491 ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test Tested with [latest transformers](https://github.com/huggingface/transformers/tree/6e50a8afb2540ac1acaa4b62cf1dd5f1170f6511) <img width="2448" height="540" alt="image" src="https://github.com/user-attachments/assets/06d40f40-572c-4454-8e08-115857f61f21" /> <img width="2796" height="1394" alt="image" src="https://github.com/user-attachments/assets/17489b9c-e376-46e3-80d8-71106d304077" /> <img width="2098" height="744" alt="image" src="https://github.com/user-attachments/assets/8c7f736d-bf09-4ba9-9cf4-0d56e367c526" /> ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes #### ⚠️ Breaking We adopt a new format for Qwen2VL's position ids: (4, batch size, seq len) Assuming a vision position ids (mrope) has a shape of (3, batch size, seq len) and a text position ids (normal rope) has a shape of (1, batch size, seq len), we concatenate both to obtain the final position ids. This aligns with the implementation in the Transformers >= 4.54.0 🤗 https://github.com/huggingface/transformers/blob/v4.54.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1469 #### 🎤 New We have refactored the Qwen2VL and Qwen2.5VL patches, supporting no-image input for FSDP by introducing fake ViT inputs. We have also removed some redundant code for better maintainability. #### 🚨 Changes We move the ulysses logic into the attention function. So the position ids will be scattered before the language model part. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 85eaefa commit 8018f4f

11 files changed

Lines changed: 334 additions & 792 deletions

File tree

examples/grpo_trainer/run_qwen2_5_vl-7b.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ python3 -m verl.trainer.main_ppo \
1414
actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-7B-Instruct \
1515
actor_rollout_ref.actor.optim.lr=1e-6 \
1616
actor_rollout_ref.model.use_remove_padding=True \
17+
actor_rollout_ref.model.use_fused_kernels=True \
1718
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
1819
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=10 \
1920
actor_rollout_ref.actor.use_kl_loss=True \

verl/models/transformers/monkey_patch.py

Lines changed: 59 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,15 @@
1515
Apply monkey-patch function to models
1616
"""
1717

18-
import importlib.metadata
1918
import sys
20-
from functools import lru_cache
2119
from typing import Optional
2220

2321
import torch
24-
from packaging import version
2522
from transformers.modeling_flash_attention_utils import _flash_attention_forward
2623
from transformers.modeling_utils import PreTrainedModel
2724

2825
from verl.utils.import_utils import is_trl_available
26+
from verl.utils.transformers_compat import is_transformers_version_in_range
2927
from verl.utils.ulysses import (
3028
gather_heads_scatter_seq,
3129
gather_seq_scatter_heads,
@@ -51,13 +49,19 @@ def _ulysses_flash_attention_forward(
5149
query_states: torch.Tensor,
5250
key_states: torch.Tensor,
5351
value_states: torch.Tensor,
52+
attention_mask: Optional[torch.Tensor],
53+
query_length: int,
5454
*args,
5555
position_ids: Optional[torch.Tensor] = None,
5656
**kwargs,
5757
):
5858
"""Insert all-to-all before and after flash attention.
5959
DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509
6060
61+
For transformers>=4.55, the flash attention api has changed,
62+
we need to pass the query_length after doing ulysses all2all.
63+
See https://github.com/huggingface/transformers/issues/40399
64+
6165
Args:
6266
query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim)
6367
key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)
@@ -66,64 +70,7 @@ def _ulysses_flash_attention_forward(
6670
6771
Returns:
6872
torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim)
69-
"""
70-
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
71-
72-
########## AlltoAll for Ulysses ##########
73-
if ulysses_sp_size > 1:
74-
assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism"
75-
76-
# NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,
77-
# we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.
78-
# For example:
79-
# - nheads_k=4, sp=8, repeats=2
80-
# - nheads_k=8, sp=8, repeats=1
81-
# - nheads_k=16, sp=8, repeats=1
82-
repeats = max(ulysses_sp_size // key_states.size(2), 1)
83-
key_states = repeat_kv(key_states, repeats)
84-
value_states = repeat_kv(value_states, repeats)
85-
86-
# (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim)
87-
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
88-
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
89-
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
90-
91-
# TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate
92-
# this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly.
93-
# https://github.com/huggingface/transformers/pull/33932
94-
95-
# (bsz, seq_len/n) -> (bsz, seq_len)
96-
position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]
97-
torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())
98-
position_ids = torch.concat(position_ids_list, dim=-1)
99-
100-
# (bsz, seq_len, n_head/n, head_dim)
101-
attn_output = _flash_attention_forward(
102-
query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs
103-
)
104-
105-
########## AlltoAll for Ulysses ##########
106-
if ulysses_sp_size > 1:
107-
# (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)
108-
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
10973
110-
return attn_output
111-
112-
113-
def _ulysses_flash_attention_forward_transformers_4_55(
114-
query_states: torch.Tensor,
115-
key_states: torch.Tensor,
116-
value_states: torch.Tensor,
117-
attention_mask: Optional[torch.Tensor],
118-
query_length: int,
119-
*args,
120-
position_ids: Optional[torch.Tensor] = None,
121-
**kwargs,
122-
):
123-
"""For transformers>=4.55, the flash attention api has changed,
124-
we need to pass the query_length after doing ulysses alltoall.
125-
126-
See https://github.com/huggingface/transformers/issues/40399
12774
"""
12875
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
12976

@@ -178,6 +125,7 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
178125
def _create_ulysses_wrapped_decoder_forward(original_forward):
179126
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
180127
inputs_embeds = kwargs.get("inputs_embeds")
128+
position_ids = kwargs.get("position_ids")
181129
call_kwargs = kwargs.copy()
182130

183131
current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
@@ -189,6 +137,7 @@ def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
189137
)
190138
if slice_now:
191139
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
140+
call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False)
192141
self._needs_initial_slice = False
193142
try:
194143
return original_forward(self, *args, **call_kwargs)
@@ -225,12 +174,7 @@ def patch_forward_with_backends(
225174

226175
forward_with_torch_backend_function = model.__class__.forward
227176
forward_with_triton_backend_function = model.__class__.forward
228-
if model.config.model_type == "qwen2_5_vl":
229-
from verl.models.transformers.qwen2_5_vl import forward_with_torch_backend, forward_with_triton_backend
230-
231-
forward_with_torch_backend_function = forward_with_torch_backend
232-
forward_with_triton_backend_function = forward_with_triton_backend
233-
elif model.config.model_type == "qwen2_vl":
177+
if model.config.model_type in ["qwen2_5_vl", "qwen2_vl"]:
234178
from verl.models.transformers.qwen2_vl import forward_with_torch_backend, forward_with_triton_backend
235179

236180
forward_with_torch_backend_function = forward_with_torch_backend
@@ -296,50 +240,70 @@ def state_dict(self, *args, **kwargs):
296240

297241
# TODO: VLM models only, unify monkey patch to LLM models.
298242
if model.config.model_type == "qwen2_5_vl":
299-
if is_transformers_version_in_range(min_version="4.53.0"):
300-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention
243+
if is_transformers_version_in_range(min_version="4.52.0"):
244+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
245+
Qwen2_5_VLAttention,
246+
Qwen2_5_VLForConditionalGeneration,
247+
Qwen2_5_VLModel,
248+
Qwen2_5_VLTextModel,
249+
)
250+
251+
from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward
252+
253+
Qwen2_5_VLModel.forward = qwen2_vl_base_forward
254+
Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend
301255
else:
302256
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
303257
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
304258
)
259+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
260+
Qwen2_5_VLForConditionalGeneration,
261+
)
262+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel
263+
264+
from verl.models.transformers.qwen2_vl import forward_with_normal_backend
265+
266+
Qwen2_5_VLForConditionalGeneration.forward = forward_with_normal_backend
305267

306268
if use_remove_padding or ulysses_sp_size > 1:
307-
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
269+
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward
308270

309-
Qwen2_5_VLAttention.forward = ulysses_flash_attn_forward
310-
print("Monkey patch FlashAttention2.forward in Qwen2.5VL")
271+
Qwen2_5_VLAttention.forward = qwen2_vl_attn_forward
272+
print("Monkey patch Qwen2.5VL attention layer")
311273

312274
if ulysses_sp_size > 1:
313-
if is_transformers_version_in_range(min_version="4.52.0"):
314-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLTextModel
275+
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
315276

316-
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLTextModel)
317-
else:
318-
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLModel
277+
elif model.config.model_type == "qwen2_vl":
278+
if is_transformers_version_in_range(min_version="4.52.0"):
279+
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
280+
Qwen2VLAttention,
281+
Qwen2VLForConditionalGeneration,
282+
Qwen2VLModel,
283+
Qwen2VLTextModel,
284+
)
319285

320-
patch_vlm_for_ulysses_input_slicing(Qwen2_5_VLModel)
286+
from verl.models.transformers.qwen2_vl import forward_with_normal_backend, qwen2_vl_base_forward
321287

322-
elif model.config.model_type == "qwen2_vl":
323-
if is_transformers_version_in_range(min_version="4.53.0"):
324-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
288+
Qwen2VLModel.forward = qwen2_vl_base_forward
289+
Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend
325290
else:
326291
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
292+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration
293+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel as Qwen2VLTextModel
327294

328-
if use_remove_padding or ulysses_sp_size > 1:
329-
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
295+
from verl.models.transformers.qwen2_vl import forward_with_normal_backend
330296

331-
Qwen2VLAttention.forward = ulysses_flash_attn_forward
332-
print("Monkey patch FlashAttention2.forward in Qwen2VL")
297+
Qwen2VLForConditionalGeneration.forward = forward_with_normal_backend
333298

334-
if ulysses_sp_size > 1:
335-
if is_transformers_version_in_range(min_version="4.52.0"):
336-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLTextModel
299+
if use_remove_padding or ulysses_sp_size > 1:
300+
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward
337301

338-
patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)
339-
else:
340-
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
302+
Qwen2VLAttention.forward = qwen2_vl_attn_forward
303+
print("Monkey patch Qwen2VL attention layer")
341304

342-
patch_vlm_for_ulysses_input_slicing(Qwen2VLModel)
305+
if ulysses_sp_size > 1:
306+
patch_vlm_for_ulysses_input_slicing(Qwen2VLTextModel)
343307

344308
elif model.config.model_type == "kimi_vl":
345309
if use_remove_padding or ulysses_sp_size > 1:
@@ -357,43 +321,14 @@ def state_dict(self, *args, **kwargs):
357321

358322
return
359323

360-
# transformers<=4.47.1
361324
if use_remove_padding or ulysses_sp_size > 1:
362-
if hasattr(module, "_flash_attention_forward"):
325+
if hasattr(module, "_flash_attention_forward"): # transformers <= 4.47.1 or legacy models
363326
module._flash_attention_forward = _ulysses_flash_attention_forward
364327
print(f"Monkey patch _flash_attention_forward in {model.__module__}")
365328
else:
366-
if is_transformers_version_in_range(min_version="4.55.0"):
367-
from transformers.integrations import flash_attention
368-
369-
flash_attention._flash_attention_forward = _ulysses_flash_attention_forward_transformers_4_55
370-
print(f"Monkey patch _flash_attention_forward in {model.__module__} for new api")
371-
else:
372-
# 4.48.0 <= transformers <= 4.54.1, Vision attention
373-
from transformers.integrations import flash_attention
329+
from transformers.integrations import flash_attention
374330

375-
flash_attention._flash_attention_forward = _ulysses_flash_attention_forward
376-
print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}")
331+
flash_attention._flash_attention_forward = _ulysses_flash_attention_forward
332+
print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}")
377333

378334
patch_forward_with_backends(model, use_fused_kernels=use_fused_kernels, fused_kernels_backend=fused_kernels_backend)
379-
380-
381-
@lru_cache
382-
def is_transformers_version_in_range(min_version: Optional[str] = None, max_version: Optional[str] = None) -> bool:
383-
try:
384-
# Get the installed version of the transformers library
385-
transformers_version_str = importlib.metadata.version("transformers")
386-
except importlib.metadata.PackageNotFoundError as e:
387-
raise ModuleNotFoundError("The `transformers` package is not installed.") from e
388-
389-
transformers_version = version.parse(transformers_version_str)
390-
391-
lower_bound_check = True
392-
if min_version is not None:
393-
lower_bound_check = version.parse(min_version) <= transformers_version
394-
395-
upper_bound_check = True
396-
if max_version is not None:
397-
upper_bound_check = transformers_version <= version.parse(max_version)
398-
399-
return lower_bound_check and upper_bound_check

0 commit comments

Comments
 (0)