Add TimmBackbone model (#22619)

* Add test_backbone for convnext

* Add TimmBackbone model

* Add check for backbone type

* Tidying up - config checks

* Update convnextv2

* Tidy up

* Fix indices & clearer comment

* Exceptions for config checks

* Correclty update config for tests

* Safer imports

* Safer safer imports

* Fix where decorators go

* Update import logic and backbone tests

* More import fixes

* Fixup

* Only import all_models if torch available

* Fix kwarg updates in from_pretrained & main rebase

* Tidy up

* Add tests for AutoBackbone

* Tidy up

* Fix import error

* Fix up

* Install nattan in doc_test_job

* Revert back to setting self._out_xxx directly

* Bug fix - out_indices mapping from out_features

* Fix tests

* Dont accept output_loading_info for Timm models

* Set out_xxx and don't remap

* Use smaller checkpoint for test

* Don't remap timm indices - check out_indices based on stage names

* Skip test as it's n/a

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Cleaner imports / spelling is hard

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
amyeroberts 2023-06-06 17:11:30 +01:00 committed by GitHub
parent b8935980a2
commit a717e0318c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 753 additions and 88 deletions

View File

@ -128,7 +128,7 @@ class CircleCIJob:
if self.command_timeout:
test_command = f"timeout {self.command_timeout} "
test_command += f"python -m pytest -n {self.pytest_num_workers} " + " ".join(pytest_flags)
if self.parallelism == 1:
if self.tests_to_run is None:
test_command += " << pipeline.parameters.tests_to_run >>"
@ -456,6 +456,7 @@ doc_test_job = CircleCIJob(
"pip install -e .[dev]",
"pip install git+https://github.com/huggingface/accelerate",
"pip install --upgrade pytest pytest-sugar",
"pip install natten",
"find -name __pycache__ -delete",
"find . -name \*.pyc -delete",
# Add an empty file to keep the test step running correctly even no file is selected to be tested.

View File

@ -424,6 +424,7 @@ Flax), PyTorch, and/or TensorFlow.
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
| Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| TimeSformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| TimmBackbone | ❌ | ❌ | ❌ | ❌ | ❌ |
| Trajectory Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |

View File

@ -483,6 +483,7 @@ _import_structure = {
"TimeSeriesTransformerConfig",
],
"models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"],
"models.timm_backbone": ["TimmBackboneConfig"],
"models.trajectory_transformer": [
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TrajectoryTransformerConfig",
@ -2578,6 +2579,7 @@ else:
"TimesformerPreTrainedModel",
]
)
_import_structure["models.timm_backbone"].extend(["TimmBackbone"])
_import_structure["models.trajectory_transformer"].extend(
[
"TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -4288,6 +4290,7 @@ if TYPE_CHECKING:
TimeSeriesTransformerConfig,
)
from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig
from .models.timm_backbone import TimmBackboneConfig
from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TrajectoryTransformerConfig,
@ -6024,6 +6027,7 @@ if TYPE_CHECKING:
TimesformerModel,
TimesformerPreTrainedModel,
)
from .models.timm_backbone import TimmBackbone
from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TrajectoryTransformerModel,

View File

