Skip to content

Commit 52e9d05

Browse files
authored
[Ernie 4.5 VL Moe] Post merge adjustments (#43117)
post merge fixes
1 parent 5c68832 commit 52e9d05

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

src/transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1705,6 +1705,8 @@ def prepare_inputs_for_generation(
17051705
past_key_values=None,
17061706
image_grid_thw=None,
17071707
video_grid_thw=None,
1708+
use_cache=True,
1709+
is_first_iteration=False,
17081710
# Intentionally ignore position ids to force custom cache logic
17091711
position_ids=None,
17101712
**kwargs,
@@ -1717,6 +1719,8 @@ def prepare_inputs_for_generation(
17171719
past_key_values=past_key_values,
17181720
image_grid_thw=image_grid_thw,
17191721
video_grid_thw=video_grid_thw,
1722+
use_cache=use_cache,
1723+
is_first_iteration=is_first_iteration,
17201724
**kwargs,
17211725
)
17221726

@@ -1732,7 +1736,7 @@ def prepare_inputs_for_generation(
17321736
mm_token_type_ids=model_inputs.get("mm_token_type_ids"),
17331737
)
17341738

1735-
if model_inputs["cache_position"][0] != 0:
1739+
if not is_first_iteration and use_cache:
17361740
model_inputs["pixel_values"] = None
17371741
model_inputs["pixel_values_videos"] = None
17381742
model_inputs["mm_token_type_ids"] = None

src/transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1393,6 +1393,8 @@ def prepare_inputs_for_generation(
13931393
past_key_values=None,
13941394
image_grid_thw=None,
13951395
video_grid_thw=None,
1396+
use_cache=True,
1397+
is_first_iteration=False,
13961398
# Intentionally ignore position ids to force custom cache logic
13971399
position_ids=None,
13981400
**kwargs,
@@ -1405,6 +1407,8 @@ def prepare_inputs_for_generation(
14051407
past_key_values=past_key_values,
14061408
image_grid_thw=image_grid_thw,
14071409
video_grid_thw=video_grid_thw,
1410+
use_cache=use_cache,
1411+
is_first_iteration=is_first_iteration,
14081412
**kwargs,
14091413
)
14101414

@@ -1420,7 +1424,7 @@ def prepare_inputs_for_generation(
14201424
mm_token_type_ids=model_inputs.get("mm_token_type_ids"),
14211425
)
14221426

1423-
if model_inputs["cache_position"][0] != 0:
1427+
if not is_first_iteration and use_cache:
14241428
model_inputs["pixel_values"] = None
14251429
model_inputs["pixel_values_videos"] = None
14261430
model_inputs["mm_token_type_ids"] = None

tests/models/ernie4_5_vl_moe/test_modeling_ernie4_5_vl_moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def load_model(self, dtype, attn_implementation="sdpa"):
313313
device_map="auto",
314314
dtype=dtype,
315315
attn_implementation=attn_implementation,
316+
experts_implementation="eager",
316317
revision="refs/pr/10",
317318
)
318319

@@ -549,6 +550,7 @@ def load_model(self, dtype, attn_implementation="sdpa"):
549550
device_map="auto",
550551
dtype=dtype,
551552
attn_implementation=attn_implementation,
553+
experts_implementation="eager",
552554
)
553555

554556
def test_small_model_integration_test(self):

0 commit comments

Comments
 (0)