Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 244 additions & 71 deletions docs/source/en/generation_strategies.md

Large diffs are not rendered by default.

103 changes: 93 additions & 10 deletions src/transformers/dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import filecmp
import hashlib
import importlib
import importlib.metadata
import importlib.util
import os
import re
Expand All @@ -30,6 +31,7 @@
from typing import Any, Optional, Union

from huggingface_hub import try_to_load_from_cache
from packaging import version

from .utils import (
HF_MODULES_CACHE,
Expand All @@ -39,6 +41,7 @@
is_offline_mode,
logging,
)
from .utils.import_utils import VersionComparison, split_package_version


logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -383,7 +386,7 @@ def get_cached_module_file(
new_files.append(module_file)

except OSError:
logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
logger.info(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were throwing an exception and a warning πŸ€”

Lowered the severity, otherwise we were throwing warnings in most model from_pretrained (see changes in modeling_utils.py)

raise

# Check we have all the requirements in our environment
Expand Down Expand Up @@ -417,7 +420,8 @@ def get_cached_module_file(
# benefit of versioning.
submodule_path = submodule_path / commit_hash
full_submodule = full_submodule + os.path.sep + commit_hash
create_dynamic_module(full_submodule)
full_submodule_module_file_path = os.path.join(full_submodule, module_file)
create_dynamic_module(Path(full_submodule_module_file_path).parent)
Comment on lines +423 to +424
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously: we could only load root level custom modules
With this change: we can load custom modules in any folder

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


if not (submodule_path / module_file).exists():
shutil.copy(resolved_module_file, submodule_path / module_file)
Expand Down Expand Up @@ -663,7 +667,33 @@ def _raise_timeout_error(signum, frame):
TIME_OUT_REMOTE_CODE = 15


def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None):
"""
Resolves the `trust_remote_code` argument. If there is remote code to be loaded, the user must opt-in to loading
it.

Args:
trust_remote_code (`bool` or `None`):
User-defined `trust_remote_code` value.
model_name (`str`):
The name of the model repository in huggingface.co.
has_local_code (`bool`):
Whether the model has local code.
has_remote_code (`bool`):
Whether the model has remote code.
error_message (`str`, *optional*):
Custom error message to display if there is remote code to load and the user didn't opt-in. If unset, the error
message will be regarding loading a model with custom code.

Returns:
The resolved `trust_remote_code` value.
"""
# Originally, `trust_remote_code` was used to load models with custom code.
error_message = (
error_message
or f"The repository `{model_name}` contains custom code which must be executed to correctly load the model."
)

if trust_remote_code is None:
if has_local_code:
trust_remote_code = False
Expand All @@ -674,8 +704,7 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
Expand All @@ -687,8 +716,7 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
Expand All @@ -701,9 +729,64 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has

if has_remote_code and not has_local_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)

return trust_remote_code


def check_python_requirements(path_or_repo_id, requirements_file="requirements.txt", **kwargs):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic for the requirements exception as below

ValueError: Missing requirements for joaogante/test_generate_from_hub_bad_requirements:
foo (installed: None)
bar==0.0.0 (installed: None)
torch>=99.0 (installed: 2.6.0+cu126)

"""
Tries to locate `requirements_file` in a local folder or repo, and confirms that the environment has all the
python dependencies installed.

Args:
path_or_repo_id (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
kwargs (`Dict[str, Any]`, *optional*):
Additional arguments to pass to `cached_file`.
"""
failed = [] # error messages regarding requirements
try:
requirements = cached_file(path_or_repo_id=path_or_repo_id, filename=requirements_file, **kwargs)
with open(requirements, "r") as f:
requirements = f.readlines()

for requirement in requirements:
requirement = requirement.strip()
if not requirement or requirement.startswith("#"): # skip empty lines and comments
continue

try:
# e.g. "torch>2.6.0" -> "torch", ">", "2.6.0"
package_name, delimiter, version_number = split_package_version(requirement)
except ValueError: # e.g. "torch", as opposed to "torch>2.6.0"
package_name = requirement
delimiter, version_number = None, None

try:
local_package_version = importlib.metadata.version(package_name)
except importlib.metadata.PackageNotFoundError:
failed.append(f"{requirement} (installed: None)")
continue

if delimiter is not None and version_number is not None:
is_satisfied = VersionComparison.from_string(delimiter)(
version.parse(local_package_version), version.parse(version_number)
)
else:
is_satisfied = True

if not is_satisfied:
failed.append(f"{requirement} (installed: {local_package_version})")

except OSError: # no requirements.txt
pass

if failed:
raise ImportError(
f"Missing requirements in your local environment for `{path_or_repo_id}`:\n" + "\n".join(failed)
)
97 changes: 95 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
import numpy as np
import torch
import torch.distributed as dist
from huggingface_hub import file_exists
from packaging import version
from torch import nn
from torch.nn import functional as F

from transformers.generation.candidate_generator import AssistantVocabTranslatorCache

from ..cache_utils import (
Cache,
DynamicCache,
Expand All @@ -39,6 +38,12 @@
QuantizedCacheConfig,
)
from ..configuration_utils import PretrainedConfig
from ..dynamic_module_utils import (
check_python_requirements,
get_cached_module_file,
get_class_in_module,
resolve_trust_remote_code,
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..integrations.fsdp import is_fsdp_managed_module
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
Expand All @@ -55,6 +60,7 @@
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import (
AssistantVocabTranslatorCache,
AssistedCandidateGenerator,
AssistedCandidateGeneratorDifferentTokenizers,
CandidateGenerator,
Expand Down Expand Up @@ -376,6 +382,73 @@ class GenerationMixin:
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""

def load_custom_generate(
self,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
trust_remote_code: Optional[bool] = None,
**kwargs,
) -> Callable:
"""
Loads and returns a custom generate function, given a model repo.

Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
trust_remote_code (`bool`, *optional*):
Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
execute code present on the Hub on your local machine.
**kwargs:
Additional keyword arguments for remote code loading.

Raises:
OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory.

Returns:
A callable that can be used to generate text.
"""
# Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError
is_local_code = os.path.exists(pretrained_model_name_or_path)
has_custom_generate_folder = True
if is_local_code:
if not os.path.exists(os.path.join(pretrained_model_name_or_path, "custom_generate/generate.py")):
has_custom_generate_folder = False
else:
if not file_exists(pretrained_model_name_or_path, "custom_generate/generate.py"):
has_custom_generate_folder = False

if not has_custom_generate_folder:
raise OSError(
f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a "
"`generate.py` file, can't load the custom generate function."
)

# Handle opt-in `trust_remote_code` and related exceptions
error_message = (
f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override "
"the default `generate` method."
)
resolve_trust_remote_code(
trust_remote_code,
pretrained_model_name_or_path,
has_local_code=is_local_code,
has_remote_code=not is_local_code,
error_message=error_message,
)

# Load the custom generate function
check_python_requirements(
pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs
)
module = get_cached_module_file(
pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs
)
custom_generate_function = get_class_in_module("generate", module)
return custom_generate_function

def _cache_dependant_input_preparation(
self,
input_ids: torch.LongTensor,
Expand Down Expand Up @@ -2158,6 +2231,7 @@ def generate(
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
use_model_defaults: Optional[bool] = None,
custom_generate: Optional[str] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -2227,6 +2301,11 @@ def generate(
generation configuration (`model.generation_config`), as opposed to the global defaults
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
`True`.
custom_generate (`str`, *optional*):
A string containing the name of a huggingface.co repository. If provided, the custom `generate`
function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the
standard `generate` method. Note that the logic is for generation is entirely defined in that
repository, and the return type may be different from the standard `generate` method.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
Expand All @@ -2248,6 +2327,20 @@ def generate(
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
if custom_generate is not None:
trust_remote_code = kwargs.pop("trust_remote_code", None)
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
# they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to
# methods from `GenerationMixin` through `model`.
global_keys_to_exclude = {"self", "kwargs"}
generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
generate_arguments.update(kwargs)

Comment on lines +2332 to +2339
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this can be done in loda_custom_generate that would return the generate arguments!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs locals() from generate to redirect non-keyword arguments :')

custom_generate_function = self.load_custom_generate(
custom_generate, trust_remote_code=trust_remote_code, **kwargs
)
return custom_generate_function(model=self, **generate_arguments)

# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
Expand Down
34 changes: 24 additions & 10 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4104,6 +4104,7 @@ def from_pretrained(
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)

# Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
if any(allowed_name in cls.__name__.lower() for allowed_name in VLMS):
Expand All @@ -4113,7 +4114,6 @@ def from_pretrained(

# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
_ = kwargs.pop("trust_remote_code", None)
_ = kwargs.pop("mirror", None)
_ = kwargs.pop("_fast_init", True)
_ = kwargs.pop("low_cpu_mem_usage", None)
Expand Down Expand Up @@ -4591,30 +4591,44 @@ def _assign_original_dtype(module):
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()

# If it is a model with generation capabilities, attempt to load the generation config
# If it is a model with generation capabilities, attempt to load generation files (generation config,
# custom generate function)
if model.can_generate() and generation_config is not None:
logger.info("The user-defined `generation_config` will be used to override the default generation config.")
model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
elif model.can_generate() and pretrained_model_name_or_path is not None:
repo_loading_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"subfolder": subfolder,
**kwargs,
}
# Load generation config
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
**repo_loading_kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(model, "load_custom_generate"):
try:
custom_generate = model.load_custom_generate(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
)
model.generate = functools.partial(custom_generate, model=model)
except OSError: # there is no custom generate function
pass

# Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
# harm performances)
Expand Down
Loading