@ -186,6 +186,7 @@ from . import (
tapex,
time_series_transformer,
timesformer,
timm_backbone,
trajectory_transformer,
transfo_xl,
trocr,

View File

@ -19,7 +19,7 @@ from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...utils import copy_func, logging
from ...utils import copy_func, logging, requires_backends
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
@ -515,6 +515,48 @@ class _BaseAutoModelClass:
cls._model_mapping.register(config_class, model_class)
class _BaseAutoBackboneClass(_BaseAutoModelClass):
# Base class for auto backbone models.
_model_mapping = None
@classmethod
def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
requires_backends(cls, ["vision", "timm"])
from ...models.timm_backbone import TimmBackboneConfig
config = kwargs.pop("config", TimmBackboneConfig())
use_timm = kwargs.pop("use_timm_backbone", True)
if not use_timm:
raise ValueError("`use_timm_backbone` must be `True` for timm backbones")
if kwargs.get("out_features", None) is not None:
raise ValueError("Cannot specify `out_features` for timm backbones")
if kwargs.get("output_loading_info", False):
raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
num_channels = kwargs.pop("num_channels", config.num_channels)
features_only = kwargs.pop("features_only", config.features_only)
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
out_indices = kwargs.pop("out_indices", config.out_indices)
config = TimmBackboneConfig(
backbone=pretrained_model_name_or_path,
num_channels=num_channels,
features_only=features_only,
use_pretrained_backbone=use_pretrained_backbone,
out_indices=out_indices,
)
return super().from_config(config, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if kwargs.get("use_timm_backbone", False):
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
def insert_head_doc(docstring, head_doc=""):
if len(head_doc) > 0:
return docstring.replace(

View File

@ -186,6 +186,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("tapas", "TapasConfig"),
("time_series_transformer", "TimeSeriesTransformerConfig"),
("timesformer", "TimesformerConfig"),
("timm_backbone", "TimmBackboneConfig"),
("trajectory_transformer", "TrajectoryTransformerConfig"),
("transfo-xl", "TransfoXLConfig"),
("trocr", "TrOCRConfig"),
@ -579,6 +580,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("tapex", "TAPEX"),
("time_series_transformer", "Time Series Transformer"),
("timesformer", "TimeSformer"),
("timm_backbone", "TimmBackbone"),
("trajectory_transformer", "Trajectory Transformer"),
("transfo-xl", "Transformer-XL"),
("trocr", "TrOCR"),

View File

@ -18,7 +18,7 @@ import warnings
from collections import OrderedDict
from ...utils import logging
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .auto_factory import _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
@ -179,6 +179,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("tapas", "TapasModel"),
("time_series_transformer", "TimeSeriesTransformerModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
("trajectory_transformer", "TrajectoryTransformerModel"),
("transfo-xl", "TransfoXLModel"),
("tvlt", "TvltModel"),
@ -999,6 +1000,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("nat", "NatBackbone"),
("resnet", "ResNetBackbone"),
("swin", "SwinBackbone"),
("timm_backbone", "TimmBackbone"),
]
)
@ -1330,7 +1332,7 @@ class AutoModelForAudioXVector(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
class AutoBackbone(_BaseAutoModelClass):
class AutoBackbone(_BaseAutoBackboneClass):
_model_mapping = MODEL_FOR_BACKBONE_MAPPING

View File

@ -39,7 +39,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_bit import BitConfig
@ -845,14 +845,10 @@ class BitForImageClassification(BitPreTrainedModel):
class BitBackbone(BitPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.bit = BitModel(config)
self.num_features = [config.embedding_size] + config.hidden_sizes
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
# initialize weights and apply final processing
self.post_init()

View File

@ -37,7 +37,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_convnext import ConvNextConfig
@ -481,15 +481,11 @@ class ConvNextForImageClassification(ConvNextPreTrainedModel):
class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.embeddings = ConvNextEmbeddings(config)
self.encoder = ConvNextEncoder(config)
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
# Add layer norms to hidden states of out_features
hidden_states_norms = {}

View File

@ -37,7 +37,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_convnextv2 import ConvNextV2Config
@ -504,15 +504,11 @@ class ConvNextV2ForImageClassification(ConvNextV2PreTrainedModel):
class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.embeddings = ConvNextV2Embeddings(config)
self.encoder = ConvNextV2Encoder(config)
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
# Add layer norms to hidden states of out_features
hidden_states_norms = {}

View File

@ -39,7 +39,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_dinat import DinatConfig
@ -883,17 +883,12 @@ class DinatForImageClassification(DinatPreTrainedModel):
class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
requires_backends(self, ["natten"])
self.stage_names = config.stage_names
self.embeddings = DinatEmbeddings(config)
self.encoder = DinatEncoder(config)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features

View File

@ -36,7 +36,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_focalnet import FocalNetConfig
@ -981,16 +981,12 @@ class FocalNetForImageClassification(FocalNetPreTrainedModel):
FOCALNET_START_DOCSTRING,
)
class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
def __init__(self, config):
def __init__(self, config: FocalNetConfig):
super().__init__(config)
self.stage_names = config.stage_names
self.focalnet = FocalNetModel(config)
super()._init_backbone(config)
self.num_features = [config.embed_dim] + config.hidden_sizes
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.focalnet = FocalNetModel(config)
# initialize weights and apply final processing
self.post_init()

View File

@ -29,7 +29,7 @@ from ...file_utils import ModelOutput
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_maskformer_swin import MaskFormerSwinConfig
@ -852,17 +852,11 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
def __init__(self, config: MaskFormerSwinConfig):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.model = MaskFormerSwinModel(config)
self._out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.hidden_states_norms = nn.ModuleList(
[nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]]

View File

@ -39,7 +39,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_nat import NatConfig
@ -861,17 +861,12 @@ class NatForImageClassification(NatPreTrainedModel):
class NatBackbone(NatPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
requires_backends(self, ["natten"])
self.stage_names = config.stage_names
self.embeddings = NatEmbeddings(config)
self.encoder = NatEncoder(config)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features

View File

@ -36,7 +36,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_resnet import ResNetConfig
@ -432,16 +432,12 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.num_features = [config.embedding_size] + config.hidden_sizes
self.embedder = ResNetEmbeddings(config)
self.encoder = ResNetEncoder(config)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.num_features = [config.embedding_size] + config.hidden_sizes
# initialize weights and apply final processing
self.post_init()

View File

@ -38,7 +38,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_swin import SwinConfig
@ -1259,17 +1259,12 @@ class SwinForImageClassification(SwinPreTrainedModel):
class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
def __init__(self, config: SwinConfig):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.embeddings = SwinEmbeddings(config)
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features
hidden_states_norms = {}
for stage, num_channels in zip(self._out_features, self.channels):

View File

@ -0,0 +1,49 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {"configuration_timm_backbone": ["TimmBackboneConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_timm_backbone"] = ["TimmBackbone"]
if TYPE_CHECKING:
from .configuration_timm_backbone import TimmBackboneConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_timm_backbone import TimmBackbone
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,78 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration for Backbone models"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class TimmBackboneConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration for a timm backbone [`TimmBackbone`].
It is used to instantiate a timm backbone model according to the specified arguments, defining the model.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
backbone (`str`, *optional*):
The timm checkpoint to load.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
features_only (`bool`, *optional*, defaults to `True`):
Whether to output only the features or also the logits.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use a pretrained backbone.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). Will default to the last stage if unset.
Example:
```python
>>> from transformers import TimmBackboneConfig, TimmBackbone
>>> # Initializing a timm backbone
>>> configuration = TimmBackboneConfig("resnet50")
>>> # Initializing a model from the configuration
>>> model = TimmBackbone(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "timm_backbone"
def __init__(
self,
backbone=None,
num_channels=3,
features_only=True,
use_pretrained_backbone=True,
out_indices=None,
**kwargs,
):
super().__init__(**kwargs)
self.backbone = backbone
self.num_channels = num_channels
self.features_only = features_only
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = True
self.out_indices = out_indices if out_indices is not None else (-1,)

View File

@ -0,0 +1,140 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...utils import is_timm_available, is_torch_available, requires_backends
from ...utils.backbone_utils import BackboneMixin
from .configuration_timm_backbone import TimmBackboneConfig
if is_timm_available():
import timm
if is_torch_available():
from torch import Tensor
class TimmBackbone(PreTrainedModel, BackboneMixin):
"""
Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the
other models in the library keeping the same API.
"""
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
config_class = TimmBackboneConfig
def __init__(self, config, **kwargs):
requires_backends(self, "timm")
super().__init__(config)
self.config = config
if config.backbone is None:
raise ValueError("backbone is not set in the config. Please set it to a timm model name.")
if config.backbone not in timm.list_models():
raise ValueError(f"backbone {config.backbone} is not supported by timm.")
if hasattr(config, "out_features") and config.out_features is not None:
raise ValueError("out_features is not supported by TimmBackbone. Please use out_indices instead.")
pretrained = getattr(config, "use_pretrained_backbone", None)
if pretrained is None:
raise ValueError("use_pretrained_backbone is not set in the config. Please set it to True or False.")
# We just take the final layer by default. This matches the default for the transformers models.
out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)
self._backbone = timm.create_model(
config.backbone,
pretrained=pretrained,
# This is currently not possible for transformer architectures.
features_only=config.features_only,
in_chans=config.num_channels,
out_indices=out_indices,
**kwargs,
)
# These are used to control the output of the model when called. If output_hidden_states is True, then
# return_layers is modified to include all layers.
self._return_layers = self._backbone.return_layers
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
super()._init_backbone(config)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
requires_backends(cls, ["vision", "timm"])
from ...models.timm_backbone import TimmBackboneConfig
config = kwargs.pop("config", TimmBackboneConfig())
use_timm = kwargs.pop("use_timm_backbone", True)
if not use_timm:
raise ValueError("use_timm_backbone must be True for timm backbones")
num_channels = kwargs.pop("num_channels", config.num_channels)
features_only = kwargs.pop("features_only", config.features_only)
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
out_indices = kwargs.pop("out_indices", config.out_indices)
config = TimmBackboneConfig(
backbone=pretrained_model_name_or_path,
num_channels=num_channels,
features_only=features_only,
use_pretrained_backbone=use_pretrained_backbone,
out_indices=out_indices,
)
return super()._from_config(config, **kwargs)
def _init_weights(self, module):
"""
Empty init weights function to ensure compatibility of the class in the library.
"""
pass
def forward(
self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs
) -> Union[BackboneOutput, Tuple[Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if output_attentions:
raise ValueError("Cannot output attentions for timm backbones at the moment")
if output_hidden_states:
# We modify the return layers to include all the stages of the backbone
self._backbone.return_layers = self._all_layers
hidden_states = self._backbone(pixel_values, **kwargs)
self._backbone.return_layers = self._return_layers
feature_maps = tuple(hidden_states[i] for i in self.out_indices)
else:
feature_maps = self._backbone(pixel_values, **kwargs)
hidden_states = None
feature_maps = tuple(feature_maps)
hidden_states = tuple(hidden_states) if hidden_states is not None else None
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output = output + (hidden_states,)
return output
return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None)

View File

@ -15,10 +15,16 @@
""" Collection of utils to be used by backbones and their components."""
import enum
import inspect
from typing import Iterable, List, Optional, Tuple, Union
class BackboneType(enum.Enum):
TIMM = "timm"
TRANSFORMERS = "transformers"
def verify_out_features_out_indices(
out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]]
):
@ -72,7 +78,7 @@ def _align_output_features_output_indices(
out_indices = [len(stage_names) - 1]
out_features = [stage_names[-1]]
elif out_indices is None and out_features is not None:
out_indices = [stage_names.index(layer) for layer in stage_names if layer in out_features]
out_indices = [stage_names.index(layer) for layer in out_features]
elif out_features is None and out_indices is not None:
out_features = [stage_names[idx] for idx in out_indices]
return out_features, out_indices
@ -110,29 +116,57 @@ def get_aligned_output_features_output_indices(
class BackboneMixin:
@property
def out_feature_channels(self):
# the current backbones will output the number of channels for each stage
# even if that stage is not in the out_features list.
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
backbone_type: Optional[BackboneType] = None
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def _init_timm_backbone(self, config) -> None:
"""
Initialize the backbone model from timm The backbone must already be loaded to self._backbone
"""
if getattr(self, "_backbone", None) is None:
raise ValueError("self._backbone must be set before calling _init_timm_backbone")
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
# These will diagree with the defaults for the transformers models e.g. for resnet50
# the transformer model has out_features = ['stem', 'stage1', 'stage2', 'stage3', 'stage4']
# the timm model has out_features = ['act', 'layer1', 'layer2', 'layer3', 'layer4']
self.stage_names = [stage["module"] for stage in self._backbone.feature_info.info]
self.num_features = [stage["num_chs"] for stage in self._backbone.feature_info.info]
out_indices = self._backbone.feature_info.out_indices
out_features = self._backbone.feature_info.module_name()
def forward(
self,
pixel_values,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
raise NotImplementedError("This method should be implemented by the derived class.")
# We verify the out indices and out features are valid
verify_out_features_out_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
self._out_features, self._out_indices = out_features, out_indices
def _init_transformers_backbone(self, config) -> None:
stage_names = getattr(config, "stage_names")
out_features = getattr(config, "out_features", None)
out_indices = getattr(config, "out_indices", None)
self.stage_names = stage_names
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=stage_names
)
# Number of channels for each stage. This is set in the transformer backbone model init
self.num_features = None
def _init_backbone(self, config) -> None:
"""
Method to initialize the backbone. This method is called by the constructor of the base class after the
pretrained model weights have been loaded.
"""
self.config = config
self.use_timm_backbone = getattr(config, "use_timm_backbone", False)
self.backbone_type = BackboneType.TIMM if self.use_timm_backbone else BackboneType.TRANSFORMERS
if self.backbone_type == BackboneType.TIMM:
self._init_timm_backbone(config)
elif self.backbone_type == BackboneType.TRANSFORMERS:
self._init_transformers_backbone(config)
else:
raise ValueError(f"backbone_type {self.backbone_type} not supported.")
@property
def out_features(self):
@ -160,6 +194,40 @@ class BackboneMixin:
out_features=None, out_indices=out_indices, stage_names=self.stage_names
)
@property
def out_feature_channels(self):
# the current backbones will output the number of channels for each stage
# even if that stage is not in the out_features list.
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
def forward(
self,
pixel_values,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
raise NotImplementedError("This method should be implemented by the derived class.")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to
include the `out_features` and `out_indices` attributes.
"""
output = super().to_dict()
output["out_features"] = output.pop("_out_features")
output["out_indices"] = output.pop("_out_indices")
return output
class BackboneConfigMixin:
"""

View File

@ -6806,6 +6806,13 @@ class TimesformerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class TimmBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@ -45,6 +45,7 @@ if is_torch_available():
from test_module.custom_modeling import CustomModel
from transformers import (
AutoBackbone,
AutoConfig,
AutoModel,
AutoModelForCausalLM,
@ -66,11 +67,13 @@ if is_torch_available():
FunnelModel,
GPT2Config,
GPT2LMHeadModel,
ResNetBackbone,
RobertaForMaskedLM,
T5Config,
T5ForConditionalGeneration,
TapasConfig,
TapasForQuestionAnswering,
TimmBackbone,
)
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING,
@ -224,6 +227,42 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model)
self.assertIsInstance(model, BertForTokenClassification)
@slow
def test_auto_backbone_timm_model_from_pretrained(self):
# Configs can't be loaded for timm models
model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True)
with pytest.raises(ValueError):
# We can't pass output_loading_info=True as we're loading from timm
AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, TimmBackbone)
# Check kwargs are correctly passed to the backbone
model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_indices=(-1, -2))
self.assertEqual(model.out_indices, (-1, -2))
# Check out_features cannot be passed to Timm backbones
with self.assertRaises(ValueError):
_ = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_features=["stage1"])
@slow
def test_auto_backbone_from_pretrained(self):
model = AutoBackbone.from_pretrained("microsoft/resnet-18")
model, loading_info = AutoBackbone.from_pretrained("microsoft/resnet-18", output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, ResNetBackbone)
# Check kwargs are correctly passed to the backbone
model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_indices=[-1, -2])
self.assertEqual(model.out_indices, [-1, -2])
self.assertEqual(model.out_features, ["stage4", "stage3"])
model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_features=["stage2", "stage4"])
self.assertEqual(model.out_indices, [2, 4])
self.assertEqual(model.out_features, ["stage2", "stage4"])
def test_from_pretrained_identifier(self):
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM)

View File

View File

@ -0,0 +1,259 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import unittest
from transformers import AutoBackbone
from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import require_timm, require_torch, torch_device
from transformers.utils.import_utils import is_torch_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
if is_torch_available():
import torch
from transformers import TimmBackbone, TimmBackboneConfig
class TimmBackboneModelTester:
def __init__(
self,
parent,
out_indices=None,
out_features=None,
stage_names=None,
backbone="resnet50",
batch_size=3,
image_size=32,
num_channels=3,
is_training=True,
use_pretrained_backbone=True,
):
self.parent = parent
self.out_indices = out_indices if out_indices is not None else [4]
self.stage_names = stage_names
self.out_features = out_features
self.backbone = backbone
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.use_pretrained_backbone = use_pretrained_backbone
self.is_training = is_training
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return TimmBackboneConfig(
image_size=self.image_size,
num_channels=self.num_channels,
out_features=self.out_features,
out_indices=self.out_indices,
stage_names=self.stage_names,
use_pretrained_backbone=self.use_pretrained_backbone,
backbone=self.backbone,
)
def create_and_check_model(self, config, pixel_values):
model = TimmBackbone(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
self.parent.assertEqual(
result.feature_map[-1].shape,
(self.batch_size, model.channels[-1], 14, 14),
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
@require_timm
class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, unittest.TestCase):
all_model_classes = (TimmBackbone,) if is_torch_available() else ()
test_resize_embeddings = False
test_head_masking = False
test_pruning = False
has_attentions = False
def setUp(self):
self.model_tester = TimmBackboneModelTester(self)
self.config_tester = ConfigTester(self, config_class=PretrainedConfig, has_text_modality=False)
def test_config(self):
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
def test_timm_transformer_backbone_equivalence(self):
timm_checkpoint = "resnet18"
transformers_checkpoint = "microsoft/resnet-18"
timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True)
transformers_model = AutoBackbone.from_pretrained(transformers_checkpoint)
self.assertEqual(len(timm_model.out_features), len(transformers_model.out_features))
self.assertEqual(len(timm_model.stage_names), len(transformers_model.stage_names))
self.assertEqual(timm_model.channels, transformers_model.channels)
# Out indices are set to the last layer by default. For timm models, we don't know
# the number of layers in advance, so we set it to (-1,), whereas for transformers
# models, we set it to [len(stage_names) - 1] (kept for backward compatibility).
self.assertEqual(timm_model.out_indices, (-1,))
self.assertEqual(transformers_model.out_indices, [len(timm_model.stage_names) - 1])
timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True, out_indices=[1, 2, 3])
transformers_model = AutoBackbone.from_pretrained(transformers_checkpoint, out_indices=[1, 2, 3])
self.assertEqual(timm_model.out_indices, transformers_model.out_indices)
self.assertEqual(len(timm_model.out_features), len(transformers_model.out_features))
self.assertEqual(timm_model.channels, transformers_model.channels)
@unittest.skip("TimmBackbone doesn't support feed forward chunking")
def test_feed_forward_chunking(self):
pass
@unittest.skip("TimmBackbone doesn't have num_hidden_layers attribute")
def test_hidden_states_output(self):
pass
@unittest.skip("TimmBackbone initialization is managed on the timm side")
def test_initialization(self):
pass
@unittest.skip("TimmBackbone models doesn't have inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip("TimmBackbone models doesn't have inputs_embeds")
def test_model_common_attributes(self):
pass
@unittest.skip("TimmBackbone model cannot be created without specifying a backbone checkpoint")
def test_from_pretrained_no_checkpoint(self):
pass
@unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone")
def test_save_load(self):
pass
@unittest.skip("model weights aren't tied in TimmBackbone.")
def test_tie_model_weights(self):
pass
@unittest.skip("model weights aren't tied in TimmBackbone.")
def test_tied_model_weights_key_ignore(self):
pass
@unittest.skip("TimmBackbone doesn't have hidden size info in its configuration.")
def test_channels(self):
pass
@unittest.skip("TimmBackbone doesn't support output_attentions.")
def test_torchscript_output_attentions(self):
pass
@unittest.skip("Safetensors is not supported by timm.")
def test_can_use_safetensors(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**inputs)
output = outputs[0][-1]
# Encoder-/Decoder-only models
hidden_states = outputs.hidden_states[0]
hidden_states.retain_grad()
if self.has_attentions:
attentions = outputs.attentions[0]
attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad)
if self.has_attentions:
self.assertIsNotNone(attentions.grad)
# TimmBackbone config doesn't have out_features attribute
def test_create_from_modified_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
result = model(**inputs_dict)
self.assertEqual(len(result.feature_maps), len(config.out_indices))
self.assertEqual(len(model.channels), len(config.out_indices))
# Check output of last stage is taken if out_features=None, out_indices=None
modified_config = copy.deepcopy(config)
modified_config.out_indices = None
model = model_class(modified_config)
model.to(torch_device)
model.eval()
result = model(**inputs_dict)
self.assertEqual(len(result.feature_maps), 1)
self.assertEqual(len(model.channels), 1)
# Check backbone can be initialized with fresh weights
modified_config = copy.deepcopy(config)
modified_config.use_pretrained_backbone = False
model = model_class(modified_config)
model.to(torch_device)
model.eval()
result = model(**inputs_dict)

View File

@ -17,6 +17,7 @@ import copy
import inspect
from transformers.testing_utils import require_torch, torch_device
from transformers.utils.backbone_utils import BackboneType
@require_torch
@ -104,6 +105,8 @@ class BackboneTesterMixin:
self.assertEqual(len(result.feature_maps), len(config.out_features))
self.assertEqual(len(model.channels), len(config.out_features))
self.assertEqual(len(result.feature_maps), len(config.out_indices))
self.assertEqual(len(model.channels), len(config.out_indices))
# Check output of last stage is taken if out_features=None, out_indices=None
modified_config = copy.deepcopy(config)
@ -140,6 +143,7 @@ class BackboneTesterMixin:
for backbone_class in self.all_model_classes:
backbone = backbone_class(config)
self.assertTrue(hasattr(backbone, "backbone_type"))
self.assertTrue(hasattr(backbone, "stage_names"))
self.assertTrue(hasattr(backbone, "num_features"))
self.assertTrue(hasattr(backbone, "out_indices"))
@ -147,6 +151,7 @@ class BackboneTesterMixin:
self.assertTrue(hasattr(backbone, "out_feature_channels"))
self.assertTrue(hasattr(backbone, "channels"))
self.assertIsInstance(backbone.backbone_type, BackboneType)
# Verify num_features has been initialized in the backbone init
self.assertIsNotNone(backbone.num_features)
self.assertTrue(len(backbone.channels) == len(backbone.out_indices))

View File

@ -77,6 +77,7 @@ SPECIAL_CASES_TO_ALLOW = {
"AutoformerConfig": ["num_static_real_features", "num_time_features"],
}
# TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure
SPECIAL_CASES_TO_ALLOW.update(
{
@ -172,6 +173,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"mask_index",
"image_size",
"use_cache",
"out_features",
"out_indices",
]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]

View File

@ -39,6 +39,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"EncoderDecoderConfig",
"RagConfig",
"SpeechEncoderDecoderConfig",
"TimmBackboneConfig",
"VisionEncoderDecoderConfig",
"VisionTextDualEncoderConfig",
"LlamaConfig",

View File

@ -517,6 +517,7 @@ MODELS_NOT_IN_README = [
"Speech Encoder decoder",
"Speech2Text",
"Speech2Text2",
"TimmBackbone",
"Vision Encoder decoder",
"VisionTextDualEncoder",
]

View File

@ -408,6 +408,7 @@ def get_model_modules():
"modeling_speech_encoder_decoder",
"modeling_flax_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder",
"modeling_timm_backbone",
"modeling_transfo_xl_utilities",
"modeling_tf_auto",
"modeling_tf_encoder_decoder",
@ -846,6 +847,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"NatBackbone",
"ResNetBackbone",
"SwinBackbone",
"TimmBackbone",
"TimmBackboneConfig",
]