From 9120567b02a551d198337e21bee8c1465f389ab2 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Fri, 20 Jun 2025 13:00:09 +0100 Subject: [PATCH] Add kwargs for timm.create_model in TimmWrapper (#38860) * Add init kwargs for timm wrapper * model_init_kwargs -> model_args * add save-load test * fixup --- .../timm_wrapper/configuration_timm_wrapper.py | 14 ++++++++++++-- .../timm_wrapper/modeling_timm_wrapper.py | 8 ++++++-- .../timm_wrapper/test_modeling_timm_wrapper.py | 18 ++++++++++++++++++ 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py index 3d542de6aa7..39ed2098d68 100644 --- a/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/configuration_timm_wrapper.py @@ -15,7 +15,7 @@ """Configuration for TimmWrapper models""" -from typing import Any +from typing import Any, Optional from ...configuration_utils import PretrainedConfig from ...utils import is_timm_available, logging, requires_backends @@ -45,6 +45,9 @@ class TimmWrapperConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. do_pooling (`bool`, *optional*, defaults to `True`): Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not. + model_args (`dict[str, Any]`, *optional*): + Additional keyword arguments to pass to the `timm.create_model` function. e.g. `model_args={"depth": 3}` + for `timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k` to create a model with 3 blocks. Defaults to `None`. Example: ```python @@ -60,9 +63,16 @@ class TimmWrapperConfig(PretrainedConfig): model_type = "timm_wrapper" - def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs): + def __init__( + self, + initializer_range: float = 0.02, + do_pooling: bool = True, + model_args: Optional[dict[str, Any]] = None, + **kwargs, + ): self.initializer_range = initializer_range self.do_pooling = do_pooling + self.model_args = model_args # named "model_args" for BC with timm super().__init__(**kwargs) @classmethod diff --git a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py index 57c96aef27b..7c2fe021232 100644 --- a/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py +++ b/src/transformers/models/timm_wrapper/modeling_timm_wrapper.py @@ -116,7 +116,8 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel): def __init__(self, config: TimmWrapperConfig): super().__init__(config) # using num_classes=0 to avoid creating classification head - self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0) + extra_init_kwargs = config.model_args or {} + self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0, **extra_init_kwargs) self.post_init() @auto_docstring @@ -233,7 +234,10 @@ class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel): "or use `TimmWrapperModel` for feature extraction." ) - self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels) + extra_init_kwargs = config.model_args or {} + self.timm_model = timm.create_model( + config.architecture, pretrained=False, num_classes=config.num_labels, **extra_init_kwargs + ) self.num_labels = config.num_labels self.post_init() diff --git a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py index 821ade8dbcb..f7f374ed574 100644 --- a/tests/models/timm_wrapper/test_modeling_timm_wrapper.py +++ b/tests/models/timm_wrapper/test_modeling_timm_wrapper.py @@ -237,6 +237,24 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC self.assertEqual(config.id2label, restored_config.id2label) self.assertEqual(config.label2id, restored_config.label2id) + def test_model_init_args(self): + # test init from config + config = TimmWrapperConfig.from_pretrained( + "timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k", + model_args={"depth": 3}, + ) + model = TimmWrapperModel(config) + self.assertEqual(len(model.timm_model.blocks), 3) + + cls_model = TimmWrapperForImageClassification(config) + self.assertEqual(len(cls_model.timm_model.blocks), 3) + + # test save load + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + restored_model = TimmWrapperModel.from_pretrained(tmpdirname) + self.assertEqual(len(restored_model.timm_model.blocks), 3) + # We will verify our results on an image of cute cats def prepare_img():