mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
* Reload transformers fix form cache * add imports * add test fn for clearing import cache * ruff fix to core import logic * ruff fix to test file * fixup for imports * fixup for test * lru restore * test check * fix style changes * added documentation for usecase * fixing --------- Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
This commit is contained in:
parent
1cc7ca3295
commit
d6897b46bd
@ -24,7 +24,37 @@ You'll learn how to:
|
||||
- Modify a model's architecture by changing its attention mechanism.
|
||||
- Apply techniques like Low-Rank Adaptation (LoRA) to specific model components.
|
||||
|
||||
We encourage you to contribute your own hacks and share them here with the community1
|
||||
We encourage you to contribute your own hacks and share them here with the community!
|
||||
|
||||
## Efficient Development Workflow
|
||||
|
||||
When modifying model code, you'll often need to test your changes without restarting your Python session. The `clear_import_cache()` utility helps with this workflow, especially during model development and contribution when you need to frequently test and compare model outputs:
|
||||
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
model = AutoModel.from_pretrained("bert-base-uncased")
|
||||
|
||||
# Make modifications to the transformers code...
|
||||
|
||||
# Clear the cache to reload the modified code
|
||||
from transformers.utils.import_utils import clear_import_cache
|
||||
clear_import_cache()
|
||||
|
||||
# Reimport to get the changes
|
||||
from transformers import AutoModel
|
||||
model = AutoModel.from_pretrained("bert-base-uncased") # Will use updated code
|
||||
```
|
||||
|
||||
This is particularly useful when:
|
||||
- Iteratively modifying model architectures
|
||||
- Debugging model implementations
|
||||
- Testing changes during model development
|
||||
- Comparing outputs between original and modified versions
|
||||
- Working on model contributions
|
||||
|
||||
The `clear_import_cache()` function removes all cached Transformers modules and allows Python to reload the modified code. This enables rapid development cycles without constantly restarting your environment.
|
||||
|
||||
This workflow is especially valuable when implementing new models, where you need to frequently compare outputs between the original implementation and your Transformers version (as described in the [Add New Model](https://huggingface.co/docs/transformers/add_new_model) guide).
|
||||
|
||||
## Example: Modifying the Attention Mechanism in the Segment Anything Model (SAM)
|
||||
|
||||
|
@ -2271,3 +2271,28 @@ def define_import_structure(module_path: str) -> IMPORT_STRUCTURE_T:
|
||||
"""
|
||||
import_structure = create_import_structure_from_path(module_path)
|
||||
return spread_import_structure(import_structure)
|
||||
|
||||
|
||||
def clear_import_cache():
|
||||
"""
|
||||
Clear cached Transformers modules to allow reloading modified code.
|
||||
|
||||
This is useful when actively developing/modifying Transformers code.
|
||||
"""
|
||||
# Get all transformers modules
|
||||
transformers_modules = [mod_name for mod_name in sys.modules if mod_name.startswith("transformers.")]
|
||||
|
||||
# Remove them from sys.modules
|
||||
for mod_name in transformers_modules:
|
||||
module = sys.modules[mod_name]
|
||||
# Clear _LazyModule caches if applicable
|
||||
if isinstance(module, _LazyModule):
|
||||
module._objects = {} # Clear cached objects
|
||||
del sys.modules[mod_name]
|
||||
|
||||
# Force reload main transformers module
|
||||
if "transformers" in sys.modules:
|
||||
main_module = sys.modules["transformers"]
|
||||
if isinstance(main_module, _LazyModule):
|
||||
main_module._objects = {} # Clear cached objects
|
||||
importlib.reload(main_module)
|
||||
|
23
tests/utils/test_import_utils.py
Normal file
23
tests/utils/test_import_utils.py
Normal file
@ -0,0 +1,23 @@
|
||||
import sys
|
||||
|
||||
from transformers.utils.import_utils import clear_import_cache
|
||||
|
||||
|
||||
def test_clear_import_cache():
|
||||
# Import some transformers modules
|
||||
|
||||
# Get initial module count
|
||||
initial_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
|
||||
|
||||
# Verify we have some modules loaded
|
||||
assert len(initial_modules) > 0
|
||||
|
||||
# Clear cache
|
||||
clear_import_cache()
|
||||
|
||||
# Check modules were removed
|
||||
remaining_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
|
||||
assert len(remaining_modules) < len(initial_modules)
|
||||
|
||||
# Verify we can reimport
|
||||
assert "transformers.models.auto.modeling_auto" in sys.modules
|
Loading…
Reference in New Issue
Block a user