Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2597,6 +2597,15 @@ def _contrastive_search(
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

# Create cosine_matrix_mask based on the attention_mask
cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
if self.config.is_encoder_decoder:
if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None:
cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
else:
cosine_matrix_mask = model_kwargs["attention_mask"]
cosine_matrix_mask = cosine_matrix_mask.repeat_interleave(top_k, dim=0)

Comment on lines +2600 to +2608
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code initializes a default mask and then updates it based on the model type.

  • For encoder-decoder models, if model_kwargs contains a decoder_attention_mask (and it is not None), cosine_matrix_mask is set to this mask. If decoder_attention_mask is missing, it falls back to the default mask.
  • For decoder-only models, cosine_matrix_mask is set to the attention_mask from model_kwargs.

Please let me know if there are any additional logic checks that need to be added

this_peer_finished = False

while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
Expand Down Expand Up @@ -2764,7 +2773,12 @@ def _contrastive_search(
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
# model confidence. Keeping `selected_idx` on CPU enables multi-device contrastive search and doesn't
# introduce (noticeable) slowdowns on single-device runs.
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
selected_idx = _ranking_fast(
context_hidden, next_hidden, top_k_probs, cosine_matrix_mask, penalty_alpha, top_k
)
cosine_matrix_mask = torch.cat(
[cosine_matrix_mask, cosine_matrix_mask.new_ones((cosine_matrix_mask.shape[0], 1))], dim=-1
)
selected_idx = selected_idx.to("cpu")

# This will be used instead of the previous inneficient torch.stack(torch.split())
Expand Down Expand Up @@ -4276,6 +4290,7 @@ def _ranking_fast(
context_hidden: torch.FloatTensor,
next_hidden: torch.FloatTensor,
next_top_k_probs: torch.FloatTensor,
cosine_matrix_mask: torch.LongTensor,
alpha: float,
beam_width: int,
) -> torch.FloatTensor:
Expand All @@ -4287,6 +4302,13 @@ def _ranking_fast(
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1) # [B*K, S]

# Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
# Using a large negative value for masked positions
cosine_matrix_mask = cosine_matrix_mask.to(dtype=cosine_matrix.dtype)
cosine_matrix_mask = (1 - cosine_matrix_mask) * torch.finfo(cosine_matrix.dtype).min
cosine_matrix = cosine_matrix + cosine_matrix_mask

degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
Expand Down
135 changes: 135 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

if is_torch_available():
import torch
import torch.nn.functional as F

from transformers import (
AutoModelForCausalLM,
Expand All @@ -59,6 +60,7 @@
GPT2Tokenizer,
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
)
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.generation import (
Expand Down Expand Up @@ -3529,6 +3531,139 @@ def test_init_static_cache_multi_gpu(self):
value_cache_1 = results.past_key_values.value_cache[1]
self.assertTrue(key_cache_1.device == value_cache_1.device == torch.device(1))

@slow
def test_padding_input_contrastive_search_gpt2(self):
# Load the pre-trained GPT-2 model and tokenizer
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", clean_up_tokenization_spaces=True)

# Set the tokenizer to left-pad the sequences
tokenizer.padding_side = "left"

# Define the PAD token as the EOS token
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.pad_token_id = model.generation_config.eos_token_id

# Define the input prompt
prompt_text = "The whispered legends of the haunted mansion spoke"

# Tokenize the input prompt
encoded_prompt = tokenizer(prompt_text, return_tensors="pt", padding=True)
input_ids = encoded_prompt.input_ids.to(torch_device)
attention_mask = encoded_prompt.attention_mask.to(torch_device)

# Define the contrastive search params
penalty_alpha = 0.6
top_k = 4

# Define the padding length to add to the input IDs and attention mask
padding_length = 10

# Generate text without padding
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Pad the input IDs and attention mask on the left
padded_input_ids = F.pad(
input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
)
padded_attention_mask = F.pad(attention_mask, (padding_length, 0), "constant", value=0)

# Generate text with padded inputs
outputs_with_padding = model.generate(
input_ids=padded_input_ids,
attention_mask=padded_attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)

# Assert that the generated texts are identical for padded and non-padded inputs
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(
generated_text_with_padding,
'The whispered legends of the haunted mansion spoke of the "souls of the dead" who were "falling '
'out of the sky" and "falling into the sea."\n\nThe ghostly apparitions were said to have been '
'created by the spirits of the dead, who were "falling out of the sky" and "falling into the sea',
)

@slow
def test_padding_input_contrastive_search_t5(self):
# Load the pre-trained T5 model and tokenizer
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
model.to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", clean_up_tokenization_spaces=True)

# Define the input prompt
prompt_text = "translate English to German: I need to finish this task before the end of the day."

# Tokenize the input prompt
encoded_prompt = tokenizer(prompt_text, return_tensors="pt")
input_ids = encoded_prompt.input_ids.to(torch_device)
attention_mask = encoded_prompt.attention_mask.to(torch_device)

# Define the decoder prompt
decoder_prompt_text = "Ich muss diese Aufgabe"
encoded_decoder_prompt = tokenizer(decoder_prompt_text, add_special_tokens=False, return_tensors="pt")
decoder_input_ids = encoded_decoder_prompt.input_ids.to(torch_device)
decoder_attention_mask = encoded_decoder_prompt.attention_mask.to(torch_device)

# Define the contrastive search params
penalty_alpha = 0.6
top_k = 4

# Generate text without padding
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_no_padding = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Define the padding length to add to the input IDs and attention mask
padding_length = 10

# Pad the decoder input IDs and attention mask on the left
padded_decoder_input_ids = F.pad(
decoder_input_ids, (padding_length, 0), "constant", value=model.generation_config.pad_token_id
)
padded_decoder_attention_mask = F.pad(decoder_attention_mask, (padding_length, 0), "constant", value=0)
# Since the decoder_start_token_id is the same as the pad_token_id,
# the last padded token represents the decoder start token.
# Set the attention mask for the decoder_start_token_id to True (1).
padded_decoder_attention_mask[:, padding_length - 1] = 1
# Generate text with padded inputs
outputs_with_padding = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=padded_decoder_input_ids,
decoder_attention_mask=padded_decoder_attention_mask,
do_sample=False,
penalty_alpha=penalty_alpha,
top_k=top_k,
max_new_tokens=64,
)
generated_text_with_padding = tokenizer.decode(outputs_with_padding[0], skip_special_tokens=True)

# Assert that the generated texts are identical for padded and non-padded inputs
self.assertEqual(generated_text_no_padding, generated_text_with_padding)
self.assertEqual(generated_text_no_padding, "Ich muss diese Aufgabe vor Ende des Tages beenden.")


@require_torch
class TokenHealingTestCase(unittest.TestCase):
Expand Down