-
Notifications
You must be signed in to change notification settings - Fork 31.7k
[generate] Run custom generation code from the Hub #36405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
060e19b
691ff21
4b21bbf
23d1ed1
a740ba4
93591c9
898c81f
d1db551
fd5af47
0b0828d
f524143
3dafff8
99c45aa
245dd8f
961966f
bdb314c
76faf90
9b062dc
88d7b58
9b32946
63f7500
5422ea4
c46bbd2
dd05515
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import filecmp | ||
| import hashlib | ||
| import importlib | ||
| import importlib.metadata | ||
| import importlib.util | ||
| import os | ||
| import re | ||
|
|
@@ -30,6 +31,7 @@ | |
| from typing import Any, Optional, Union | ||
|
|
||
| from huggingface_hub import try_to_load_from_cache | ||
| from packaging import version | ||
|
|
||
| from .utils import ( | ||
| HF_MODULES_CACHE, | ||
|
|
@@ -39,6 +41,7 @@ | |
| is_offline_mode, | ||
| logging, | ||
| ) | ||
| from .utils.import_utils import VersionComparison, split_package_version | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | ||
|
|
@@ -383,7 +386,7 @@ def get_cached_module_file( | |
| new_files.append(module_file) | ||
|
|
||
| except OSError: | ||
| logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") | ||
| logger.info(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.") | ||
| raise | ||
|
|
||
| # Check we have all the requirements in our environment | ||
|
|
@@ -417,7 +420,8 @@ def get_cached_module_file( | |
| # benefit of versioning. | ||
| submodule_path = submodule_path / commit_hash | ||
| full_submodule = full_submodule + os.path.sep + commit_hash | ||
| create_dynamic_module(full_submodule) | ||
| full_submodule_module_file_path = os.path.join(full_submodule, module_file) | ||
| create_dynamic_module(Path(full_submodule_module_file_path).parent) | ||
|
Comment on lines
+423
to
+424
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Previously: we could only load root level custom modules
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||
|
|
||
| if not (submodule_path / module_file).exists(): | ||
| shutil.copy(resolved_module_file, submodule_path / module_file) | ||
|
|
@@ -663,7 +667,33 @@ def _raise_timeout_error(signum, frame): | |
| TIME_OUT_REMOTE_CODE = 15 | ||
|
|
||
|
|
||
| def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code): | ||
| def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None): | ||
| """ | ||
| Resolves the `trust_remote_code` argument. If there is remote code to be loaded, the user must opt-in to loading | ||
| it. | ||
|
|
||
| Args: | ||
| trust_remote_code (`bool` or `None`): | ||
| User-defined `trust_remote_code` value. | ||
| model_name (`str`): | ||
| The name of the model repository in huggingface.co. | ||
| has_local_code (`bool`): | ||
| Whether the model has local code. | ||
| has_remote_code (`bool`): | ||
| Whether the model has remote code. | ||
| error_message (`str`, *optional*): | ||
| Custom error message to display if there is remote code to load and the user didn't opt-in. If unset, the error | ||
| message will be regarding loading a model with custom code. | ||
|
|
||
| Returns: | ||
| The resolved `trust_remote_code` value. | ||
| """ | ||
| # Originally, `trust_remote_code` was used to load models with custom code. | ||
| error_message = ( | ||
| error_message | ||
| or f"The repository `{model_name}` contains custom code which must be executed to correctly load the model." | ||
| ) | ||
|
|
||
| if trust_remote_code is None: | ||
| if has_local_code: | ||
| trust_remote_code = False | ||
|
|
@@ -674,8 +704,7 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has | |
| signal.alarm(TIME_OUT_REMOTE_CODE) | ||
| while trust_remote_code is None: | ||
| answer = input( | ||
| f"The repository for {model_name} contains custom code which must be executed to correctly " | ||
| f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" | ||
| f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n" | ||
| f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n" | ||
| f"Do you wish to run the custom code? [y/N] " | ||
| ) | ||
|
|
@@ -687,8 +716,7 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has | |
| except Exception: | ||
| # OS which does not support signal.SIGALRM | ||
| raise ValueError( | ||
| f"The repository for {model_name} contains custom code which must be executed to correctly " | ||
| f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" | ||
| f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n" | ||
| f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." | ||
| ) | ||
| finally: | ||
|
|
@@ -701,9 +729,64 @@ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has | |
|
|
||
| if has_remote_code and not has_local_code and not trust_remote_code: | ||
| raise ValueError( | ||
| f"Loading {model_name} requires you to execute the configuration file in that" | ||
| " repo on your local machine. Make sure you have read the code there to avoid malicious use, then" | ||
| " set the option `trust_remote_code=True` to remove this error." | ||
| f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n" | ||
| f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." | ||
| ) | ||
|
|
||
| return trust_remote_code | ||
|
|
||
|
|
||
| def check_python_requirements(path_or_repo_id, requirements_file="requirements.txt", **kwargs): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Logic for the requirements exception as below |
||
| """ | ||
| Tries to locate `requirements_file` in a local folder or repo, and confirms that the environment has all the | ||
| python dependencies installed. | ||
|
|
||
| Args: | ||
| path_or_repo_id (`str` or `os.PathLike`): | ||
| This can be either: | ||
| - a string, the *model id* of a model repo on huggingface.co. | ||
| - a path to a *directory* potentially containing the file. | ||
| kwargs (`Dict[str, Any]`, *optional*): | ||
| Additional arguments to pass to `cached_file`. | ||
| """ | ||
| failed = [] # error messages regarding requirements | ||
| try: | ||
| requirements = cached_file(path_or_repo_id=path_or_repo_id, filename=requirements_file, **kwargs) | ||
| with open(requirements, "r") as f: | ||
| requirements = f.readlines() | ||
|
|
||
| for requirement in requirements: | ||
| requirement = requirement.strip() | ||
| if not requirement or requirement.startswith("#"): # skip empty lines and comments | ||
| continue | ||
|
|
||
| try: | ||
| # e.g. "torch>2.6.0" -> "torch", ">", "2.6.0" | ||
| package_name, delimiter, version_number = split_package_version(requirement) | ||
| except ValueError: # e.g. "torch", as opposed to "torch>2.6.0" | ||
| package_name = requirement | ||
| delimiter, version_number = None, None | ||
|
|
||
| try: | ||
| local_package_version = importlib.metadata.version(package_name) | ||
| except importlib.metadata.PackageNotFoundError: | ||
| failed.append(f"{requirement} (installed: None)") | ||
| continue | ||
|
|
||
| if delimiter is not None and version_number is not None: | ||
| is_satisfied = VersionComparison.from_string(delimiter)( | ||
| version.parse(local_package_version), version.parse(version_number) | ||
| ) | ||
| else: | ||
| is_satisfied = True | ||
|
|
||
| if not is_satisfied: | ||
| failed.append(f"{requirement} (installed: {local_package_version})") | ||
|
|
||
| except OSError: # no requirements.txt | ||
| pass | ||
|
|
||
| if failed: | ||
| raise ImportError( | ||
| f"Missing requirements in your local environment for `{path_or_repo_id}`:\n" + "\n".join(failed) | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,12 +23,11 @@ | |
| import numpy as np | ||
| import torch | ||
| import torch.distributed as dist | ||
| from huggingface_hub import file_exists | ||
| from packaging import version | ||
| from torch import nn | ||
| from torch.nn import functional as F | ||
|
|
||
| from transformers.generation.candidate_generator import AssistantVocabTranslatorCache | ||
|
|
||
| from ..cache_utils import ( | ||
| Cache, | ||
| DynamicCache, | ||
|
|
@@ -39,6 +38,12 @@ | |
| QuantizedCacheConfig, | ||
| ) | ||
| from ..configuration_utils import PretrainedConfig | ||
| from ..dynamic_module_utils import ( | ||
| check_python_requirements, | ||
| get_cached_module_file, | ||
| get_class_in_module, | ||
| resolve_trust_remote_code, | ||
| ) | ||
| from ..integrations.deepspeed import is_deepspeed_zero3_enabled | ||
| from ..integrations.fsdp import is_fsdp_managed_module | ||
| from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput | ||
|
|
@@ -55,6 +60,7 @@ | |
| from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint | ||
| from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer | ||
| from .candidate_generator import ( | ||
| AssistantVocabTranslatorCache, | ||
| AssistedCandidateGenerator, | ||
| AssistedCandidateGeneratorDifferentTokenizers, | ||
| CandidateGenerator, | ||
|
|
@@ -376,6 +382,73 @@ class GenerationMixin: | |
| To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies). | ||
| """ | ||
|
|
||
| def load_custom_generate( | ||
| self, | ||
| pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, | ||
| trust_remote_code: Optional[bool] = None, | ||
| **kwargs, | ||
| ) -> Callable: | ||
| """ | ||
| Loads and returns a custom generate function, given a model repo. | ||
|
|
||
| Args: | ||
| pretrained_model_name_or_path (`str` or `os.PathLike`): | ||
| Can be either: | ||
| - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. | ||
| - A path to a *directory* containing model weights saved using | ||
| [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. | ||
| trust_remote_code (`bool`, *optional*): | ||
| Whether or not to allow for custom models defined on the Hub in their own modeling files. This option | ||
| should only be set to `True` for repositories you trust and in which you have read the code, as it will | ||
| execute code present on the Hub on your local machine. | ||
| **kwargs: | ||
| Additional keyword arguments for remote code loading. | ||
|
|
||
| Raises: | ||
| OSError: If `pretrained_model_name_or_path` does not contain a `custom_generate` subdirectory. | ||
|
|
||
| Returns: | ||
| A callable that can be used to generate text. | ||
| """ | ||
| # Does `pretrained_model_name_or_path` have a `custom_generate` subdirectory? If not -> OSError | ||
| is_local_code = os.path.exists(pretrained_model_name_or_path) | ||
| has_custom_generate_folder = True | ||
| if is_local_code: | ||
| if not os.path.exists(os.path.join(pretrained_model_name_or_path, "custom_generate/generate.py")): | ||
| has_custom_generate_folder = False | ||
| else: | ||
| if not file_exists(pretrained_model_name_or_path, "custom_generate/generate.py"): | ||
| has_custom_generate_folder = False | ||
|
|
||
| if not has_custom_generate_folder: | ||
| raise OSError( | ||
| f"`{pretrained_model_name_or_path}` does not contain a `custom_generate` subdirectory with a " | ||
| "`generate.py` file, can't load the custom generate function." | ||
| ) | ||
|
|
||
| # Handle opt-in `trust_remote_code` and related exceptions | ||
| error_message = ( | ||
| f"The repository `{pretrained_model_name_or_path}` contains custom generation code that will override " | ||
| "the default `generate` method." | ||
| ) | ||
| resolve_trust_remote_code( | ||
| trust_remote_code, | ||
| pretrained_model_name_or_path, | ||
| has_local_code=is_local_code, | ||
| has_remote_code=not is_local_code, | ||
| error_message=error_message, | ||
| ) | ||
|
|
||
| # Load the custom generate function | ||
| check_python_requirements( | ||
| pretrained_model_name_or_path, requirements_file="custom_generate/requirements.txt", **kwargs | ||
| ) | ||
| module = get_cached_module_file( | ||
| pretrained_model_name_or_path, module_file="custom_generate/generate.py", **kwargs | ||
| ) | ||
| custom_generate_function = get_class_in_module("generate", module) | ||
| return custom_generate_function | ||
|
|
||
| def _cache_dependant_input_preparation( | ||
| self, | ||
| input_ids: torch.LongTensor, | ||
|
|
@@ -2158,6 +2231,7 @@ def generate( | |
| negative_prompt_ids: Optional[torch.Tensor] = None, | ||
| negative_prompt_attention_mask: Optional[torch.Tensor] = None, | ||
| use_model_defaults: Optional[bool] = None, | ||
| custom_generate: Optional[str] = None, | ||
| **kwargs, | ||
| ) -> Union[GenerateOutput, torch.LongTensor]: | ||
| r""" | ||
|
|
@@ -2227,6 +2301,11 @@ def generate( | |
| generation configuration (`model.generation_config`), as opposed to the global defaults | ||
| (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be | ||
| `True`. | ||
| custom_generate (`str`, *optional*): | ||
| A string containing the name of a huggingface.co repository. If provided, the custom `generate` | ||
| function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the | ||
| standard `generate` method. Note that the logic is for generation is entirely defined in that | ||
| repository, and the return type may be different from the standard `generate` method. | ||
| kwargs (`Dict[str, Any]`, *optional*): | ||
| Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be | ||
| forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder | ||
|
|
@@ -2248,6 +2327,20 @@ def generate( | |
| - [`~generation.GenerateEncoderDecoderOutput`], | ||
| - [`~generation.GenerateBeamEncoderDecoderOutput`] | ||
| """ | ||
| # 0. If requested, load an arbitrary generation recipe from the Hub and run it instead | ||
| if custom_generate is not None: | ||
| trust_remote_code = kwargs.pop("trust_remote_code", None) | ||
| # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them: | ||
| # they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to | ||
| # methods from `GenerationMixin` through `model`. | ||
| global_keys_to_exclude = {"self", "kwargs"} | ||
| generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude} | ||
| generate_arguments.update(kwargs) | ||
|
|
||
|
Comment on lines
+2332
to
+2339
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all of this can be done in loda_custom_generate that would return the generate arguments!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It needs |
||
| custom_generate_function = self.load_custom_generate( | ||
| custom_generate, trust_remote_code=trust_remote_code, **kwargs | ||
| ) | ||
| return custom_generate_function(model=self, **generate_arguments) | ||
|
|
||
| # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call | ||
| tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were throwing an exception and a warning π€
Lowered the severity, otherwise we were throwing warnings in most model
from_pretrained(see changes inmodeling_utils.py)