diff --git a/docs/source/en/how_to_hack_models.md b/docs/source/en/how_to_hack_models.md index 411539e104b..5e2aa8297bc 100644 --- a/docs/source/en/how_to_hack_models.md +++ b/docs/source/en/how_to_hack_models.md @@ -24,7 +24,37 @@ You'll learn how to: - Modify a model's architecture by changing its attention mechanism. - Apply techniques like Low-Rank Adaptation (LoRA) to specific model components. -We encourage you to contribute your own hacks and share them here with the community1 +We encourage you to contribute your own hacks and share them here with the community! + +## Efficient Development Workflow + +When modifying model code, you'll often need to test your changes without restarting your Python session. The `clear_import_cache()` utility helps with this workflow, especially during model development and contribution when you need to frequently test and compare model outputs: + +```python +from transformers import AutoModel +model = AutoModel.from_pretrained("bert-base-uncased") + +# Make modifications to the transformers code... + +# Clear the cache to reload the modified code +from transformers.utils.import_utils import clear_import_cache +clear_import_cache() + +# Reimport to get the changes +from transformers import AutoModel +model = AutoModel.from_pretrained("bert-base-uncased") # Will use updated code +``` + +This is particularly useful when: +- Iteratively modifying model architectures +- Debugging model implementations +- Testing changes during model development +- Comparing outputs between original and modified versions +- Working on model contributions + +The `clear_import_cache()` function removes all cached Transformers modules and allows Python to reload the modified code. This enables rapid development cycles without constantly restarting your environment. + +This workflow is especially valuable when implementing new models, where you need to frequently compare outputs between the original implementation and your Transformers version (as described in the [Add New Model](https://huggingface.co/docs/transformers/add_new_model) guide). ## Example: Modifying the Attention Mechanism in the Segment Anything Model (SAM) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 5a6dd937519..41065a5d11c 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -2271,3 +2271,28 @@ def define_import_structure(module_path: str) -> IMPORT_STRUCTURE_T: """ import_structure = create_import_structure_from_path(module_path) return spread_import_structure(import_structure) + + +def clear_import_cache(): + """ + Clear cached Transformers modules to allow reloading modified code. + + This is useful when actively developing/modifying Transformers code. + """ + # Get all transformers modules + transformers_modules = [mod_name for mod_name in sys.modules if mod_name.startswith("transformers.")] + + # Remove them from sys.modules + for mod_name in transformers_modules: + module = sys.modules[mod_name] + # Clear _LazyModule caches if applicable + if isinstance(module, _LazyModule): + module._objects = {} # Clear cached objects + del sys.modules[mod_name] + + # Force reload main transformers module + if "transformers" in sys.modules: + main_module = sys.modules["transformers"] + if isinstance(main_module, _LazyModule): + main_module._objects = {} # Clear cached objects + importlib.reload(main_module) diff --git a/tests/utils/test_import_utils.py b/tests/utils/test_import_utils.py new file mode 100644 index 00000000000..3d846174aca --- /dev/null +++ b/tests/utils/test_import_utils.py @@ -0,0 +1,23 @@ +import sys + +from transformers.utils.import_utils import clear_import_cache + + +def test_clear_import_cache(): + # Import some transformers modules + + # Get initial module count + initial_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")} + + # Verify we have some modules loaded + assert len(initial_modules) > 0 + + # Clear cache + clear_import_cache() + + # Check modules were removed + remaining_modules = {name: mod for name, mod in sys.modules.items() if name.startswith("transformers.")} + assert len(remaining_modules) < len(initial_modules) + + # Verify we can reimport + assert "transformers.models.auto.modeling_auto" in sys.modules