Skip to content

Commit d6897b4

Browse files
Add utility for Reload Transformers imports cache for development workflow #35508 (#35858)
* Reload transformers fix form cache * add imports * add test fn for clearing import cache * ruff fix to core import logic * ruff fix to test file * fixup for imports * fixup for test * lru restore * test check * fix style changes * added documentation for usecase * fixing --------- Co-authored-by: sambhavnoobcoder <[email protected]>
1 parent 1cc7ca3 commit d6897b4

File tree

3 files changed

+79
-1
lines changed

3 files changed

+79
-1
lines changed

docs/source/en/how_to_hack_models.md

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,37 @@ You'll learn how to:
2424
- Modify a model's architecture by changing its attention mechanism.
2525
- Apply techniques like Low-Rank Adaptation (LoRA) to specific model components.
2626

27-
We encourage you to contribute your own hacks and share them here with the community1
27+
We encourage you to contribute your own hacks and share them here with the community!
28+
29+
## Efficient Development Workflow
30+
31+
When modifying model code, you'll often need to test your changes without restarting your Python session. The `clear_import_cache()` utility helps with this workflow, especially during model development and contribution when you need to frequently test and compare model outputs:
32+
33+
```python
34+
from transformers import AutoModel
35+
model = AutoModel.from_pretrained("bert-base-uncased")
36+
37+
# Make modifications to the transformers code...
38+
39+
# Clear the cache to reload the modified code
40+
from transformers.utils.import_utils import clear_import_cache
41+
clear_import_cache()
42+
43+
# Reimport to get the changes
44+
from transformers import AutoModel
45+
model = AutoModel.from_pretrained("bert-base-uncased") # Will use updated code
46+
```
47+
48+
This is particularly useful when:
49+
- Iteratively modifying model architectures
50+
- Debugging model implementations
51+
- Testing changes during model development
52+
- Comparing outputs between original and modified versions
53+
- Working on model contributions
54+
55+
The `clear_import_cache()` function removes all cached Transformers modules and allows Python to reload the modified code. This enables rapid development cycles without constantly restarting your environment.
56+
57+
This workflow is especially valuable when implementing new models, where you need to frequently compare outputs between the original implementation and your Transformers version (as described in the [Add New Model](https://huggingface.co/docs/transformers/add_new_model) guide).
2858

2959
## Example: Modifying the Attention Mechanism in the Segment Anything Model (SAM)
3060

src/transformers/utils/import_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,3 +2271,28 @@ def define_import_structure(module_path: str) -> IMPORT_STRUCTURE_T:
22712271
"""
22722272
import_structure = create_import_structure_from_path(module_path)
22732273
return spread_import_structure(import_structure)
2274+
2275+
2276+
def clear_import_cache():
2277+
"""
2278+
Clear cached Transformers modules to allow reloading modified code.
2279+
2280+
This is useful when actively developing/modifying Transformers code.
2281+
"""
2282+
# Get all transformers modules
2283+
transformers_modules = [mod_name for mod_name in sys.modules if mod_name.startswith("transformers.")]
2284+
2285+
# Remove them from sys.modules
2286+
for mod_name in transformers_modules:
2287+
module = sys.modules[mod_name]
2288+
# Clear _LazyModule caches if applicable
2289+
if isinstance(module, _LazyModule):
2290+
module._objects = {} # Clear cached objects
2291+
del sys.modules[mod_name]
2292+
2293+
# Force reload main transformers module
2294+
if "transformers" in sys.modules:
2295+
main_module = sys.modules["transformers"]
2296+
if isinstance(main_module, _LazyModule):
2297+
main_module._objects = {} # Clear cached objects
2298+
importlib.reload(main_module)

tests/utils/test_import_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import sys
2+
3+
from transformers.utils.import_utils import clear_import_cache
4+
5+
6+
def test_clear_import_cache():
7+
# Import some transformers modules
8+
9+
# Get initial module count
10+
initial_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
11+
12+
# Verify we have some modules loaded
13+
assert len(initial_modules) > 0
14+
15+
# Clear cache
16+
clear_import_cache()
17+
18+
# Check modules were removed
19+
remaining_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
20+
assert len(remaining_modules) < len(initial_modules)
21+
22+
# Verify we can reimport
23+
assert "transformers.models.auto.modeling_auto" in sys.modules

0 commit comments

Comments
 (0)