mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
Add TimmBackbone model (#22619)
* Add test_backbone for convnext * Add TimmBackbone model * Add check for backbone type * Tidying up - config checks * Update convnextv2 * Tidy up * Fix indices & clearer comment * Exceptions for config checks * Correclty update config for tests * Safer imports * Safer safer imports * Fix where decorators go * Update import logic and backbone tests * More import fixes * Fixup * Only import all_models if torch available * Fix kwarg updates in from_pretrained & main rebase * Tidy up * Add tests for AutoBackbone * Tidy up * Fix import error * Fix up * Install nattan in doc_test_job * Revert back to setting self._out_xxx directly * Bug fix - out_indices mapping from out_features * Fix tests * Dont accept output_loading_info for Timm models * Set out_xxx and don't remap * Use smaller checkpoint for test * Don't remap timm indices - check out_indices based on stage names * Skip test as it's n/a * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Cleaner imports / spelling is hard --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
b8935980a2
commit
a717e0318c
@ -456,6 +456,7 @@ doc_test_job = CircleCIJob(
|
||||
"pip install -e .[dev]",
|
||||
"pip install git+https://github.com/huggingface/accelerate",
|
||||
"pip install --upgrade pytest pytest-sugar",
|
||||
"pip install natten",
|
||||
"find -name __pycache__ -delete",
|
||||
"find . -name \*.pyc -delete",
|
||||
# Add an empty file to keep the test step running correctly even no file is selected to be tested.
|
||||
|
@ -424,6 +424,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TimeSformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TimmBackbone | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| Trajectory Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
@ -483,6 +483,7 @@ _import_structure = {
|
||||
"TimeSeriesTransformerConfig",
|
||||
],
|
||||
"models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"],
|
||||
"models.timm_backbone": ["TimmBackboneConfig"],
|
||||
"models.trajectory_transformer": [
|
||||
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"TrajectoryTransformerConfig",
|
||||
@ -2578,6 +2579,7 @@ else:
|
||||
"TimesformerPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.timm_backbone"].extend(["TimmBackbone"])
|
||||
_import_structure["models.trajectory_transformer"].extend(
|
||||
[
|
||||
"TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -4288,6 +4290,7 @@ if TYPE_CHECKING:
|
||||
TimeSeriesTransformerConfig,
|
||||
)
|
||||
from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig
|
||||
from .models.timm_backbone import TimmBackboneConfig
|
||||
from .models.trajectory_transformer import (
|
||||
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TrajectoryTransformerConfig,
|
||||
@ -6024,6 +6027,7 @@ if TYPE_CHECKING:
|
||||
TimesformerModel,
|
||||
TimesformerPreTrainedModel,
|
||||
)
|
||||
from .models.timm_backbone import TimmBackbone
|
||||
from .models.trajectory_transformer import (
|
||||
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TrajectoryTransformerModel,
|
||||
|
@ -186,6 +186,7 @@ from . import (
|
||||
tapex,
|
||||
time_series_transformer,
|
||||
timesformer,
|
||||
timm_backbone,
|
||||
trajectory_transformer,
|
||||
transfo_xl,
|
||||
trocr,
|
||||
|
@ -19,7 +19,7 @@ from collections import OrderedDict
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module
|
||||
from ...utils import copy_func, logging
|
||||
from ...utils import copy_func, logging, requires_backends
|
||||
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
|
||||
|
||||
|
||||
@ -515,6 +515,48 @@ class _BaseAutoModelClass:
|
||||
cls._model_mapping.register(config_class, model_class)
|
||||
|
||||
|
||||
class _BaseAutoBackboneClass(_BaseAutoModelClass):
|
||||
# Base class for auto backbone models.
|
||||
_model_mapping = None
|
||||
|
||||
@classmethod
|
||||
def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
requires_backends(cls, ["vision", "timm"])
|
||||
from ...models.timm_backbone import TimmBackboneConfig
|
||||
|
||||
config = kwargs.pop("config", TimmBackboneConfig())
|
||||
|
||||
use_timm = kwargs.pop("use_timm_backbone", True)
|
||||
if not use_timm:
|
||||
raise ValueError("`use_timm_backbone` must be `True` for timm backbones")
|
||||
|
||||
if kwargs.get("out_features", None) is not None:
|
||||
raise ValueError("Cannot specify `out_features` for timm backbones")
|
||||
|
||||
if kwargs.get("output_loading_info", False):
|
||||
raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
|
||||
|
||||
num_channels = kwargs.pop("num_channels", config.num_channels)
|
||||
features_only = kwargs.pop("features_only", config.features_only)
|
||||
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
|
||||
out_indices = kwargs.pop("out_indices", config.out_indices)
|
||||
config = TimmBackboneConfig(
|
||||
backbone=pretrained_model_name_or_path,
|
||||
num_channels=num_channels,
|
||||
features_only=features_only,
|
||||
use_pretrained_backbone=use_pretrained_backbone,
|
||||
out_indices=out_indices,
|
||||
)
|
||||
return super().from_config(config, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
if kwargs.get("use_timm_backbone", False):
|
||||
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
|
||||
|
||||
def insert_head_doc(docstring, head_doc=""):
|
||||
if len(head_doc) > 0:
|
||||
return docstring.replace(
|
||||
|
@ -186,6 +186,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
||||
("tapas", "TapasConfig"),
|
||||
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
||||
("timesformer", "TimesformerConfig"),
|
||||
("timm_backbone", "TimmBackboneConfig"),
|
||||
("trajectory_transformer", "TrajectoryTransformerConfig"),
|
||||
("transfo-xl", "TransfoXLConfig"),
|
||||
("trocr", "TrOCRConfig"),
|
||||
@ -579,6 +580,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
||||
("tapex", "TAPEX"),
|
||||
("time_series_transformer", "Time Series Transformer"),
|
||||
("timesformer", "TimeSformer"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
("trajectory_transformer", "Trajectory Transformer"),
|
||||
("transfo-xl", "Transformer-XL"),
|
||||
("trocr", "TrOCR"),
|
||||
|
@ -18,7 +18,7 @@ import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...utils import logging
|
||||
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
||||
from .auto_factory import _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||
|
||||
|
||||
@ -179,6 +179,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("tapas", "TapasModel"),
|
||||
("time_series_transformer", "TimeSeriesTransformerModel"),
|
||||
("timesformer", "TimesformerModel"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
("trajectory_transformer", "TrajectoryTransformerModel"),
|
||||
("transfo-xl", "TransfoXLModel"),
|
||||
("tvlt", "TvltModel"),
|
||||
@ -999,6 +1000,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
||||
("nat", "NatBackbone"),
|
||||
("resnet", "ResNetBackbone"),
|
||||
("swin", "SwinBackbone"),
|
||||
("timm_backbone", "TimmBackbone"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -1330,7 +1332,7 @@ class AutoModelForAudioXVector(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
|
||||
|
||||
|
||||
class AutoBackbone(_BaseAutoModelClass):
|
||||
class AutoBackbone(_BaseAutoBackboneClass):
|
||||
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_bit import BitConfig
|
||||
|
||||
|
||||
@ -845,14 +845,10 @@ class BitForImageClassification(BitPreTrainedModel):
|
||||
class BitBackbone(BitPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
self.bit = BitModel(config)
|
||||
|
||||
self.num_features = [config.embedding_size] + config.hidden_sizes
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
|
||||
# initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
@ -37,7 +37,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_convnext import ConvNextConfig
|
||||
|
||||
|
||||
@ -481,15 +481,11 @@ class ConvNextForImageClassification(ConvNextPreTrainedModel):
|
||||
class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
self.embeddings = ConvNextEmbeddings(config)
|
||||
self.encoder = ConvNextEncoder(config)
|
||||
|
||||
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
hidden_states_norms = {}
|
||||
|
@ -37,7 +37,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_convnextv2 import ConvNextV2Config
|
||||
|
||||
|
||||
@ -504,15 +504,11 @@ class ConvNextV2ForImageClassification(ConvNextV2PreTrainedModel):
|
||||
class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
self.embeddings = ConvNextV2Embeddings(config)
|
||||
self.encoder = ConvNextV2Encoder(config)
|
||||
|
||||
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
hidden_states_norms = {}
|
||||
|
@ -39,7 +39,7 @@ from ...utils import (
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_dinat import DinatConfig
|
||||
|
||||
|
||||
@ -883,17 +883,12 @@ class DinatForImageClassification(DinatPreTrainedModel):
|
||||
class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
requires_backends(self, ["natten"])
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
|
||||
self.embeddings = DinatEmbeddings(config)
|
||||
self.encoder = DinatEncoder(config)
|
||||
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
|
@ -36,7 +36,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_focalnet import FocalNetConfig
|
||||
|
||||
|
||||
@ -981,16 +981,12 @@ class FocalNetForImageClassification(FocalNetPreTrainedModel):
|
||||
FOCALNET_START_DOCSTRING,
|
||||
)
|
||||
class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: FocalNetConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
self.focalnet = FocalNetModel(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.num_features = [config.embed_dim] + config.hidden_sizes
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.focalnet = FocalNetModel(config)
|
||||
|
||||
# initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
@ -29,7 +29,7 @@ from ...file_utils import ModelOutput
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
||||
|
||||
|
||||
@ -852,17 +852,11 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
|
||||
|
||||
def __init__(self, config: MaskFormerSwinConfig):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
self.model = MaskFormerSwinModel(config)
|
||||
|
||||
self._out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
if "stem" in self.out_features:
|
||||
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
|
||||
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
self.hidden_states_norms = nn.ModuleList(
|
||||
[nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]]
|
||||
|
@ -39,7 +39,7 @@ from ...utils import (
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_nat import NatConfig
|
||||
|
||||
|
||||
@ -861,17 +861,12 @@ class NatForImageClassification(NatPreTrainedModel):
|
||||
class NatBackbone(NatPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
requires_backends(self, ["natten"])
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
|
||||
self.embeddings = NatEmbeddings(config)
|
||||
self.encoder = NatEncoder(config)
|
||||
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
|
@ -36,7 +36,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_resnet import ResNetConfig
|
||||
|
||||
|
||||
@ -432,16 +432,12 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
|
||||
class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
self.num_features = [config.embedding_size] + config.hidden_sizes
|
||||
self.embedder = ResNetEmbeddings(config)
|
||||
self.encoder = ResNetEncoder(config)
|
||||
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.num_features = [config.embedding_size] + config.hidden_sizes
|
||||
|
||||
# initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
|
@ -38,7 +38,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_swin import SwinConfig
|
||||
|
||||
|
||||
@ -1259,17 +1259,12 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
||||
class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
|
||||
def __init__(self, config: SwinConfig):
|
||||
super().__init__(config)
|
||||
super()._init_backbone(config)
|
||||
|
||||
self.stage_names = config.stage_names
|
||||
|
||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
self.embeddings = SwinEmbeddings(config)
|
||||
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
|
||||
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
hidden_states_norms = {}
|
||||
for stage, num_channels in zip(self._out_features, self.channels):
|
||||
|
49
src/transformers/models/timm_backbone/__init__.py
Normal file
49
src/transformers/models/timm_backbone/__init__.py
Normal file
@ -0,0 +1,49 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {"configuration_timm_backbone": ["TimmBackboneConfig"]}
|
||||
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_timm_backbone"] = ["TimmBackbone"]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_timm_backbone import TimmBackboneConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_timm_backbone import TimmBackbone
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
@ -0,0 +1,78 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" Configuration for Backbone models"""
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TimmBackboneConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration for a timm backbone [`TimmBackbone`].
|
||||
|
||||
It is used to instantiate a timm backbone model according to the specified arguments, defining the model.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
backbone (`str`, *optional*):
|
||||
The timm checkpoint to load.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
The number of input channels.
|
||||
features_only (`bool`, *optional*, defaults to `True`):
|
||||
Whether to output only the features or also the logits.
|
||||
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use a pretrained backbone.
|
||||
out_indices (`List[int]`, *optional*):
|
||||
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
||||
many stages the model has). Will default to the last stage if unset.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import TimmBackboneConfig, TimmBackbone
|
||||
|
||||
>>> # Initializing a timm backbone
|
||||
>>> configuration = TimmBackboneConfig("resnet50")
|
||||
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = TimmBackbone(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
model_type = "timm_backbone"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone=None,
|
||||
num_channels=3,
|
||||
features_only=True,
|
||||
use_pretrained_backbone=True,
|
||||
out_indices=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.backbone = backbone
|
||||
self.num_channels = num_channels
|
||||
self.features_only = features_only
|
||||
self.use_pretrained_backbone = use_pretrained_backbone
|
||||
self.use_timm_backbone = True
|
||||
self.out_indices = out_indices if out_indices is not None else (-1,)
|
140
src/transformers/models/timm_backbone/modeling_timm_backbone.py
Normal file
140
src/transformers/models/timm_backbone/modeling_timm_backbone.py
Normal file
@ -0,0 +1,140 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import is_timm_available, is_torch_available, requires_backends
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_timm_backbone import TimmBackboneConfig
|
||||
|
||||
|
||||
if is_timm_available():
|
||||
import timm
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class TimmBackbone(PreTrainedModel, BackboneMixin):
|
||||
"""
|
||||
Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the
|
||||
other models in the library keeping the same API.
|
||||
"""
|
||||
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = False
|
||||
config_class = TimmBackboneConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
requires_backends(self, "timm")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
if config.backbone is None:
|
||||
raise ValueError("backbone is not set in the config. Please set it to a timm model name.")
|
||||
|
||||
if config.backbone not in timm.list_models():
|
||||
raise ValueError(f"backbone {config.backbone} is not supported by timm.")
|
||||
|
||||
if hasattr(config, "out_features") and config.out_features is not None:
|
||||
raise ValueError("out_features is not supported by TimmBackbone. Please use out_indices instead.")
|
||||
|
||||
pretrained = getattr(config, "use_pretrained_backbone", None)
|
||||
if pretrained is None:
|
||||
raise ValueError("use_pretrained_backbone is not set in the config. Please set it to True or False.")
|
||||
|
||||
# We just take the final layer by default. This matches the default for the transformers models.
|
||||
out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)
|
||||
|
||||
self._backbone = timm.create_model(
|
||||
config.backbone,
|
||||
pretrained=pretrained,
|
||||
# This is currently not possible for transformer architectures.
|
||||
features_only=config.features_only,
|
||||
in_chans=config.num_channels,
|
||||
out_indices=out_indices,
|
||||
**kwargs,
|
||||
)
|
||||
# These are used to control the output of the model when called. If output_hidden_states is True, then
|
||||
# return_layers is modified to include all layers.
|
||||
self._return_layers = self._backbone.return_layers
|
||||
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
|
||||
super()._init_backbone(config)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
requires_backends(cls, ["vision", "timm"])
|
||||
from ...models.timm_backbone import TimmBackboneConfig
|
||||
|
||||
config = kwargs.pop("config", TimmBackboneConfig())
|
||||
|
||||
use_timm = kwargs.pop("use_timm_backbone", True)
|
||||
if not use_timm:
|
||||
raise ValueError("use_timm_backbone must be True for timm backbones")
|
||||
|
||||
num_channels = kwargs.pop("num_channels", config.num_channels)
|
||||
features_only = kwargs.pop("features_only", config.features_only)
|
||||
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
|
||||
out_indices = kwargs.pop("out_indices", config.out_indices)
|
||||
config = TimmBackboneConfig(
|
||||
backbone=pretrained_model_name_or_path,
|
||||
num_channels=num_channels,
|
||||
features_only=features_only,
|
||||
use_pretrained_backbone=use_pretrained_backbone,
|
||||
out_indices=out_indices,
|
||||
)
|
||||
return super()._from_config(config, **kwargs)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""
|
||||
Empty init weights function to ensure compatibility of the class in the library.
|
||||
"""
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs
|
||||
) -> Union[BackboneOutput, Tuple[Tensor, ...]]:
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
|
||||
if output_attentions:
|
||||
raise ValueError("Cannot output attentions for timm backbones at the moment")
|
||||
|
||||
if output_hidden_states:
|
||||
# We modify the return layers to include all the stages of the backbone
|
||||
self._backbone.return_layers = self._all_layers
|
||||
hidden_states = self._backbone(pixel_values, **kwargs)
|
||||
self._backbone.return_layers = self._return_layers
|
||||
feature_maps = tuple(hidden_states[i] for i in self.out_indices)
|
||||
else:
|
||||
feature_maps = self._backbone(pixel_values, **kwargs)
|
||||
hidden_states = None
|
||||
|
||||
feature_maps = tuple(feature_maps)
|
||||
hidden_states = tuple(hidden_states) if hidden_states is not None else None
|
||||
|
||||
if not return_dict:
|
||||
output = (feature_maps,)
|
||||
if output_hidden_states:
|
||||
output = output + (hidden_states,)
|
||||
return output
|
||||
|
||||
return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None)
|
@ -15,10 +15,16 @@
|
||||
|
||||
""" Collection of utils to be used by backbones and their components."""
|
||||
|
||||
import enum
|
||||
import inspect
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
class BackboneType(enum.Enum):
|
||||
TIMM = "timm"
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
def verify_out_features_out_indices(
|
||||
out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]]
|
||||
):
|
||||
@ -72,7 +78,7 @@ def _align_output_features_output_indices(
|
||||
out_indices = [len(stage_names) - 1]
|
||||
out_features = [stage_names[-1]]
|
||||
elif out_indices is None and out_features is not None:
|
||||
out_indices = [stage_names.index(layer) for layer in stage_names if layer in out_features]
|
||||
out_indices = [stage_names.index(layer) for layer in out_features]
|
||||
elif out_features is None and out_indices is not None:
|
||||
out_features = [stage_names[idx] for idx in out_indices]
|
||||
return out_features, out_indices
|
||||
@ -110,29 +116,57 @@ def get_aligned_output_features_output_indices(
|
||||
|
||||
|
||||
class BackboneMixin:
|
||||
@property
|
||||
def out_feature_channels(self):
|
||||
# the current backbones will output the number of channels for each stage
|
||||
# even if that stage is not in the out_features list.
|
||||
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
|
||||
backbone_type: Optional[BackboneType] = None
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
return [self.out_feature_channels[name] for name in self.out_features]
|
||||
def _init_timm_backbone(self, config) -> None:
|
||||
"""
|
||||
Initialize the backbone model from timm The backbone must already be loaded to self._backbone
|
||||
"""
|
||||
if getattr(self, "_backbone", None) is None:
|
||||
raise ValueError("self._backbone must be set before calling _init_timm_backbone")
|
||||
|
||||
def forward_with_filtered_kwargs(self, *args, **kwargs):
|
||||
signature = dict(inspect.signature(self.forward).parameters)
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
|
||||
return self(*args, **filtered_kwargs)
|
||||
# These will diagree with the defaults for the transformers models e.g. for resnet50
|
||||
# the transformer model has out_features = ['stem', 'stage1', 'stage2', 'stage3', 'stage4']
|
||||
# the timm model has out_features = ['act', 'layer1', 'layer2', 'layer3', 'layer4']
|
||||
self.stage_names = [stage["module"] for stage in self._backbone.feature_info.info]
|
||||
self.num_features = [stage["num_chs"] for stage in self._backbone.feature_info.info]
|
||||
out_indices = self._backbone.feature_info.out_indices
|
||||
out_features = self._backbone.feature_info.module_name()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
raise NotImplementedError("This method should be implemented by the derived class.")
|
||||
# We verify the out indices and out features are valid
|
||||
verify_out_features_out_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
self._out_features, self._out_indices = out_features, out_indices
|
||||
|
||||
def _init_transformers_backbone(self, config) -> None:
|
||||
stage_names = getattr(config, "stage_names")
|
||||
out_features = getattr(config, "out_features", None)
|
||||
out_indices = getattr(config, "out_indices", None)
|
||||
|
||||
self.stage_names = stage_names
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=stage_names
|
||||
)
|
||||
# Number of channels for each stage. This is set in the transformer backbone model init
|
||||
self.num_features = None
|
||||
|
||||
def _init_backbone(self, config) -> None:
|
||||
"""
|
||||
Method to initialize the backbone. This method is called by the constructor of the base class after the
|
||||
pretrained model weights have been loaded.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
self.use_timm_backbone = getattr(config, "use_timm_backbone", False)
|
||||
self.backbone_type = BackboneType.TIMM if self.use_timm_backbone else BackboneType.TRANSFORMERS
|
||||
|
||||
if self.backbone_type == BackboneType.TIMM:
|
||||
self._init_timm_backbone(config)
|
||||
elif self.backbone_type == BackboneType.TRANSFORMERS:
|
||||
self._init_transformers_backbone(config)
|
||||
else:
|
||||
raise ValueError(f"backbone_type {self.backbone_type} not supported.")
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
@ -160,6 +194,40 @@ class BackboneMixin:
|
||||
out_features=None, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
||||
@property
|
||||
def out_feature_channels(self):
|
||||
# the current backbones will output the number of channels for each stage
|
||||
# even if that stage is not in the out_features list.
|
||||
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
return [self.out_feature_channels[name] for name in self.out_features]
|
||||
|
||||
def forward_with_filtered_kwargs(self, *args, **kwargs):
|
||||
signature = dict(inspect.signature(self.forward).parameters)
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
|
||||
return self(*args, **filtered_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
raise NotImplementedError("This method should be implemented by the derived class.")
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to
|
||||
include the `out_features` and `out_indices` attributes.
|
||||
"""
|
||||
output = super().to_dict()
|
||||
output["out_features"] = output.pop("_out_features")
|
||||
output["out_indices"] = output.pop("_out_indices")
|
||||
return output
|
||||
|
||||
|
||||
class BackboneConfigMixin:
|
||||
"""
|
||||
|
@ -6806,6 +6806,13 @@ class TimesformerPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class TimmBackbone(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -45,6 +45,7 @@ if is_torch_available():
|
||||
from test_module.custom_modeling import CustomModel
|
||||
|
||||
from transformers import (
|
||||
AutoBackbone,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
@ -66,11 +67,13 @@ if is_torch_available():
|
||||
FunnelModel,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
ResNetBackbone,
|
||||
RobertaForMaskedLM,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
TapasConfig,
|
||||
TapasForQuestionAnswering,
|
||||
TimmBackbone,
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
@ -224,6 +227,42 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForTokenClassification)
|
||||
|
||||
@slow
|
||||
def test_auto_backbone_timm_model_from_pretrained(self):
|
||||
# Configs can't be loaded for timm models
|
||||
model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# We can't pass output_loading_info=True as we're loading from timm
|
||||
AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, output_loading_info=True)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, TimmBackbone)
|
||||
|
||||
# Check kwargs are correctly passed to the backbone
|
||||
model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_indices=(-1, -2))
|
||||
self.assertEqual(model.out_indices, (-1, -2))
|
||||
|
||||
# Check out_features cannot be passed to Timm backbones
|
||||
with self.assertRaises(ValueError):
|
||||
_ = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_features=["stage1"])
|
||||
|
||||
@slow
|
||||
def test_auto_backbone_from_pretrained(self):
|
||||
model = AutoBackbone.from_pretrained("microsoft/resnet-18")
|
||||
model, loading_info = AutoBackbone.from_pretrained("microsoft/resnet-18", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, ResNetBackbone)
|
||||
|
||||
# Check kwargs are correctly passed to the backbone
|
||||
model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_indices=[-1, -2])
|
||||
self.assertEqual(model.out_indices, [-1, -2])
|
||||
self.assertEqual(model.out_features, ["stage4", "stage3"])
|
||||
|
||||
model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_features=["stage2", "stage4"])
|
||||
self.assertEqual(model.out_indices, [2, 4])
|
||||
self.assertEqual(model.out_features, ["stage2", "stage4"])
|
||||
|
||||
def test_from_pretrained_identifier(self):
|
||||
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||
self.assertIsInstance(model, BertForMaskedLM)
|
||||
|
0
tests/models/timm_backbone/__init__.py
Normal file
0
tests/models/timm_backbone/__init__.py
Normal file
259
tests/models/timm_backbone/test_modeling_timm_backbone.py
Normal file
259
tests/models/timm_backbone/test_modeling_timm_backbone.py
Normal file
@ -0,0 +1,259 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import AutoBackbone
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.testing_utils import require_timm, require_torch, torch_device
|
||||
from transformers.utils.import_utils import is_torch_available
|
||||
|
||||
from ...test_backbone_common import BackboneTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import TimmBackbone, TimmBackboneConfig
|
||||
|
||||
|
||||
class TimmBackboneModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
out_indices=None,
|
||||
out_features=None,
|
||||
stage_names=None,
|
||||
backbone="resnet50",
|
||||
batch_size=3,
|
||||
image_size=32,
|
||||
num_channels=3,
|
||||
is_training=True,
|
||||
use_pretrained_backbone=True,
|
||||
):
|
||||
self.parent = parent
|
||||
self.out_indices = out_indices if out_indices is not None else [4]
|
||||
self.stage_names = stage_names
|
||||
self.out_features = out_features
|
||||
self.backbone = backbone
|
||||
self.batch_size = batch_size
|
||||
self.image_size = image_size
|
||||
self.num_channels = num_channels
|
||||
self.use_pretrained_backbone = use_pretrained_backbone
|
||||
self.is_training = is_training
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return TimmBackboneConfig(
|
||||
image_size=self.image_size,
|
||||
num_channels=self.num_channels,
|
||||
out_features=self.out_features,
|
||||
out_indices=self.out_indices,
|
||||
stage_names=self.stage_names,
|
||||
use_pretrained_backbone=self.use_pretrained_backbone,
|
||||
backbone=self.backbone,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = TimmBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(pixel_values)
|
||||
self.parent.assertEqual(
|
||||
result.feature_map[-1].shape,
|
||||
(self.batch_size, model.channels[-1], 14, 14),
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values = config_and_inputs
|
||||
inputs_dict = {"pixel_values": pixel_values}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_timm
|
||||
class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TimmBackbone,) if is_torch_available() else ()
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_pruning = False
|
||||
has_attentions = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TimmBackboneModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=PretrainedConfig, has_text_modality=False)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.create_and_test_config_to_json_string()
|
||||
self.config_tester.create_and_test_config_to_json_file()
|
||||
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||
self.config_tester.create_and_test_config_with_num_labels()
|
||||
self.config_tester.check_config_can_be_init_without_params()
|
||||
self.config_tester.check_config_arguments_init()
|
||||
|
||||
def test_timm_transformer_backbone_equivalence(self):
|
||||
timm_checkpoint = "resnet18"
|
||||
transformers_checkpoint = "microsoft/resnet-18"
|
||||
|
||||
timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True)
|
||||
transformers_model = AutoBackbone.from_pretrained(transformers_checkpoint)
|
||||
|
||||
self.assertEqual(len(timm_model.out_features), len(transformers_model.out_features))
|
||||
self.assertEqual(len(timm_model.stage_names), len(transformers_model.stage_names))
|
||||
self.assertEqual(timm_model.channels, transformers_model.channels)
|
||||
# Out indices are set to the last layer by default. For timm models, we don't know
|
||||
# the number of layers in advance, so we set it to (-1,), whereas for transformers
|
||||
# models, we set it to [len(stage_names) - 1] (kept for backward compatibility).
|
||||
self.assertEqual(timm_model.out_indices, (-1,))
|
||||
self.assertEqual(transformers_model.out_indices, [len(timm_model.stage_names) - 1])
|
||||
|
||||
timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True, out_indices=[1, 2, 3])
|
||||
transformers_model = AutoBackbone.from_pretrained(transformers_checkpoint, out_indices=[1, 2, 3])
|
||||
|
||||
self.assertEqual(timm_model.out_indices, transformers_model.out_indices)
|
||||
self.assertEqual(len(timm_model.out_features), len(transformers_model.out_features))
|
||||
self.assertEqual(timm_model.channels, transformers_model.channels)
|
||||
|
||||
@unittest.skip("TimmBackbone doesn't support feed forward chunking")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TimmBackbone doesn't have num_hidden_layers attribute")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TimmBackbone initialization is managed on the timm side")
|
||||
def test_initialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TimmBackbone models doesn't have inputs_embeds")
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TimmBackbone models doesn't have inputs_embeds")
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TimmBackbone model cannot be created without specifying a backbone checkpoint")
|
||||
def test_from_pretrained_no_checkpoint(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone")
|
||||
def test_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("model weights aren't tied in TimmBackbone.")
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("model weights aren't tied in TimmBackbone.")
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TimmBackbone doesn't have hidden size info in its configuration.")
|
||||
def test_channels(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("TimmBackbone doesn't support output_attentions.")
|
||||
def test_torchscript_output_attentions(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Safetensors is not supported by timm.")
|
||||
def test_can_use_safetensors(self):
|
||||
pass
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.forward)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["pixel_values"]
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(**inputs)
|
||||
output = outputs[0][-1]
|
||||
|
||||
# Encoder-/Decoder-only models
|
||||
hidden_states = outputs.hidden_states[0]
|
||||
hidden_states.retain_grad()
|
||||
|
||||
if self.has_attentions:
|
||||
attentions = outputs.attentions[0]
|
||||
attentions.retain_grad()
|
||||
|
||||
output.flatten()[0].backward(retain_graph=True)
|
||||
|
||||
self.assertIsNotNone(hidden_states.grad)
|
||||
|
||||
if self.has_attentions:
|
||||
self.assertIsNotNone(attentions.grad)
|
||||
|
||||
# TimmBackbone config doesn't have out_features attribute
|
||||
def test_create_from_modified_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(**inputs_dict)
|
||||
|
||||
self.assertEqual(len(result.feature_maps), len(config.out_indices))
|
||||
self.assertEqual(len(model.channels), len(config.out_indices))
|
||||
|
||||
# Check output of last stage is taken if out_features=None, out_indices=None
|
||||
modified_config = copy.deepcopy(config)
|
||||
modified_config.out_indices = None
|
||||
model = model_class(modified_config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(**inputs_dict)
|
||||
|
||||
self.assertEqual(len(result.feature_maps), 1)
|
||||
self.assertEqual(len(model.channels), 1)
|
||||
|
||||
# Check backbone can be initialized with fresh weights
|
||||
modified_config = copy.deepcopy(config)
|
||||
modified_config.use_pretrained_backbone = False
|
||||
model = model_class(modified_config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(**inputs_dict)
|
@ -17,6 +17,7 @@ import copy
|
||||
import inspect
|
||||
|
||||
from transformers.testing_utils import require_torch, torch_device
|
||||
from transformers.utils.backbone_utils import BackboneType
|
||||
|
||||
|
||||
@require_torch
|
||||
@ -104,6 +105,8 @@ class BackboneTesterMixin:
|
||||
|
||||
self.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
self.assertEqual(len(model.channels), len(config.out_features))
|
||||
self.assertEqual(len(result.feature_maps), len(config.out_indices))
|
||||
self.assertEqual(len(model.channels), len(config.out_indices))
|
||||
|
||||
# Check output of last stage is taken if out_features=None, out_indices=None
|
||||
modified_config = copy.deepcopy(config)
|
||||
@ -140,6 +143,7 @@ class BackboneTesterMixin:
|
||||
for backbone_class in self.all_model_classes:
|
||||
backbone = backbone_class(config)
|
||||
|
||||
self.assertTrue(hasattr(backbone, "backbone_type"))
|
||||
self.assertTrue(hasattr(backbone, "stage_names"))
|
||||
self.assertTrue(hasattr(backbone, "num_features"))
|
||||
self.assertTrue(hasattr(backbone, "out_indices"))
|
||||
@ -147,6 +151,7 @@ class BackboneTesterMixin:
|
||||
self.assertTrue(hasattr(backbone, "out_feature_channels"))
|
||||
self.assertTrue(hasattr(backbone, "channels"))
|
||||
|
||||
self.assertIsInstance(backbone.backbone_type, BackboneType)
|
||||
# Verify num_features has been initialized in the backbone init
|
||||
self.assertIsNotNone(backbone.num_features)
|
||||
self.assertTrue(len(backbone.channels) == len(backbone.out_indices))
|
||||
|
@ -77,6 +77,7 @@ SPECIAL_CASES_TO_ALLOW = {
|
||||
"AutoformerConfig": ["num_static_real_features", "num_time_features"],
|
||||
}
|
||||
|
||||
|
||||
# TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure
|
||||
SPECIAL_CASES_TO_ALLOW.update(
|
||||
{
|
||||
@ -172,6 +173,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
|
||||
"mask_index",
|
||||
"image_size",
|
||||
"use_cache",
|
||||
"out_features",
|
||||
"out_indices",
|
||||
]
|
||||
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
|
||||
|
||||
|
@ -39,6 +39,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"EncoderDecoderConfig",
|
||||
"RagConfig",
|
||||
"SpeechEncoderDecoderConfig",
|
||||
"TimmBackboneConfig",
|
||||
"VisionEncoderDecoderConfig",
|
||||
"VisionTextDualEncoderConfig",
|
||||
"LlamaConfig",
|
||||
|
@ -517,6 +517,7 @@ MODELS_NOT_IN_README = [
|
||||
"Speech Encoder decoder",
|
||||
"Speech2Text",
|
||||
"Speech2Text2",
|
||||
"TimmBackbone",
|
||||
"Vision Encoder decoder",
|
||||
"VisionTextDualEncoder",
|
||||
]
|
||||
|
@ -408,6 +408,7 @@ def get_model_modules():
|
||||
"modeling_speech_encoder_decoder",
|
||||
"modeling_flax_speech_encoder_decoder",
|
||||
"modeling_flax_vision_encoder_decoder",
|
||||
"modeling_timm_backbone",
|
||||
"modeling_transfo_xl_utilities",
|
||||
"modeling_tf_auto",
|
||||
"modeling_tf_encoder_decoder",
|
||||
@ -846,6 +847,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
||||
"NatBackbone",
|
||||
"ResNetBackbone",
|
||||
"SwinBackbone",
|
||||
"TimmBackbone",
|
||||
"TimmBackboneConfig",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user