Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 53 additions & 96 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
is_accelerate_available,
is_hqq_available,
is_optimum_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
Expand Down Expand Up @@ -393,7 +392,6 @@ def prepare_inputs_for_generation(
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
Expand All @@ -402,7 +400,7 @@ def prepare_inputs_for_generation(
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember this code overridden in some models. If that's still the case, we'll need to clean up there also

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll do a scan and replace the pattern in other points!

or cache_position[-1] >= input_ids.shape[1] # Exception 3
):
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
Expand Down Expand Up @@ -1323,7 +1321,7 @@ def _validate_model_class(self):
# TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from
# `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can
# safely call `GenerationMixin.generate`
if not is_torchdynamo_compiling() and not self.can_generate():
if not self.can_generate():
terminations_with_generation_support = [
"ForCausalLM",
"ForConditionalGeneration",
Expand Down Expand Up @@ -1424,11 +1422,6 @@ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):

def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
"""Performs validation related to the resulting generated length"""

# Can't throw warnings/exceptions during compilation
if is_torchdynamo_compiling():
return

# 1. Max length warnings related to poor parameterization
if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
# 20 is the default max_length of the generation config
Expand Down Expand Up @@ -1547,10 +1540,8 @@ def _prepare_generation_config(
# 2) the generation config must have seen no modification since its creation (the hash is the same);
# 3) there are non-default generation parameters in the model config.
# 4) the user must have set new generation parameters in the model config.
# NOTE: `torch.compile` can't compile `hash`, this legacy support is disabled with compilation.
if (
not is_torchdynamo_compiling()
and self.generation_config._from_model_config # 1)
self.generation_config._from_model_config # 1)
and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
and len(self.config._get_non_default_generation_parameters()) > 0 # 3)
):
Expand All @@ -1568,24 +1559,18 @@ def _prepare_generation_config(
generation_config = self.generation_config
using_model_generation_config = True

# `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
# will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
# exception will be raised in `_validate_model_kwargs`
if not is_torchdynamo_compiling():
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
else:
model_kwargs = kwargs
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id

return generation_config, model_kwargs

Expand All @@ -1610,10 +1595,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()

# TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
# end-to-end compilation will yield bad results because `cache_position` will be incorrect.
if not is_torchdynamo_compiling():
cache_position = cache_position[past_length:]
cache_position = cache_position[past_length:]

model_kwargs["cache_position"] = cache_position
return model_kwargs
Expand Down Expand Up @@ -1690,13 +1672,7 @@ def _get_cache(
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
if not is_torchdynamo_compiling():
cache_dtype = self.dtype
else:
# NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
# Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
# models. May cause trobles with non-text modalities.
cache_dtype = self.get_output_embeddings().weight.dtype
cache_dtype = self.dtype

layer_device_map = self._get_layer_device_map_for_cache_init()
cache_kwargs = {
Expand Down Expand Up @@ -1896,12 +1872,11 @@ def _tensor_or_none(token, device=None):

# Set pad token if unset (and there are conditions to do so)
if pad_token_tensor is None and eos_token_tensor is not None:
if not is_torchdynamo_compiling():
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
pad_token_tensor = eos_token_tensor[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")

Expand All @@ -1910,24 +1885,23 @@ def _tensor_or_none(token, device=None):
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow
if (
eos_token_tensor is not None
and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as "
"eos token. As a consequence, you may observe unexpected behavior. Please pass your input's "
"`attention_mask` to obtain reliable results."
)
if eos_token_tensor is not None and (
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation "
"will not stop until the maximum length is reached. Depending on other flags, it may even crash."
if (
eos_token_tensor is not None
and isin_mps_friendly(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as "
"eos token. As a consequence, you may observe unexpected behavior. Please pass your input's "
"`attention_mask` to obtain reliable results."
)
if eos_token_tensor is not None and (
torch.is_floating_point(eos_token_tensor) or (eos_token_tensor < 0).any()
):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_tensor}. Your generation "
"will not stop until the maximum length is reached. Depending on other flags, it may even crash."
)

# Update generation config with the updated special tokens tensors
# NOTE: this must be written into a different attribute name than the one holding the original special tokens
Expand Down Expand Up @@ -2067,7 +2041,7 @@ def generate(
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

# decoder-only models must use left-padding for batched generation.
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
if not self.config.is_encoder_decoder:
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
Expand Down Expand Up @@ -2164,7 +2138,7 @@ def generate(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)

if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
if self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
Expand Down Expand Up @@ -2427,43 +2401,29 @@ def typeerror():
# Convert to legacy cache format if requested
if (
generation_config.return_legacy_cache is True
and not is_torchdynamo_compiling()
and hasattr(result, "past_key_values")
and getattr(result.past_key_values, "to_legacy_cache") is not None
):
result.past_key_values = result.past_key_values.to_legacy_cache()
return result

def _has_unfinished_sequences(
self,
this_peer_finished: bool,
synced_gpus: bool,
device: torch.device,
cur_len: Optional[int] = None,
max_length: Optional[int] = None,
) -> bool:
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
"""
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
fed through `this_peer_finished`. ZeRO stage 3-friendly.
"""
# torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile,
# although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria)
# TODO (joao): remove this when torch's support for control flow is not experimental (https://pytorch.org/docs/stable/generated/torch.cond.html)
if is_torchdynamo_compiling():
return cur_len < max_length
else:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
return False
elif this_peer_finished:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
return False
return True
elif this_peer_finished:
return False
return True

def heal_tokens(
self, input_ids: torch.LongTensor, tokenizer: Optional["PreTrainedTokenizerBase"] = None
Expand Down Expand Up @@ -3225,7 +3185,6 @@ def _sample(
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
max_length = generation_config.max_length
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample

Expand Down Expand Up @@ -3260,9 +3219,7 @@ def _sample(
model_forward = self.get_compiled_call(generation_config.compile_config)

is_prefill = True
while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

Expand Down