-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Generation: deprecate PreTrainedModel inheriting from GenerationMixin
#33203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
1160c0f
1b2cd4e
8e018c8
898315f
03e05b6
00aca81
275b631
3709bf8
9620273
01c66a1
7aa36fb
9337cd4
d57cefa
ed22155
8b40781
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -212,7 +212,7 @@ def _skip_init(*args, **kwargs): | |
| setattr(torch.nn.init, name, init_func) | ||
|
|
||
|
|
||
| def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): | ||
| def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): | ||
| try: | ||
| return next(parameter.parameters()).device | ||
| except StopIteration: | ||
|
|
@@ -227,7 +227,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | |
| return first_tuple[1].device | ||
|
|
||
|
|
||
| def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): | ||
| def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): | ||
| """ | ||
| Returns the first parameter dtype (can be non-floating) or asserts if none were found. | ||
| """ | ||
|
|
@@ -245,7 +245,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | |
| return first_tuple[1].dtype | ||
|
|
||
|
|
||
| def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): | ||
| def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): | ||
| """ | ||
| Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. | ||
| """ | ||
|
|
@@ -1295,6 +1295,7 @@ def floating_point_ops( | |
| return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) | ||
|
|
||
|
|
||
| # TODO (joao): remove `GenerationMixin` inheritance in v4.50 | ||
| class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): | ||
| r""" | ||
| Base class for all models. | ||
|
|
@@ -1624,11 +1625,30 @@ def can_generate(cls) -> bool: | |
| Returns: | ||
| `bool`: Whether this model can generate sequences with `.generate()`. | ||
| """ | ||
| # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. | ||
| # Alternativelly, the model can also have a custom `generate` function. | ||
| if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): | ||
| return False | ||
| return True | ||
| # Directly inherits `GenerationMixin` -> can generate | ||
| if "GenerationMixin" in str(cls.__bases__): | ||
| return True | ||
| # Model class overwrites `generate` -> can generate | ||
| if str(cls.__class__.__name__) in str(cls.generate): | ||
|
||
| return True | ||
| # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this | ||
| # was how we detected whether a model could generate. | ||
| if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): | ||
| logger.warning_once( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This warning should only appear in the case that is not BC after removing the |
||
| f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " | ||
| "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From πv4.50π onwards, " | ||
| "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " | ||
| "to call `generate` and other related functions." | ||
| "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " | ||
| "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes" | ||
| "\n - If you are the owner of the model architecture code, please modify your model class such that " | ||
| "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)." | ||
| "\n - If you are not the owner of the model architecture class, please contact the model code owner " | ||
| "to update it." | ||
| ) | ||
| return True | ||
| # Otherwise, can't generate | ||
| return False | ||
|
|
||
| @classmethod | ||
| def _check_and_enable_flash_attn_2( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This function was touched to avoid a circular import (the
from ..models.autoimports)