mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-02 12:20:05 +06:00
Add kwargs for timm.create_model in TimmWrapper (#38860)
* Add init kwargs for timm wrapper * model_init_kwargs -> model_args * add save-load test * fixup
This commit is contained in:
parent
ff95974bc6
commit
9120567b02
@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
"""Configuration for TimmWrapper models"""
|
"""Configuration for TimmWrapper models"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import is_timm_available, logging, requires_backends
|
from ...utils import is_timm_available, logging, requires_backends
|
||||||
@ -45,6 +45,9 @@ class TimmWrapperConfig(PretrainedConfig):
|
|||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
do_pooling (`bool`, *optional*, defaults to `True`):
|
do_pooling (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
|
Whether to do pooling for the last_hidden_state in `TimmWrapperModel` or not.
|
||||||
|
model_args (`dict[str, Any]`, *optional*):
|
||||||
|
Additional keyword arguments to pass to the `timm.create_model` function. e.g. `model_args={"depth": 3}`
|
||||||
|
for `timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k` to create a model with 3 blocks. Defaults to `None`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
@ -60,9 +63,16 @@ class TimmWrapperConfig(PretrainedConfig):
|
|||||||
|
|
||||||
model_type = "timm_wrapper"
|
model_type = "timm_wrapper"
|
||||||
|
|
||||||
def __init__(self, initializer_range: float = 0.02, do_pooling: bool = True, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
do_pooling: bool = True,
|
||||||
|
model_args: Optional[dict[str, Any]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.do_pooling = do_pooling
|
self.do_pooling = do_pooling
|
||||||
|
self.model_args = model_args # named "model_args" for BC with timm
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -116,7 +116,8 @@ class TimmWrapperModel(TimmWrapperPreTrainedModel):
|
|||||||
def __init__(self, config: TimmWrapperConfig):
|
def __init__(self, config: TimmWrapperConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
# using num_classes=0 to avoid creating classification head
|
# using num_classes=0 to avoid creating classification head
|
||||||
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0)
|
extra_init_kwargs = config.model_args or {}
|
||||||
|
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=0, **extra_init_kwargs)
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
@ -233,7 +234,10 @@ class TimmWrapperForImageClassification(TimmWrapperPreTrainedModel):
|
|||||||
"or use `TimmWrapperModel` for feature extraction."
|
"or use `TimmWrapperModel` for feature extraction."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.timm_model = timm.create_model(config.architecture, pretrained=False, num_classes=config.num_labels)
|
extra_init_kwargs = config.model_args or {}
|
||||||
|
self.timm_model = timm.create_model(
|
||||||
|
config.architecture, pretrained=False, num_classes=config.num_labels, **extra_init_kwargs
|
||||||
|
)
|
||||||
self.num_labels = config.num_labels
|
self.num_labels = config.num_labels
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
@ -237,6 +237,24 @@ class TimmWrapperModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
self.assertEqual(config.id2label, restored_config.id2label)
|
self.assertEqual(config.id2label, restored_config.id2label)
|
||||||
self.assertEqual(config.label2id, restored_config.label2id)
|
self.assertEqual(config.label2id, restored_config.label2id)
|
||||||
|
|
||||||
|
def test_model_init_args(self):
|
||||||
|
# test init from config
|
||||||
|
config = TimmWrapperConfig.from_pretrained(
|
||||||
|
"timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k",
|
||||||
|
model_args={"depth": 3},
|
||||||
|
)
|
||||||
|
model = TimmWrapperModel(config)
|
||||||
|
self.assertEqual(len(model.timm_model.blocks), 3)
|
||||||
|
|
||||||
|
cls_model = TimmWrapperForImageClassification(config)
|
||||||
|
self.assertEqual(len(cls_model.timm_model.blocks), 3)
|
||||||
|
|
||||||
|
# test save load
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
restored_model = TimmWrapperModel.from_pretrained(tmpdirname)
|
||||||
|
self.assertEqual(len(restored_model.timm_model.blocks), 3)
|
||||||
|
|
||||||
|
|
||||||
# We will verify our results on an image of cute cats
|
# We will verify our results on an image of cute cats
|
||||||
def prepare_img():
|
def prepare_img():
|
||||||
|
Loading…
Reference in New Issue
Block a user