Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,23 +406,28 @@ def prepare_inputs_for_generation(
model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)

# 4. Create missing `position_ids` on the fly
attention_mask = (
kwargs.pop("decoder_attention_mask", None) if self.config.is_encoder_decoder else attention_mask
)
attention_mask_key = "decoder_attention_mask" if self.config.is_encoder_decoder else "attention_mask"
position_ids_key = "decoder_position_ids" if self.config.is_encoder_decoder else "position_ids"
if (
attention_mask is not None
and kwargs.get("position_ids") is None
and "position_ids" in set(inspect.signature(self.forward).parameters.keys())
and kwargs.get(position_ids_key) is None
and position_ids_key in set(inspect.signature(self.forward).parameters.keys())
):
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
kwargs["position_ids"] = position_ids # placed in kwargs for further processing (see below)
kwargs[position_ids_key] = position_ids # placed in kwargs for further processing (see below)

# 5. Slice model inputs if it's an input that should have the same length as `input_ids`
for model_input_name in ["position_ids", "token_type_ids"]:
for model_input_name in ["position_ids", "token_type_ids", "decoder_position_ids"]:
model_input = kwargs.get(model_input_name)
if model_input is not None:
if past_key_values is not None:
current_input_length = (
model_inputs["inputs_embeds"].shape[1]
if model_inputs["inputs_embeds"] is not None
if model_inputs.get("inputs_embeds") is not None
else model_inputs[input_ids_key].shape[1]
)
model_input = model_input[:, -current_input_length:]
Expand Down Expand Up @@ -469,7 +474,7 @@ def prepare_inputs_for_generation(
past_key_values=past_key_values,
)
if attention_mask is not None:
model_inputs["attention_mask"] = attention_mask
model_inputs[attention_mask_key] = attention_mask

# 7. Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
for key, value in kwargs.items():
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,7 @@ def _expand_variables_for_generation(
def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs):
set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs")
extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)}
set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs})
set_inputs({"inputs": segment_input, "input_ids": decoder_input_ids, **extra_kwargs})

@staticmethod
def _retrieve_total_input_frames(input_features, input_stride, kwargs):
Expand Down
84 changes: 1 addition & 83 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,7 @@ def forward(
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
position_ids = cache_position.unsqueeze(0).repeat(input_shape[0], 1)

# embed positions
if input_ids is not None:
Expand Down Expand Up @@ -1806,88 +1806,6 @@ def forward(
encoder_attentions=outputs.encoder_attentions,
)

def prepare_inputs_for_generation(
self,
decoder_input_ids,
past_key_values=None,
use_cache=None,
encoder_outputs=None,
attention_mask=None,
decoder_attention_mask=None,
cache_position=None,
**kwargs,
):
# Overwritten -- encoder-decoder whisper has custom logic, but it's close to the general function. Next time
# this function needs to be touched, let's try to sort out the commonalities between the two and remove the
# overwrite.

decoder_position_ids = None
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)

past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, EncoderDecoderCache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
else:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if decoder_input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = decoder_input_ids.shape[1] - 1

decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]

if decoder_position_ids is not None:
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format)

if cache_position is None:
cache_position = torch.arange(
past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device
)
elif use_cache:
cache_position = cache_position[-decoder_input_ids.shape[1] :]

# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
decoder_input_ids = decoder_input_ids.contiguous()

if (
isinstance(past_key_values, EncoderDecoderCache)
and (
isinstance(past_key_values.self_attention_cache, StaticCache)
or isinstance(past_key_values.cross_attention_cache, StaticCache)
)
and decoder_attention_mask is not None
and decoder_attention_mask.ndim == 2
):
batch_size, sequence_length = decoder_input_ids.shape

decoder_attention_mask = self.get_decoder()._prepare_4d_causal_attention_mask_with_cache_position(
decoder_attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.self_attention_cache.get_max_cache_shape(),
dtype=self.proj_out.weight.dtype,
device=decoder_input_ids.device,
cache_position=cache_position,
batch_size=batch_size,
)

return {
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache,
"decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
"cache_position": cache_position,
}


class WhisperDecoderWrapper(WhisperPreTrainedModel):
"""
Expand Down
5 changes: 2 additions & 3 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3323,8 +3323,8 @@ def test_tiny_static_generation(self):
input_features = input_features.to(torch_device)
eager_generated_ids = model.generate(input_features, max_new_tokens=64)

# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
# Using static cache compiles forward for each decoding step, so we don't have to manually compile

model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

# compile the forward pass and assert equivalence
static_generated_ids = model.generate(input_features, max_new_tokens=64)
Expand Down Expand Up @@ -3379,9 +3379,8 @@ def test_tiny_static_generation_long_form(self):
set_seed(42)
eager_generated_ids = model.generate(**inputs, **gen_kwargs)

# compile the forward pass and assert equivalence
# Using statiic cache compiles forward for each decoding step, so we don't have to manually compile
model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

set_seed(42)
static_generated_ids = model.generate(**inputs, **gen_kwargs)
Expand Down