mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
* Reload transformers fix form cache * add imports * add test fn for clearing import cache * ruff fix to core import logic * ruff fix to test file * fixup for imports * fixup for test * lru restore * test check * fix style changes * added documentation for usecase * fixing --------- Co-authored-by: sambhavnoobcoder <indosambahv@gmail.com>
This commit is contained in:
parent
1cc7ca3295
commit
d6897b46bd
@ -24,7 +24,37 @@ You'll learn how to:
|
|||||||
- Modify a model's architecture by changing its attention mechanism.
|
- Modify a model's architecture by changing its attention mechanism.
|
||||||
- Apply techniques like Low-Rank Adaptation (LoRA) to specific model components.
|
- Apply techniques like Low-Rank Adaptation (LoRA) to specific model components.
|
||||||
|
|
||||||
We encourage you to contribute your own hacks and share them here with the community1
|
We encourage you to contribute your own hacks and share them here with the community!
|
||||||
|
|
||||||
|
## Efficient Development Workflow
|
||||||
|
|
||||||
|
When modifying model code, you'll often need to test your changes without restarting your Python session. The `clear_import_cache()` utility helps with this workflow, especially during model development and contribution when you need to frequently test and compare model outputs:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoModel
|
||||||
|
model = AutoModel.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
# Make modifications to the transformers code...
|
||||||
|
|
||||||
|
# Clear the cache to reload the modified code
|
||||||
|
from transformers.utils.import_utils import clear_import_cache
|
||||||
|
clear_import_cache()
|
||||||
|
|
||||||
|
# Reimport to get the changes
|
||||||
|
from transformers import AutoModel
|
||||||
|
model = AutoModel.from_pretrained("bert-base-uncased") # Will use updated code
|
||||||
|
```
|
||||||
|
|
||||||
|
This is particularly useful when:
|
||||||
|
- Iteratively modifying model architectures
|
||||||
|
- Debugging model implementations
|
||||||
|
- Testing changes during model development
|
||||||
|
- Comparing outputs between original and modified versions
|
||||||
|
- Working on model contributions
|
||||||
|
|
||||||
|
The `clear_import_cache()` function removes all cached Transformers modules and allows Python to reload the modified code. This enables rapid development cycles without constantly restarting your environment.
|
||||||
|
|
||||||
|
This workflow is especially valuable when implementing new models, where you need to frequently compare outputs between the original implementation and your Transformers version (as described in the [Add New Model](https://huggingface.co/docs/transformers/add_new_model) guide).
|
||||||
|
|
||||||
## Example: Modifying the Attention Mechanism in the Segment Anything Model (SAM)
|
## Example: Modifying the Attention Mechanism in the Segment Anything Model (SAM)
|
||||||
|
|
||||||
|
@ -2271,3 +2271,28 @@ def define_import_structure(module_path: str) -> IMPORT_STRUCTURE_T:
|
|||||||
"""
|
"""
|
||||||
import_structure = create_import_structure_from_path(module_path)
|
import_structure = create_import_structure_from_path(module_path)
|
||||||
return spread_import_structure(import_structure)
|
return spread_import_structure(import_structure)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_import_cache():
|
||||||
|
"""
|
||||||
|
Clear cached Transformers modules to allow reloading modified code.
|
||||||
|
|
||||||
|
This is useful when actively developing/modifying Transformers code.
|
||||||
|
"""
|
||||||
|
# Get all transformers modules
|
||||||
|
transformers_modules = [mod_name for mod_name in sys.modules if mod_name.startswith("transformers.")]
|
||||||
|
|
||||||
|
# Remove them from sys.modules
|
||||||
|
for mod_name in transformers_modules:
|
||||||
|
module = sys.modules[mod_name]
|
||||||
|
# Clear _LazyModule caches if applicable
|
||||||
|
if isinstance(module, _LazyModule):
|
||||||
|
module._objects = {} # Clear cached objects
|
||||||
|
del sys.modules[mod_name]
|
||||||
|
|
||||||
|
# Force reload main transformers module
|
||||||
|
if "transformers" in sys.modules:
|
||||||
|
main_module = sys.modules["transformers"]
|
||||||
|
if isinstance(main_module, _LazyModule):
|
||||||
|
main_module._objects = {} # Clear cached objects
|
||||||
|
importlib.reload(main_module)
|
||||||
|
23
tests/utils/test_import_utils.py
Normal file
23
tests/utils/test_import_utils.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
from transformers.utils.import_utils import clear_import_cache
|
||||||
|
|
||||||
|
|
||||||
|
def test_clear_import_cache():
|
||||||
|
# Import some transformers modules
|
||||||
|
|
||||||
|
# Get initial module count
|
||||||
|
initial_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
|
||||||
|
|
||||||
|
# Verify we have some modules loaded
|
||||||
|
assert len(initial_modules) > 0
|
||||||
|
|
||||||
|
# Clear cache
|
||||||
|
clear_import_cache()
|
||||||
|
|
||||||
|
# Check modules were removed
|
||||||
|
remaining_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")}
|
||||||
|
assert len(remaining_modules) < len(initial_modules)
|
||||||
|
|
||||||
|
# Verify we can reimport
|
||||||
|
assert "transformers.models.auto.modeling_auto" in sys.modules
|
Loading…
Reference in New Issue
Block a user