Skip to content
Merged
34 changes: 11 additions & 23 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils import ExtensionsTrie
from ..utils import (
Expand Down Expand Up @@ -1120,26 +1113,21 @@ def _validate_model_class(self):
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
# TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from
# `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can
# safely call `GenerationMixin.generate`
if not is_torchdynamo_compiling() and not self.can_generate():
generate_compatible_mappings = [
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
terminations_with_generation_support = [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This function was touched to avoid a circular import (the from ..models.auto imports)

"ForCausalLM",
"ForConditionalGeneration",
"ForSpeechSeq2Seq",
"ForVision2Seq",
]
generate_compatible_classes = set()
for model_mapping in generate_compatible_mappings:
supported_models = model_mapping.get(type(self.config), default=None)
if supported_models is not None:
generate_compatible_classes.add(supported_models.__name__)
exception_message = (
raise TypeError(
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
"it doesn't have a language model head. Classes that support generation often end in one of these "
f"names: {terminations_with_generation_support}."
)
if generate_compatible_classes:
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
raise TypeError(exception_message)

def _validate_assistant(self, assistant_model):
if assistant_model is None:
Expand Down
36 changes: 28 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _skip_init(*args, **kwargs):
setattr(torch.nn.init, name, init_func)


def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
except StopIteration:
Expand All @@ -227,7 +227,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].device


def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
"""
Returns the first parameter dtype (can be non-floating) or asserts if none were found.
"""
Expand All @@ -245,7 +245,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].dtype


def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
Expand Down Expand Up @@ -1295,6 +1295,7 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


# TODO (joao): remove `GenerationMixin` inheritance in v4.50
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.
Expand Down Expand Up @@ -1624,11 +1625,30 @@ def can_generate(cls) -> bool:
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False
return True
# Directly inherits `GenerationMixin` -> can generate
if "GenerationMixin" in str(cls.__bases__):
return True
# Model class overwrites `generate` -> can generate
if str(cls.__class__.__name__) in str(cls.generate):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it does not inherit from generationMixin should we rais an error?

Copy link
Contributor Author

@gante gante Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, we have the time series models that overwrite generate (and, prior to this change, can_generate() returned True for them)

return True
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# was how we detected whether a model could generate.
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
logger.warning_once(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning should only appear in the case that is not BC after removing the GenerationMixin inheritance in PreTrainedModel. See caveats section of the PR header.

f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
"overwritten. However, it doesn't directly inherit from `GenerationMixin`. From πŸ‘‰v4.50πŸ‘ˆ onwards, "
"`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
"to call `generate` and other related functions."
"\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
"model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes"
"\n - If you are the owner of the model architecture code, please modify your model class such that "
"it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)."
"\n - If you are not the owner of the model architecture class, please contact the model code owner "
"to update it."
)
return True
# Otherwise, can't generate
return False

@classmethod
def _check_and_enable_flash_attn_2(
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -936,7 +937,7 @@ def forward(
)


class AlbertMLMHead(nn.Module):
class AlbertMLMHead(nn.Module, GenerationMixin):
def __init__(self, config: AlbertConfig):
super().__init__()

Expand Down
35 changes: 35 additions & 0 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@
extract_commit_hash,
find_adapter_config_file,
is_peft_available,
is_torch_available,
logging,
requires_backends,
)
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings


if is_torch_available():
from ...generation import GenerationMixin


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -428,6 +433,7 @@ def from_config(cls, config, **kwargs):
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
cls.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None)
model_class = add_generation_mixin_to_remote_model(model_class)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
Expand Down Expand Up @@ -549,6 +555,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
)
_ = hub_kwargs.pop("code_revision", None)
cls.register(config.__class__, model_class, exist_ok=True)
model_class = add_generation_mixin_to_remote_model(model_class)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
Expand Down Expand Up @@ -698,6 +705,34 @@ def getattribute_from_module(module, attr):
raise ValueError(f"Could not find {attr} in {transformers_module}!")


def add_generation_mixin_to_remote_model(model_class):
"""
Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model.

This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make
`PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded
from the Hub may not have the `generate` method after we remove the inheritance.
"""
# 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing
if "torch.nn.modules.module.Module" not in str(model_class.__mro__):
return model_class

# 2. If it already **directly** inherits from GenerationMixin, do nothing
if "GenerationMixin" in str(model_class.__bases__):
return model_class

# 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or
# `prepare_inputs_for_generation` method.
has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate"))
has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation"))
if has_custom_generate or has_custom_prepare_inputs:
model_class_with_generation_mixin = type(
model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__}
)
return model_class_with_generation_mixin
return model_class


class _LazyAutoMapping(OrderedDict):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch import nn
from torch.nn import functional as F

from ...generation import GenerationMixin
from ...generation.logits_process import (
AlternatingCodebooksLogitsProcessor,
BarkEosPrioritizerLogitsProcessor,
Expand Down Expand Up @@ -546,7 +547,7 @@ def device(self) -> torch.device:


# GPT2-like autoregressive model
class BarkCausalModel(BarkPreTrainedModel):
class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
config_class = BarkSubModelConfig

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
Expand Down Expand Up @@ -1557,7 +1558,7 @@ def forward(
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
)
class BartForConditionalGeneration(BartPreTrainedModel):
class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
Expand Down Expand Up @@ -2010,7 +2011,7 @@ def forward(self, *args, **kwargs):
""",
BART_START_DOCSTRING,
)
class BartForCausalLM(BartPreTrainedModel):
class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
Expand Down Expand Up @@ -1280,7 +1281,7 @@ def forward(
@add_start_docstrings(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]

def __init__(self, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
Expand Down Expand Up @@ -863,7 +864,7 @@ def _tie_weights(self):
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning.""",
BERT_GENERATION_START_DOCSTRING,
)
class BertGenerationDecoder(BertGenerationPreTrainedModel):
class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]

def __init__(self, config):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
Expand Down Expand Up @@ -2495,7 +2496,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_
@add_start_docstrings(
"""BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING
)
class BigBirdForCausalLM(BigBirdPreTrainedModel):
class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

def __init__(self, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -2436,7 +2437,7 @@ def forward(
BIGBIRD_PEGASUS_START_DOCSTRING,
)
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
Expand Down Expand Up @@ -2882,7 +2883,7 @@ def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -596,7 +597,7 @@ def forward(
@add_start_docstrings(
"""BioGPT Model with a `language modeling` head on top for CLM fine-tuning.""", BIOGPT_START_DOCSTRING
)
class BioGptForCausalLM(BioGptPreTrainedModel):
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["output_projection.weight"]

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -1196,7 +1197,7 @@ def forward(
@add_start_docstrings(
"The Blenderbot Model with a language modeling head. Can be used for summarization.", BLENDERBOT_START_DOCSTRING
)
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
Expand Down Expand Up @@ -1397,7 +1398,7 @@ def forward(self, *args, **kwargs):


# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -1163,7 +1164,7 @@ def forward(
"The BlenderbotSmall Model with a language modeling head. Can be used for summarization.",
BLENDERBOT_SMALL_START_DOCSTRING,
)
class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
Expand Down Expand Up @@ -1349,7 +1350,7 @@ def forward(self, *args, **kwargs):


# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
Loading