Skip to content
27 changes: 16 additions & 11 deletions tests/utils/test_import_utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import importlib
import sys

from transformers.utils.import_utils import clear_import_cache
from transformers.utils.import_utils import _LazyModule, clear_import_cache


def test_clear_import_cache():
# Import some transformers modules

# Get initial module count
# Save initial state
initial_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}

# Verify we have some modules loaded
assert len(initial_modules) > 0

# Clear cache
# Run the test
clear_import_cache()

# Check modules were removed
# Verify modules were removed
remaining_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
assert len(remaining_modules) < len(initial_modules)

# Verify we can reimport
assert "transformers" in sys.modules
# Import and verify module exists
from transformers.models.auto import modeling_auto

assert modeling_auto.__name__ == "transformers.models.auto.modeling_auto"

# Restore initial state
for name, module in initial_modules.items():
sys.modules[name] = module
if isinstance(module, _LazyModule):
# Re-initialize lazy module cache
importlib.reload(module)