Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,22 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):

class GenerationMixin:
"""
A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
A class containing all functions for auto-regressive text generation, to be used as a mixin in model classes.
Inheriting from this class causes the model to have special generation-related behavior, such as loading a
`GenerationConfig` at initialization time or ensuring `generate`-related tests are run in `transformers` CI.

A model class should inherit from `GenerationMixin` to enable calling methods like `generate`, or when it
has defined a custom `generate` method that relies on `GenerationMixin`, directly or indirectly, which
approximately shares the same interface to public methods like `generate`. Three examples:
- `LlamaForCausalLM` should inherit from `GenerationMixin` to enable calling `generate` and other public
methods in the mixin;
- `BlipForQuestionAnswering` has a custom `generate` method that approximately shares the same interface as
`GenerationMixin.generate` (it has a few extra arguments, and the same output). That function also calls
`GenerationMixin.generate` indirectly, through an inner model. As such, `BlipForQuestionAnswering` should
inherit from `GenerationMixin` to benefit from all generation-related automation in our codebase;
- `BarkModel` has a custom `generate` method and one of its inner models calls `GenerationMixin.generate`.
However, its `generate` does not share the same interface as `GenerationMixin.generate`. In this case,
`BarkModel` shoud NOT inherit from `GenerationMixin`, as it breaks the `generate` interface.

The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
- *greedy decoding* if `num_beams=1` and `do_sample=False`
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,6 +1739,9 @@ def can_generate(cls) -> bool:
"""
Returns whether this model can generate sequences with `.generate()` from the `GenerationMixin`.

Under the hood, on classes where this function returns True, some generation-specific changes are triggered:
for instance, the model instance will have a populated `generation_config` attribute.

Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
Expand Down