1515Apply monkey-patch function to models
1616"""
1717
18- import importlib .metadata
1918import sys
20- from functools import lru_cache
2119from typing import Optional
2220
2321import torch
24- from packaging import version
2522from transformers .modeling_flash_attention_utils import _flash_attention_forward
2623from transformers .modeling_utils import PreTrainedModel
2724
2825from verl .utils .import_utils import is_trl_available
26+ from verl .utils .transformers_compat import is_transformers_version_in_range
2927from verl .utils .ulysses import (
3028 gather_heads_scatter_seq ,
3129 gather_seq_scatter_heads ,
@@ -51,13 +49,19 @@ def _ulysses_flash_attention_forward(
5149 query_states : torch .Tensor ,
5250 key_states : torch .Tensor ,
5351 value_states : torch .Tensor ,
52+ attention_mask : Optional [torch .Tensor ],
53+ query_length : int ,
5454 * args ,
5555 position_ids : Optional [torch .Tensor ] = None ,
5656 ** kwargs ,
5757):
5858 """Insert all-to-all before and after flash attention.
5959 DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509
6060
61+ For transformers>=4.55, the flash attention api has changed,
62+ we need to pass the query_length after doing ulysses all2all.
63+ See https://github.com/huggingface/transformers/issues/40399
64+
6165 Args:
6266 query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim)
6367 key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)
@@ -66,64 +70,7 @@ def _ulysses_flash_attention_forward(
6670
6771 Returns:
6872 torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim)
69- """
70- ulysses_sp_size = get_ulysses_sequence_parallel_world_size ()
71-
72- ########## AlltoAll for Ulysses ##########
73- if ulysses_sp_size > 1 :
74- assert position_ids is not None , "position_ids is required for Ulysses sequence parallelism"
75-
76- # NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,
77- # we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.
78- # For example:
79- # - nheads_k=4, sp=8, repeats=2
80- # - nheads_k=8, sp=8, repeats=1
81- # - nheads_k=16, sp=8, repeats=1
82- repeats = max (ulysses_sp_size // key_states .size (2 ), 1 )
83- key_states = repeat_kv (key_states , repeats )
84- value_states = repeat_kv (value_states , repeats )
85-
86- # (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim)
87- query_states = gather_seq_scatter_heads (query_states , seq_dim = 1 , head_dim = 2 )
88- key_states = gather_seq_scatter_heads (key_states , seq_dim = 1 , head_dim = 2 )
89- value_states = gather_seq_scatter_heads (value_states , seq_dim = 1 , head_dim = 2 )
90-
91- # TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate
92- # this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly.
93- # https://github.com/huggingface/transformers/pull/33932
94-
95- # (bsz, seq_len/n) -> (bsz, seq_len)
96- position_ids_list = [torch .empty_like (position_ids ) for _ in range (ulysses_sp_size )]
97- torch .distributed .all_gather (position_ids_list , position_ids , group = get_ulysses_sequence_parallel_group ())
98- position_ids = torch .concat (position_ids_list , dim = - 1 )
99-
100- # (bsz, seq_len, n_head/n, head_dim)
101- attn_output = _flash_attention_forward (
102- query_states , key_states , value_states , * args , position_ids = position_ids , ** kwargs
103- )
104-
105- ########## AlltoAll for Ulysses ##########
106- if ulysses_sp_size > 1 :
107- # (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)
108- attn_output = gather_heads_scatter_seq (attn_output , seq_dim = 1 , head_dim = 2 )
10973
110- return attn_output
111-
112-
113- def _ulysses_flash_attention_forward_transformers_4_55 (
114- query_states : torch .Tensor ,
115- key_states : torch .Tensor ,
116- value_states : torch .Tensor ,
117- attention_mask : Optional [torch .Tensor ],
118- query_length : int ,
119- * args ,
120- position_ids : Optional [torch .Tensor ] = None ,
121- ** kwargs ,
122- ):
123- """For transformers>=4.55, the flash attention api has changed,
124- we need to pass the query_length after doing ulysses alltoall.
125-
126- See https://github.com/huggingface/transformers/issues/40399
12774 """
12875 ulysses_sp_size = get_ulysses_sequence_parallel_world_size ()
12976
@@ -178,6 +125,7 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
178125 def _create_ulysses_wrapped_decoder_forward (original_forward ):
179126 def ulysses_wrapped_decoder_forward (self , * args , ** kwargs ):
180127 inputs_embeds = kwargs .get ("inputs_embeds" )
128+ position_ids = kwargs .get ("position_ids" )
181129 call_kwargs = kwargs .copy ()
182130
183131 current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size ()
@@ -189,6 +137,7 @@ def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
189137 )
190138 if slice_now :
191139 call_kwargs ["inputs_embeds" ] = slice_input_tensor (inputs_embeds , dim = 1 , padding = False )
140+ call_kwargs ["position_ids" ] = slice_input_tensor (position_ids , dim = - 1 , padding = False )
192141 self ._needs_initial_slice = False
193142 try :
194143 return original_forward (self , * args , ** call_kwargs )
@@ -225,12 +174,7 @@ def patch_forward_with_backends(
225174
226175 forward_with_torch_backend_function = model .__class__ .forward
227176 forward_with_triton_backend_function = model .__class__ .forward
228- if model .config .model_type == "qwen2_5_vl" :
229- from verl .models .transformers .qwen2_5_vl import forward_with_torch_backend , forward_with_triton_backend
230-
231- forward_with_torch_backend_function = forward_with_torch_backend
232- forward_with_triton_backend_function = forward_with_triton_backend
233- elif model .config .model_type == "qwen2_vl" :
177+ if model .config .model_type in ["qwen2_5_vl" , "qwen2_vl" ]:
234178 from verl .models .transformers .qwen2_vl import forward_with_torch_backend , forward_with_triton_backend
235179
236180 forward_with_torch_backend_function = forward_with_torch_backend
@@ -296,50 +240,70 @@ def state_dict(self, *args, **kwargs):
296240
297241 # TODO: VLM models only, unify monkey patch to LLM models.
298242 if model .config .model_type == "qwen2_5_vl" :
299- if is_transformers_version_in_range (min_version = "4.53.0" ):
300- from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import Qwen2_5_VLAttention
243+ if is_transformers_version_in_range (min_version = "4.52.0" ):
244+ from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import (
245+ Qwen2_5_VLAttention ,
246+ Qwen2_5_VLForConditionalGeneration ,
247+ Qwen2_5_VLModel ,
248+ Qwen2_5_VLTextModel ,
249+ )
250+
251+ from verl .models .transformers .qwen2_vl import forward_with_normal_backend , qwen2_vl_base_forward
252+
253+ Qwen2_5_VLModel .forward = qwen2_vl_base_forward
254+ Qwen2_5_VLForConditionalGeneration .forward = forward_with_normal_backend
301255 else :
302256 from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import (
303257 Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention ,
304258 )
259+ from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import (
260+ Qwen2_5_VLForConditionalGeneration ,
261+ )
262+ from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel
263+
264+ from verl .models .transformers .qwen2_vl import forward_with_normal_backend
265+
266+ Qwen2_5_VLForConditionalGeneration .forward = forward_with_normal_backend
305267
306268 if use_remove_padding or ulysses_sp_size > 1 :
307- from verl .models .transformers .qwen2_vl import ulysses_flash_attn_forward
269+ from verl .models .transformers .qwen2_vl import qwen2_vl_attn_forward
308270
309- Qwen2_5_VLAttention .forward = ulysses_flash_attn_forward
310- print ("Monkey patch FlashAttention2.forward in Qwen2.5VL" )
271+ Qwen2_5_VLAttention .forward = qwen2_vl_attn_forward
272+ print ("Monkey patch Qwen2.5VL attention layer " )
311273
312274 if ulysses_sp_size > 1 :
313- if is_transformers_version_in_range (min_version = "4.52.0" ):
314- from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import Qwen2_5_VLTextModel
275+ patch_vlm_for_ulysses_input_slicing (Qwen2_5_VLTextModel )
315276
316- patch_vlm_for_ulysses_input_slicing (Qwen2_5_VLTextModel )
317- else :
318- from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import Qwen2_5_VLModel
277+ elif model .config .model_type == "qwen2_vl" :
278+ if is_transformers_version_in_range (min_version = "4.52.0" ):
279+ from transformers .models .qwen2_vl .modeling_qwen2_vl import (
280+ Qwen2VLAttention ,
281+ Qwen2VLForConditionalGeneration ,
282+ Qwen2VLModel ,
283+ Qwen2VLTextModel ,
284+ )
319285
320- patch_vlm_for_ulysses_input_slicing ( Qwen2_5_VLModel )
286+ from verl . models . transformers . qwen2_vl import forward_with_normal_backend , qwen2_vl_base_forward
321287
322- elif model .config .model_type == "qwen2_vl" :
323- if is_transformers_version_in_range (min_version = "4.53.0" ):
324- from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VLAttention
288+ Qwen2VLModel .forward = qwen2_vl_base_forward
289+ Qwen2VLForConditionalGeneration .forward = forward_with_normal_backend
325290 else :
326291 from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
292+ from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VLForConditionalGeneration
293+ from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VLModel as Qwen2VLTextModel
327294
328- if use_remove_padding or ulysses_sp_size > 1 :
329- from verl .models .transformers .qwen2_vl import ulysses_flash_attn_forward
295+ from verl .models .transformers .qwen2_vl import forward_with_normal_backend
330296
331- Qwen2VLAttention .forward = ulysses_flash_attn_forward
332- print ("Monkey patch FlashAttention2.forward in Qwen2VL" )
297+ Qwen2VLForConditionalGeneration .forward = forward_with_normal_backend
333298
334- if ulysses_sp_size > 1 :
335- if is_transformers_version_in_range (min_version = "4.52.0" ):
336- from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VLTextModel
299+ if use_remove_padding or ulysses_sp_size > 1 :
300+ from verl .models .transformers .qwen2_vl import qwen2_vl_attn_forward
337301
338- patch_vlm_for_ulysses_input_slicing (Qwen2VLTextModel )
339- else :
340- from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VLModel
302+ Qwen2VLAttention .forward = qwen2_vl_attn_forward
303+ print ("Monkey patch Qwen2VL attention layer" )
341304
342- patch_vlm_for_ulysses_input_slicing (Qwen2VLModel )
305+ if ulysses_sp_size > 1 :
306+ patch_vlm_for_ulysses_input_slicing (Qwen2VLTextModel )
343307
344308 elif model .config .model_type == "kimi_vl" :
345309 if use_remove_padding or ulysses_sp_size > 1 :
@@ -357,43 +321,14 @@ def state_dict(self, *args, **kwargs):
357321
358322 return
359323
360- # transformers<=4.47.1
361324 if use_remove_padding or ulysses_sp_size > 1 :
362- if hasattr (module , "_flash_attention_forward" ):
325+ if hasattr (module , "_flash_attention_forward" ): # transformers <= 4.47.1 or legacy models
363326 module ._flash_attention_forward = _ulysses_flash_attention_forward
364327 print (f"Monkey patch _flash_attention_forward in { model .__module__ } " )
365328 else :
366- if is_transformers_version_in_range (min_version = "4.55.0" ):
367- from transformers .integrations import flash_attention
368-
369- flash_attention ._flash_attention_forward = _ulysses_flash_attention_forward_transformers_4_55
370- print (f"Monkey patch _flash_attention_forward in { model .__module__ } for new api" )
371- else :
372- # 4.48.0 <= transformers <= 4.54.1, Vision attention
373- from transformers .integrations import flash_attention
329+ from transformers .integrations import flash_attention
374330
375- flash_attention ._flash_attention_forward = _ulysses_flash_attention_forward
376- print (f"Monkey patch _flash_attention_forward in { flash_attention .__name__ } " )
331+ flash_attention ._flash_attention_forward = _ulysses_flash_attention_forward
332+ print (f"Monkey patch _flash_attention_forward in { flash_attention .__name__ } " )
377333
378334 patch_forward_with_backends (model , use_fused_kernels = use_fused_kernels , fused_kernels_backend = fused_kernels_backend )
379-
380-
381- @lru_cache
382- def is_transformers_version_in_range (min_version : Optional [str ] = None , max_version : Optional [str ] = None ) -> bool :
383- try :
384- # Get the installed version of the transformers library
385- transformers_version_str = importlib .metadata .version ("transformers" )
386- except importlib .metadata .PackageNotFoundError as e :
387- raise ModuleNotFoundError ("The `transformers` package is not installed." ) from e
388-
389- transformers_version = version .parse (transformers_version_str )
390-
391- lower_bound_check = True
392- if min_version is not None :
393- lower_bound_check = version .parse (min_version ) <= transformers_version
394-
395- upper_bound_check = True
396- if max_version is not None :
397- upper_bound_check = transformers_version <= version .parse (max_version )
398-
399- return lower_bound_check and upper_bound_check
0 commit comments