Add utility for Reload Transformers imports cache for development workflow #35508 (#35858)

* Reload transformers fix form cache

* add imports

* add test fn for clearing import cache

* ruff fix to core import logic

* ruff fix to test file

* fixup for imports

* fixup for test

* lru restore

* test check

* fix style changes

* added documentation for usecase

* fixing

---------

Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
This commit is contained in:
Sambhav Dixit 2025-02-12 17:15:11 +05:30 committed by GitHub
parent 1cc7ca3295
commit d6897b46bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 1 deletions

View File

@ -24,7 +24,37 @@ You'll learn how to:
- Modify a model's architecture by changing its attention mechanism.
- Apply techniques like Low-Rank Adaptation (LoRA) to specific model components.
We encourage you to contribute your own hacks and share them here with the community1
We encourage you to contribute your own hacks and share them here with the community!
## Efficient Development Workflow
When modifying model code, you'll often need to test your changes without restarting your Python session. The `clear_import_cache()` utility helps with this workflow, especially during model development and contribution when you need to frequently test and compare model outputs:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-uncased")
# Make modifications to the transformers code...
# Clear the cache to reload the modified code
from transformers.utils.import_utils import clear_import_cache
clear_import_cache()
# Reimport to get the changes
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-uncased") # Will use updated code
```
This is particularly useful when:
- Iteratively modifying model architectures
- Debugging model implementations
- Testing changes during model development
- Comparing outputs between original and modified versions
- Working on model contributions
The `clear_import_cache()` function removes all cached Transformers modules and allows Python to reload the modified code. This enables rapid development cycles without constantly restarting your environment.
This workflow is especially valuable when implementing new models, where you need to frequently compare outputs between the original implementation and your Transformers version (as described in the [Add New Model](https://huggingface.co/docs/transformers/add_new_model) guide).
## Example: Modifying the Attention Mechanism in the Segment Anything Model (SAM)

View File

@ -2271,3 +2271,28 @@ def define_import_structure(module_path: str) -> IMPORT_STRUCTURE_T:
"""
import_structure = create_import_structure_from_path(module_path)
return spread_import_structure(import_structure)
def clear_import_cache():
"""
Clear cached Transformers modules to allow reloading modified code.
This is useful when actively developing/modifying Transformers code.
"""
# Get all transformers modules
transformers_modules = [mod_name for mod_name in sys.modules if mod_name.startswith("transformers.")]
# Remove them from sys.modules
for mod_name in transformers_modules:
module = sys.modules[mod_name]
# Clear _LazyModule caches if applicable
if isinstance(module, _LazyModule):
module._objects = {} # Clear cached objects
del sys.modules[mod_name]
# Force reload main transformers module
if "transformers" in sys.modules:
main_module = sys.modules["transformers"]
if isinstance(main_module, _LazyModule):
main_module._objects = {} # Clear cached objects
importlib.reload(main_module)

View File

@ -0,0 +1,23 @@
import sys
from transformers.utils.import_utils import clear_import_cache
def test_clear_import_cache():
# Import some transformers modules
# Get initial module count
initial_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
# Verify we have some modules loaded
assert len(initial_modules) > 0
# Clear cache
clear_import_cache()
# Check modules were removed
remaining_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
assert len(remaining_modules) < len(initial_modules)
# Verify we can reimport
assert "transformers.models.auto.modeling_auto" in sys.modules