Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 66 additions & 110 deletions src/transformers/generation/utils.py

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1590,7 +1589,7 @@ def prepare_inputs_for_generation(
if not empty_past_kv:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1274,7 +1273,7 @@ def prepare_inputs_for_generation(
if not empty_past_kv:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from ...modeling_utils import PreTrainedModel
from ...utils import (
is_torch_flex_attn_available,
is_torchdynamo_compiling,
logging,
)
from .configuration_bloom import BloomConfig
Expand Down Expand Up @@ -919,7 +918,7 @@ def prepare_inputs_for_generation(
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
LossKwargs,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -931,7 +930,7 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import (
is_torchdynamo_compiling,
logging,
)
from ..cohere.modeling_cohere import (
Expand Down Expand Up @@ -603,7 +602,7 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1571,7 +1570,7 @@ def prepare_inputs_for_generation(
if not empty_past_kv:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_flex_attn_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -2194,7 +2193,7 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/moshi/modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2507,7 +2507,7 @@ def prepare_inputs_for_generation(
if past_key_values is not None:
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1325,7 +1324,7 @@ def prepare_inputs_for_generation(
# (we can't check exception 3 while compiling)
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
)
Expand Down Expand Up @@ -1762,7 +1761,7 @@ def prepare_inputs_for_generation(
# (we can't check exception 3 while compiling)
if (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down