mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-02 04:10:06 +06:00
Add kwargs for timm.create_model in TimmWrapper (#38860)
* Add init kwargs for timm wrapper * model_init_kwargs -> model_args * add save-load test * fixup
This commit is contained in:
parent
ff95974bc6
commit
9120567b02
@ -15,7 +15,7 @@
|
||||
|
||||
"""Configuration for TimmWrapper models"""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import is_timm_available, logging, requires_backends
|
||||
@ -45,6 +45,9 @@ class TimmWrapperConfig(PretrainedConfig):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
do_pooling (`bool`, *optional*, defaults to `True`):
|
||||
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
|
||||
model_args (`dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments to pass to the `timm.create_model` function. e.g. `model_args={"depth": 3}`
|
||||
for `timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k` to create a model with 3 blocks. Defaults to `None`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@ -60,9 +63,16 @@ class TimmWrapperConfig(PretrainedConfig):
|
||||
|
||||
model_type = "timm_wrapper"
|
||||
|
||||
def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
initializer_range: float = 0.02,
|
||||
do_pooling: bool = True,
|
||||
model_args: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.initializer_range = initializer_range
|
||||
self.do_pooling = do_pooling
|
||||
self.model_args = model_args # named "model_args" for BC with timm
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@classmethod
|
||||
|
@ -116,7 +116,8 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
|
||||
def __init__(self, config: TimmWrapperConfig):
|
||||
super().__init__(config)
|
||||
# using num_classes=0 to avoid creating classification head
|
||||
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0)
|
||||
extra_init_kwargs = config.model_args or {}
|
||||
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0, **extra_init_kwargs)
|
||||
self.post_init()
|
||||
|
||||
@auto_docstring
|
||||
@ -233,7 +234,10 @@ class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
|
||||
"or use `TimmWrapperModel` for feature extraction."
|
||||
)
|
||||
|
||||
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels)
|
||||
extra_init_kwargs = config.model_args or {}
|
||||
self.timm_model = timm.create_model(
|
||||
config.architecture, pretrained=False, num_classes=config.num_labels, **extra_init_kwargs
|
||||
)
|
||||
self.num_labels = config.num_labels
|
||||
self.post_init()
|
||||
|
||||
|
@ -237,6 +237,24 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
self.assertEqual(config.id2label, restored_config.id2label)
|
||||
self.assertEqual(config.label2id, restored_config.label2id)
|
||||
|
||||
def test_model_init_args(self):
|
||||
# test init from config
|
||||
config = TimmWrapperConfig.from_pretrained(
|
||||
"timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
|
||||
model_args={"depth": 3},
|
||||
)
|
||||
model = TimmWrapperModel(config)
|
||||
self.assertEqual(len(model.timm_model.blocks), 3)
|
||||
|
||||
cls_model = TimmWrapperForImageClassification(config)
|
||||
self.assertEqual(len(cls_model.timm_model.blocks), 3)
|
||||
|
||||
# test save load
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
restored_model = TimmWrapperModel.from_pretrained(tmpdirname)
|
||||
self.assertEqual(len(restored_model.timm_model.blocks), 3)
|
||||
|
||||
|
||||
# We will verify our results on an image of cute cats
|
||||
def prepare_img():
|
||||
|
Loading…
Reference in New Issue
Block a user