Add kwargs for timm.create_model in TimmWrapper (#38860)

* Add init kwargs for timm wrapper

* model_init_kwargs -> model_args

* add save-load test

* fixup
This commit is contained in:
Pavel Iakubovskii 2025-06-20 13:00:09 +01:00 committed by GitHub
parent ff95974bc6
commit 9120567b02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 4 deletions

View File

@ -15,7 +15,7 @@
"""Configuration for TimmWrapper models"""
from typing import Any
from typing import Any, Optional
from ...configuration_utils import PretrainedConfig
from ...utils import is_timm_available, logging, requires_backends
@ -45,6 +45,9 @@ class TimmWrapperConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
do_pooling (`bool`, *optional*, defaults to `True`):
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
model_args (`dict[str, Any]`, *optional*):
Additional keyword arguments to pass to the `timm.create_model` function. e.g. `model_args={"depth": 3}`
for `timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k` to create a model with 3 blocks. Defaults to `None`.
Example:
```python
@ -60,9 +63,16 @@ class TimmWrapperConfig(PretrainedConfig):
model_type = "timm_wrapper"
def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs):
def __init__(
self,
initializer_range: float = 0.02,
do_pooling: bool = True,
model_args: Optional[dict[str, Any]] = None,
**kwargs,
):
self.initializer_range = initializer_range
self.do_pooling = do_pooling
self.model_args = model_args # named "model_args" for BC with timm
super().__init__(**kwargs)
@classmethod

View File

@ -116,7 +116,8 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
def __init__(self, config: TimmWrapperConfig):
super().__init__(config)
# using num_classes=0 to avoid creating classification head
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0)
extra_init_kwargs = config.model_args or {}
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0, **extra_init_kwargs)
self.post_init()
@auto_docstring
@ -233,7 +234,10 @@ class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
"or use `TimmWrapperModel` for feature extraction."
)
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels)
extra_init_kwargs = config.model_args or {}
self.timm_model = timm.create_model(
config.architecture, pretrained=False, num_classes=config.num_labels, **extra_init_kwargs
)
self.num_labels = config.num_labels
self.post_init()

View File

@ -237,6 +237,24 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
self.assertEqual(config.id2label, restored_config.id2label)
self.assertEqual(config.label2id, restored_config.label2id)
def test_model_init_args(self):
# test init from config
config = TimmWrapperConfig.from_pretrained(
"timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
model_args={"depth": 3},
)
model = TimmWrapperModel(config)
self.assertEqual(len(model.timm_model.blocks), 3)
cls_model = TimmWrapperForImageClassification(config)
self.assertEqual(len(cls_model.timm_model.blocks), 3)
# test save load
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
restored_model = TimmWrapperModel.from_pretrained(tmpdirname)
self.assertEqual(len(restored_model.timm_model.blocks), 3)
# We will verify our results on an image of cute cats
def prepare_img():