From a717e0318ce4b4973cfebee44d1747f5ca828ac2 Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 6 Jun 2023 17:11:30 +0100 Subject: [PATCH] Add TimmBackbone model (#22619) * Add test_backbone for convnext * Add TimmBackbone model * Add check for backbone type * Tidying up - config checks * Update convnextv2 * Tidy up * Fix indices & clearer comment * Exceptions for config checks * Correclty update config for tests * Safer imports * Safer safer imports * Fix where decorators go * Update import logic and backbone tests * More import fixes * Fixup * Only import all_models if torch available * Fix kwarg updates in from_pretrained & main rebase * Tidy up * Add tests for AutoBackbone * Tidy up * Fix import error * Fix up * Install nattan in doc_test_job * Revert back to setting self._out_xxx directly * Bug fix - out_indices mapping from out_features * Fix tests * Dont accept output_loading_info for Timm models * Set out_xxx and don't remap * Use smaller checkpoint for test * Don't remap timm indices - check out_indices based on stage names * Skip test as it's n/a * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Cleaner imports / spelling is hard --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .circleci/create_circleci_config.py | 3 +- docs/source/en/index.mdx | 1 + src/transformers/__init__.py | 4 + src/transformers/models/__init__.py | 1 + src/transformers/models/auto/auto_factory.py | 44 ++- .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 6 +- src/transformers/models/bit/modeling_bit.py | 8 +- .../models/convnext/modeling_convnext.py | 8 +- .../models/convnextv2/modeling_convnextv2.py | 8 +- .../models/dinat/modeling_dinat.py | 9 +- .../models/focalnet/modeling_focalnet.py | 12 +- .../maskformer/modeling_maskformer_swin.py | 10 +- src/transformers/models/nat/modeling_nat.py | 9 +- .../models/resnet/modeling_resnet.py | 10 +- src/transformers/models/swin/modeling_swin.py | 11 +- .../models/timm_backbone/__init__.py | 49 ++++ .../configuration_timm_backbone.py | 78 ++++++ .../timm_backbone/modeling_timm_backbone.py | 140 ++++++++++ src/transformers/utils/backbone_utils.py | 110 ++++++-- src/transformers/utils/dummy_pt_objects.py | 7 + tests/models/auto/test_modeling_auto.py | 39 +++ tests/models/timm_backbone/__init__.py | 0 .../test_modeling_timm_backbone.py | 259 ++++++++++++++++++ tests/test_backbone_common.py | 5 + utils/check_config_attributes.py | 3 + utils/check_config_docstrings.py | 1 + utils/check_copies.py | 1 + utils/check_repo.py | 3 + 29 files changed, 753 insertions(+), 88 deletions(-) create mode 100644 src/transformers/models/timm_backbone/__init__.py create mode 100644 src/transformers/models/timm_backbone/configuration_timm_backbone.py create mode 100644 src/transformers/models/timm_backbone/modeling_timm_backbone.py create mode 100644 tests/models/timm_backbone/__init__.py create mode 100644 tests/models/timm_backbone/test_modeling_timm_backbone.py diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index 3f7da48770e..2a09d5f8542 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -128,7 +128,7 @@ class CircleCIJob: if self.command_timeout: test_command = f"timeout {self.command_timeout} " test_command += f"python -m pytest -n {self.pytest_num_workers} " + " ".join(pytest_flags) - + if self.parallelism == 1: if self.tests_to_run is None: test_command += " << pipeline.parameters.tests_to_run >>" @@ -456,6 +456,7 @@ doc_test_job = CircleCIJob( "pip install -e .[dev]", "pip install git+https://github.com/huggingface/accelerate", "pip install --upgrade pytest pytest-sugar", + "pip install natten", "find -name __pycache__ -delete", "find . -name \*.pyc -delete", # Add an empty file to keep the test step running correctly even no file is selected to be tested. diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index ee47a3ab5e6..6075c0ecf3d 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -424,6 +424,7 @@ Flax), PyTorch, and/or TensorFlow. | TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ | | Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ | | TimeSformer | ❌ | ❌ | ✅ | ❌ | ❌ | +| TimmBackbone | ❌ | ❌ | ❌ | ❌ | ❌ | | Trajectory Transformer | ❌ | ❌ | ✅ | ❌ | ❌ | | Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ | | TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index edb15c037b3..c7deecf7552 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -483,6 +483,7 @@ _import_structure = { "TimeSeriesTransformerConfig", ], "models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"], + "models.timm_backbone": ["TimmBackboneConfig"], "models.trajectory_transformer": [ "TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TrajectoryTransformerConfig", @@ -2578,6 +2579,7 @@ else: "TimesformerPreTrainedModel", ] ) + _import_structure["models.timm_backbone"].extend(["TimmBackbone"]) _import_structure["models.trajectory_transformer"].extend( [ "TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -4288,6 +4290,7 @@ if TYPE_CHECKING: TimeSeriesTransformerConfig, ) from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig + from .models.timm_backbone import TimmBackboneConfig from .models.trajectory_transformer import ( TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TrajectoryTransformerConfig, @@ -6024,6 +6027,7 @@ if TYPE_CHECKING: TimesformerModel, TimesformerPreTrainedModel, ) + from .models.timm_backbone import TimmBackbone from .models.trajectory_transformer import ( TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, TrajectoryTransformerModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 26e4643728c..0ee88e0ba39 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -186,6 +186,7 @@ from . import ( tapex, time_series_transformer, timesformer, + timm_backbone, trajectory_transformer, transfo_xl, trocr, diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index eedecb0da9c..89193411664 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -19,7 +19,7 @@ from collections import OrderedDict from ...configuration_utils import PretrainedConfig from ...dynamic_module_utils import get_class_from_dynamic_module -from ...utils import copy_func, logging +from ...utils import copy_func, logging, requires_backends from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings @@ -515,6 +515,48 @@ class _BaseAutoModelClass: cls._model_mapping.register(config_class, model_class) +class _BaseAutoBackboneClass(_BaseAutoModelClass): + # Base class for auto backbone models. + _model_mapping = None + + @classmethod + def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + requires_backends(cls, ["vision", "timm"]) + from ...models.timm_backbone import TimmBackboneConfig + + config = kwargs.pop("config", TimmBackboneConfig()) + + use_timm = kwargs.pop("use_timm_backbone", True) + if not use_timm: + raise ValueError("`use_timm_backbone` must be `True` for timm backbones") + + if kwargs.get("out_features", None) is not None: + raise ValueError("Cannot specify `out_features` for timm backbones") + + if kwargs.get("output_loading_info", False): + raise ValueError("Cannot specify `output_loading_info=True` when loading from timm") + + num_channels = kwargs.pop("num_channels", config.num_channels) + features_only = kwargs.pop("features_only", config.features_only) + use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone) + out_indices = kwargs.pop("out_indices", config.out_indices) + config = TimmBackboneConfig( + backbone=pretrained_model_name_or_path, + num_channels=num_channels, + features_only=features_only, + use_pretrained_backbone=use_pretrained_backbone, + out_indices=out_indices, + ) + return super().from_config(config, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + if kwargs.get("use_timm_backbone", False): + return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + def insert_head_doc(docstring, head_doc=""): if len(head_doc) > 0: return docstring.replace( diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 1b82bc40da1..197d2b2e14b 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -186,6 +186,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ("tapas", "TapasConfig"), ("time_series_transformer", "TimeSeriesTransformerConfig"), ("timesformer", "TimesformerConfig"), + ("timm_backbone", "TimmBackboneConfig"), ("trajectory_transformer", "TrajectoryTransformerConfig"), ("transfo-xl", "TransfoXLConfig"), ("trocr", "TrOCRConfig"), @@ -579,6 +580,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ("tapex", "TAPEX"), ("time_series_transformer", "Time Series Transformer"), ("timesformer", "TimeSformer"), + ("timm_backbone", "TimmBackbone"), ("trajectory_transformer", "Trajectory Transformer"), ("transfo-xl", "Transformer-XL"), ("trocr", "TrOCR"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 0a10eb96e9d..2c23da3c556 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -18,7 +18,7 @@ import warnings from collections import OrderedDict from ...utils import logging -from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update +from .auto_factory import _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update from .configuration_auto import CONFIG_MAPPING_NAMES @@ -179,6 +179,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("tapas", "TapasModel"), ("time_series_transformer", "TimeSeriesTransformerModel"), ("timesformer", "TimesformerModel"), + ("timm_backbone", "TimmBackbone"), ("trajectory_transformer", "TrajectoryTransformerModel"), ("transfo-xl", "TransfoXLModel"), ("tvlt", "TvltModel"), @@ -999,6 +1000,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ("nat", "NatBackbone"), ("resnet", "ResNetBackbone"), ("swin", "SwinBackbone"), + ("timm_backbone", "TimmBackbone"), ] ) @@ -1330,7 +1332,7 @@ class AutoModelForAudioXVector(_BaseAutoModelClass): _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING -class AutoBackbone(_BaseAutoModelClass): +class AutoBackbone(_BaseAutoBackboneClass): _model_mapping = MODEL_FOR_BACKBONE_MAPPING diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index d440f180757..284ff5e2de8 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -39,7 +39,7 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices +from ...utils.backbone_utils import BackboneMixin from .configuration_bit import BitConfig @@ -845,14 +845,10 @@ class BitForImageClassification(BitPreTrainedModel): class BitBackbone(BitPreTrainedModel, BackboneMixin): def __init__(self, config): super().__init__(config) + super()._init_backbone(config) - self.stage_names = config.stage_names self.bit = BitModel(config) - self.num_features = [config.embedding_size] + config.hidden_sizes - self._out_features, self._out_indices = get_aligned_output_features_output_indices( - config.out_features, config.out_indices, self.stage_names - ) # initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index 1748e68aeec..3733fb94140 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -37,7 +37,7 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices +from ...utils.backbone_utils import BackboneMixin from .configuration_convnext import ConvNextConfig @@ -481,15 +481,11 @@ class ConvNextForImageClassification(ConvNextPreTrainedModel): class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin): def __init__(self, config): super().__init__(config) + super()._init_backbone(config) - self.stage_names = config.stage_names self.embeddings = ConvNextEmbeddings(config) self.encoder = ConvNextEncoder(config) - self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes - self._out_features, self._out_indices = get_aligned_output_features_output_indices( - config.out_features, config.out_indices, self.stage_names - ) # Add layer norms to hidden states of out_features hidden_states_norms = {} diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index c4cac4eb39f..70c35a85af3 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -37,7 +37,7 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices +from ...utils.backbone_utils import BackboneMixin from .configuration_convnextv2 import ConvNextV2Config @@ -504,15 +504,11 @@ class ConvNextV2ForImageClassification(ConvNextV2PreTrainedModel): class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin): def __init__(self, config): super().__init__(config) + super()._init_backbone(config) - self.stage_names = config.stage_names self.embeddings = ConvNextV2Embeddings(config) self.encoder = ConvNextV2Encoder(config) - self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes - self._out_features, self._out_indices = get_aligned_output_features_output_indices( - config.out_features, config.out_indices, self.stage_names - ) # Add layer norms to hidden states of out_features hidden_states_norms = {} diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 7e3809c1a30..b15d7d187ed 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -39,7 +39,7 @@ from ...utils import ( replace_return_docstrings, requires_backends, ) -from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices +from ...utils.backbone_utils import BackboneMixin from .configuration_dinat import DinatConfig @@ -883,17 +883,12 @@ class DinatForImageClassification(DinatPreTrainedModel): class DinatBackbone(DinatPreTrainedModel, BackboneMixin): def __init__(self, config): super().__init__(config) + super()._init_backbone(config) requires_backends(self, ["natten"]) - self.stage_names = config.stage_names - self.embeddings = DinatEmbeddings(config) self.encoder = DinatEncoder(config) - - self._out_features, self._out_indices = get_aligned_output_features_output_indices( - config.out_features, config.out_indices, self.stage_names - ) self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] # Add layer norms to hidden states of out_features diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index e7ebdda5e5d..fc327ad0b39 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -36,7 +36,7 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices +from ...utils.backbone_utils import BackboneMixin from .configuration_focalnet import FocalNetConfig @@ -981,16 +981,12 @@ class FocalNetForImageClassification(FocalNetPreTrainedModel): FOCALNET_START_DOCSTRING, ) class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin): - def __init__(self, config): + def __init__(self, config: FocalNetConfig): super().__init__(config) - - self.stage_names = config.stage_names - self.focalnet = FocalNetModel(config) + super()._init_backbone(config) self.num_features = [config.embed_dim] + config.hidden_sizes - self._out_features, self._out_indices = get_aligned_output_features_output_indices( - config.out_features, config.out_indices, self.stage_names - ) + self.focalnet = FocalNetModel(config) # initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index c7b74a6f2bd..7016b598e85 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -29,7 +29,7 @@ from ...file_utils import ModelOutput from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer -from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices +from ...utils.backbone_utils import BackboneMixin from .configuration_maskformer_swin import MaskFormerSwinConfig @@ -852,17 +852,11 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin): def __init__(self, config: MaskFormerSwinConfig): super().__init__(config) + super()._init_backbone(config) - self.stage_names = config.stage_names self.model = MaskFormerSwinModel(config) - - self._out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]] if "stem" in self.out_features: raise ValueError("This backbone does not support 'stem' in the `out_features`.") - - self._out_features, self._out_indices = get_aligned_output_features_output_indices( - config.out_features, config.out_indices, self.stage_names - ) self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.hidden_states_norms = nn.ModuleList( [nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]] diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py index 7634a08ad95..2293661f2b4 100644 --- a/src/transformers/models/nat/modeling_nat.py +++ b/src/transformers/models/nat/modeling_nat.py @@ -39,7 +39,7 @@ from ...utils import ( replace_return_docstrings, requires_backends, ) -from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices +from ...utils.backbone_utils import BackboneMixin from .configuration_nat import NatConfig @@ -861,17 +861,12 @@ class NatForImageClassification(NatPreTrainedModel): class NatBackbone(NatPreTrainedModel, BackboneMixin): def __init__(self, config): super().__init__(config) + super()._init_backbone(config) requires_backends(self, ["natten"]) - self.stage_names = config.stage_names - self.embeddings = NatEmbeddings(config) self.encoder = NatEncoder(config) - - self._out_features, self._out_indices = get_aligned_output_features_output_indices( - config.out_features, config.out_indices, self.stage_names - ) self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] # Add layer norms to hidden states of out_features diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index b177cdeda6c..207a0d5196a 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -36,7 +36,7 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices +from ...utils.backbone_utils import BackboneMixin from .configuration_resnet import ResNetConfig @@ -432,16 +432,12 @@ class ResNetForImageClassification(ResNetPreTrainedModel): class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin): def __init__(self, config): super().__init__(config) + super()._init_backbone(config) - self.stage_names = config.stage_names + self.num_features = [config.embedding_size] + config.hidden_sizes self.embedder = ResNetEmbeddings(config) self.encoder = ResNetEncoder(config) - self._out_features, self._out_indices = get_aligned_output_features_output_indices( - config.out_features, config.out_indices, self.stage_names - ) - self.num_features = [config.embedding_size] + config.hidden_sizes - # initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 6482ff1b5bf..b324cfdcd93 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -38,7 +38,7 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices +from ...utils.backbone_utils import BackboneMixin from .configuration_swin import SwinConfig @@ -1259,17 +1259,12 @@ class SwinForImageClassification(SwinPreTrainedModel): class SwinBackbone(SwinPreTrainedModel, BackboneMixin): def __init__(self, config: SwinConfig): super().__init__(config) + super()._init_backbone(config) - self.stage_names = config.stage_names - + self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] self.embeddings = SwinEmbeddings(config) self.encoder = SwinEncoder(config, self.embeddings.patch_grid) - self._out_features, self._out_indices = get_aligned_output_features_output_indices( - config.out_features, config.out_indices, self.stage_names - ) - self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))] - # Add layer norms to hidden states of out_features hidden_states_norms = {} for stage, num_channels in zip(self._out_features, self.channels): diff --git a/src/transformers/models/timm_backbone/__init__.py b/src/transformers/models/timm_backbone/__init__.py new file mode 100644 index 00000000000..4c692f76432 --- /dev/null +++ b/src/transformers/models/timm_backbone/__init__.py @@ -0,0 +1,49 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + + +_import_structure = {"configuration_timm_backbone": ["TimmBackboneConfig"]} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_timm_backbone"] = ["TimmBackbone"] + + +if TYPE_CHECKING: + from .configuration_timm_backbone import TimmBackboneConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_timm_backbone import TimmBackbone + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/timm_backbone/configuration_timm_backbone.py b/src/transformers/models/timm_backbone/configuration_timm_backbone.py new file mode 100644 index 00000000000..19bfcbebf62 --- /dev/null +++ b/src/transformers/models/timm_backbone/configuration_timm_backbone.py @@ -0,0 +1,78 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Configuration for Backbone models""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class TimmBackboneConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration for a timm backbone [`TimmBackbone`]. + + It is used to instantiate a timm backbone model according to the specified arguments, defining the model. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone (`str`, *optional*): + The timm checkpoint to load. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + features_only (`bool`, *optional*, defaults to `True`): + Whether to output only the features or also the logits. + use_pretrained_backbone (`bool`, *optional*, defaults to `True`): + Whether to use a pretrained backbone. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). Will default to the last stage if unset. + + Example: + ```python + >>> from transformers import TimmBackboneConfig, TimmBackbone + + >>> # Initializing a timm backbone + >>> configuration = TimmBackboneConfig("resnet50") + + >>> # Initializing a model from the configuration + >>> model = TimmBackbone(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + model_type = "timm_backbone" + + def __init__( + self, + backbone=None, + num_channels=3, + features_only=True, + use_pretrained_backbone=True, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + self.backbone = backbone + self.num_channels = num_channels + self.features_only = features_only + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = True + self.out_indices = out_indices if out_indices is not None else (-1,) diff --git a/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/src/transformers/models/timm_backbone/modeling_timm_backbone.py new file mode 100644 index 00000000000..2dcfca7e538 --- /dev/null +++ b/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -0,0 +1,140 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union + +from ...modeling_outputs import BackboneOutput +from ...modeling_utils import PreTrainedModel +from ...utils import is_timm_available, is_torch_available, requires_backends +from ...utils.backbone_utils import BackboneMixin +from .configuration_timm_backbone import TimmBackboneConfig + + +if is_timm_available(): + import timm + + +if is_torch_available(): + from torch import Tensor + + +class TimmBackbone(PreTrainedModel, BackboneMixin): + """ + Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the + other models in the library keeping the same API. + """ + + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + config_class = TimmBackboneConfig + + def __init__(self, config, **kwargs): + requires_backends(self, "timm") + super().__init__(config) + self.config = config + + if config.backbone is None: + raise ValueError("backbone is not set in the config. Please set it to a timm model name.") + + if config.backbone not in timm.list_models(): + raise ValueError(f"backbone {config.backbone} is not supported by timm.") + + if hasattr(config, "out_features") and config.out_features is not None: + raise ValueError("out_features is not supported by TimmBackbone. Please use out_indices instead.") + + pretrained = getattr(config, "use_pretrained_backbone", None) + if pretrained is None: + raise ValueError("use_pretrained_backbone is not set in the config. Please set it to True or False.") + + # We just take the final layer by default. This matches the default for the transformers models. + out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,) + + self._backbone = timm.create_model( + config.backbone, + pretrained=pretrained, + # This is currently not possible for transformer architectures. + features_only=config.features_only, + in_chans=config.num_channels, + out_indices=out_indices, + **kwargs, + ) + # These are used to control the output of the model when called. If output_hidden_states is True, then + # return_layers is modified to include all layers. + self._return_layers = self._backbone.return_layers + self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)} + super()._init_backbone(config) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + requires_backends(cls, ["vision", "timm"]) + from ...models.timm_backbone import TimmBackboneConfig + + config = kwargs.pop("config", TimmBackboneConfig()) + + use_timm = kwargs.pop("use_timm_backbone", True) + if not use_timm: + raise ValueError("use_timm_backbone must be True for timm backbones") + + num_channels = kwargs.pop("num_channels", config.num_channels) + features_only = kwargs.pop("features_only", config.features_only) + use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone) + out_indices = kwargs.pop("out_indices", config.out_indices) + config = TimmBackboneConfig( + backbone=pretrained_model_name_or_path, + num_channels=num_channels, + features_only=features_only, + use_pretrained_backbone=use_pretrained_backbone, + out_indices=out_indices, + ) + return super()._from_config(config, **kwargs) + + def _init_weights(self, module): + """ + Empty init weights function to ensure compatibility of the class in the library. + """ + pass + + def forward( + self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs + ) -> Union[BackboneOutput, Tuple[Tensor, ...]]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + if output_attentions: + raise ValueError("Cannot output attentions for timm backbones at the moment") + + if output_hidden_states: + # We modify the return layers to include all the stages of the backbone + self._backbone.return_layers = self._all_layers + hidden_states = self._backbone(pixel_values, **kwargs) + self._backbone.return_layers = self._return_layers + feature_maps = tuple(hidden_states[i] for i in self.out_indices) + else: + feature_maps = self._backbone(pixel_values, **kwargs) + hidden_states = None + + feature_maps = tuple(feature_maps) + hidden_states = tuple(hidden_states) if hidden_states is not None else None + + if not return_dict: + output = (feature_maps,) + if output_hidden_states: + output = output + (hidden_states,) + return output + + return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None) diff --git a/src/transformers/utils/backbone_utils.py b/src/transformers/utils/backbone_utils.py index 8c6b7107eb0..595aae18832 100644 --- a/src/transformers/utils/backbone_utils.py +++ b/src/transformers/utils/backbone_utils.py @@ -15,10 +15,16 @@ """ Collection of utils to be used by backbones and their components.""" +import enum import inspect from typing import Iterable, List, Optional, Tuple, Union +class BackboneType(enum.Enum): + TIMM = "timm" + TRANSFORMERS = "transformers" + + def verify_out_features_out_indices( out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]] ): @@ -72,7 +78,7 @@ def _align_output_features_output_indices( out_indices = [len(stage_names) - 1] out_features = [stage_names[-1]] elif out_indices is None and out_features is not None: - out_indices = [stage_names.index(layer) for layer in stage_names if layer in out_features] + out_indices = [stage_names.index(layer) for layer in out_features] elif out_features is None and out_indices is not None: out_features = [stage_names[idx] for idx in out_indices] return out_features, out_indices @@ -110,29 +116,57 @@ def get_aligned_output_features_output_indices( class BackboneMixin: - @property - def out_feature_channels(self): - # the current backbones will output the number of channels for each stage - # even if that stage is not in the out_features list. - return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)} + backbone_type: Optional[BackboneType] = None - @property - def channels(self): - return [self.out_feature_channels[name] for name in self.out_features] + def _init_timm_backbone(self, config) -> None: + """ + Initialize the backbone model from timm The backbone must already be loaded to self._backbone + """ + if getattr(self, "_backbone", None) is None: + raise ValueError("self._backbone must be set before calling _init_timm_backbone") - def forward_with_filtered_kwargs(self, *args, **kwargs): - signature = dict(inspect.signature(self.forward).parameters) - filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature} - return self(*args, **filtered_kwargs) + # These will diagree with the defaults for the transformers models e.g. for resnet50 + # the transformer model has out_features = ['stem', 'stage1', 'stage2', 'stage3', 'stage4'] + # the timm model has out_features = ['act', 'layer1', 'layer2', 'layer3', 'layer4'] + self.stage_names = [stage["module"] for stage in self._backbone.feature_info.info] + self.num_features = [stage["num_chs"] for stage in self._backbone.feature_info.info] + out_indices = self._backbone.feature_info.out_indices + out_features = self._backbone.feature_info.module_name() - def forward( - self, - pixel_values, - output_hidden_states: Optional[bool] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - raise NotImplementedError("This method should be implemented by the derived class.") + # We verify the out indices and out features are valid + verify_out_features_out_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + self._out_features, self._out_indices = out_features, out_indices + + def _init_transformers_backbone(self, config) -> None: + stage_names = getattr(config, "stage_names") + out_features = getattr(config, "out_features", None) + out_indices = getattr(config, "out_indices", None) + + self.stage_names = stage_names + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=stage_names + ) + # Number of channels for each stage. This is set in the transformer backbone model init + self.num_features = None + + def _init_backbone(self, config) -> None: + """ + Method to initialize the backbone. This method is called by the constructor of the base class after the + pretrained model weights have been loaded. + """ + self.config = config + + self.use_timm_backbone = getattr(config, "use_timm_backbone", False) + self.backbone_type = BackboneType.TIMM if self.use_timm_backbone else BackboneType.TRANSFORMERS + + if self.backbone_type == BackboneType.TIMM: + self._init_timm_backbone(config) + elif self.backbone_type == BackboneType.TRANSFORMERS: + self._init_transformers_backbone(config) + else: + raise ValueError(f"backbone_type {self.backbone_type} not supported.") @property def out_features(self): @@ -160,6 +194,40 @@ class BackboneMixin: out_features=None, out_indices=out_indices, stage_names=self.stage_names ) + @property + def out_feature_channels(self): + # the current backbones will output the number of channels for each stage + # even if that stage is not in the out_features list. + return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)} + + @property + def channels(self): + return [self.out_feature_channels[name] for name in self.out_features] + + def forward_with_filtered_kwargs(self, *args, **kwargs): + signature = dict(inspect.signature(self.forward).parameters) + filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature} + return self(*args, **filtered_kwargs) + + def forward( + self, + pixel_values, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + raise NotImplementedError("This method should be implemented by the derived class.") + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to + include the `out_features` and `out_indices` attributes. + """ + output = super().to_dict() + output["out_features"] = output.pop("_out_features") + output["out_indices"] = output.pop("_out_indices") + return output + class BackboneConfigMixin: """ diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 069edb3ba85..1eba8c0ca75 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6806,6 +6806,13 @@ class TimesformerPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["torch"]) +class TimmBackbone(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 26eecd54299..9c788a61556 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -45,6 +45,7 @@ if is_torch_available(): from test_module.custom_modeling import CustomModel from transformers import ( + AutoBackbone, AutoConfig, AutoModel, AutoModelForCausalLM, @@ -66,11 +67,13 @@ if is_torch_available(): FunnelModel, GPT2Config, GPT2LMHeadModel, + ResNetBackbone, RobertaForMaskedLM, T5Config, T5ForConditionalGeneration, TapasConfig, TapasForQuestionAnswering, + TimmBackbone, ) from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING, @@ -224,6 +227,42 @@ class AutoModelTest(unittest.TestCase): self.assertIsNotNone(model) self.assertIsInstance(model, BertForTokenClassification) + @slow + def test_auto_backbone_timm_model_from_pretrained(self): + # Configs can't be loaded for timm models + model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True) + + with pytest.raises(ValueError): + # We can't pass output_loading_info=True as we're loading from timm + AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, output_loading_info=True) + + self.assertIsNotNone(model) + self.assertIsInstance(model, TimmBackbone) + + # Check kwargs are correctly passed to the backbone + model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_indices=(-1, -2)) + self.assertEqual(model.out_indices, (-1, -2)) + + # Check out_features cannot be passed to Timm backbones + with self.assertRaises(ValueError): + _ = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_features=["stage1"]) + + @slow + def test_auto_backbone_from_pretrained(self): + model = AutoBackbone.from_pretrained("microsoft/resnet-18") + model, loading_info = AutoBackbone.from_pretrained("microsoft/resnet-18", output_loading_info=True) + self.assertIsNotNone(model) + self.assertIsInstance(model, ResNetBackbone) + + # Check kwargs are correctly passed to the backbone + model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_indices=[-1, -2]) + self.assertEqual(model.out_indices, [-1, -2]) + self.assertEqual(model.out_features, ["stage4", "stage3"]) + + model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_features=["stage2", "stage4"]) + self.assertEqual(model.out_indices, [2, 4]) + self.assertEqual(model.out_features, ["stage2", "stage4"]) + def test_from_pretrained_identifier(self): model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) self.assertIsInstance(model, BertForMaskedLM) diff --git a/tests/models/timm_backbone/__init__.py b/tests/models/timm_backbone/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py new file mode 100644 index 00000000000..f58716e0f2f --- /dev/null +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -0,0 +1,259 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import unittest + +from transformers import AutoBackbone +from transformers.configuration_utils import PretrainedConfig +from transformers.testing_utils import require_timm, require_torch, torch_device +from transformers.utils.import_utils import is_torch_available + +from ...test_backbone_common import BackboneTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor + + +if is_torch_available(): + import torch + + from transformers import TimmBackbone, TimmBackboneConfig + + +class TimmBackboneModelTester: + def __init__( + self, + parent, + out_indices=None, + out_features=None, + stage_names=None, + backbone="resnet50", + batch_size=3, + image_size=32, + num_channels=3, + is_training=True, + use_pretrained_backbone=True, + ): + self.parent = parent + self.out_indices = out_indices if out_indices is not None else [4] + self.stage_names = stage_names + self.out_features = out_features + self.backbone = backbone + self.batch_size = batch_size + self.image_size = image_size + self.num_channels = num_channels + self.use_pretrained_backbone = use_pretrained_backbone + self.is_training = is_training + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) + config = self.get_config() + + return config, pixel_values + + def get_config(self): + return TimmBackboneConfig( + image_size=self.image_size, + num_channels=self.num_channels, + out_features=self.out_features, + out_indices=self.out_indices, + stage_names=self.stage_names, + use_pretrained_backbone=self.use_pretrained_backbone, + backbone=self.backbone, + ) + + def create_and_check_model(self, config, pixel_values): + model = TimmBackbone(config=config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + result = model(pixel_values) + self.parent.assertEqual( + result.feature_map[-1].shape, + (self.batch_size, model.channels[-1], 14, 14), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +@require_timm +class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, unittest.TestCase): + all_model_classes = (TimmBackbone,) if is_torch_available() else () + test_resize_embeddings = False + test_head_masking = False + test_pruning = False + has_attentions = False + + def setUp(self): + self.model_tester = TimmBackboneModelTester(self) + self.config_tester = ConfigTester(self, config_class=PretrainedConfig, has_text_modality=False) + + def test_config(self): + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + def test_timm_transformer_backbone_equivalence(self): + timm_checkpoint = "resnet18" + transformers_checkpoint = "microsoft/resnet-18" + + timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True) + transformers_model = AutoBackbone.from_pretrained(transformers_checkpoint) + + self.assertEqual(len(timm_model.out_features), len(transformers_model.out_features)) + self.assertEqual(len(timm_model.stage_names), len(transformers_model.stage_names)) + self.assertEqual(timm_model.channels, transformers_model.channels) + # Out indices are set to the last layer by default. For timm models, we don't know + # the number of layers in advance, so we set it to (-1,), whereas for transformers + # models, we set it to [len(stage_names) - 1] (kept for backward compatibility). + self.assertEqual(timm_model.out_indices, (-1,)) + self.assertEqual(transformers_model.out_indices, [len(timm_model.stage_names) - 1]) + + timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True, out_indices=[1, 2, 3]) + transformers_model = AutoBackbone.from_pretrained(transformers_checkpoint, out_indices=[1, 2, 3]) + + self.assertEqual(timm_model.out_indices, transformers_model.out_indices) + self.assertEqual(len(timm_model.out_features), len(transformers_model.out_features)) + self.assertEqual(timm_model.channels, transformers_model.channels) + + @unittest.skip("TimmBackbone doesn't support feed forward chunking") + def test_feed_forward_chunking(self): + pass + + @unittest.skip("TimmBackbone doesn't have num_hidden_layers attribute") + def test_hidden_states_output(self): + pass + + @unittest.skip("TimmBackbone initialization is managed on the timm side") + def test_initialization(self): + pass + + @unittest.skip("TimmBackbone models doesn't have inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip("TimmBackbone models doesn't have inputs_embeds") + def test_model_common_attributes(self): + pass + + @unittest.skip("TimmBackbone model cannot be created without specifying a backbone checkpoint") + def test_from_pretrained_no_checkpoint(self): + pass + + @unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone") + def test_save_load(self): + pass + + @unittest.skip("model weights aren't tied in TimmBackbone.") + def test_tie_model_weights(self): + pass + + @unittest.skip("model weights aren't tied in TimmBackbone.") + def test_tied_model_weights_key_ignore(self): + pass + + @unittest.skip("TimmBackbone doesn't have hidden size info in its configuration.") + def test_channels(self): + pass + + @unittest.skip("TimmBackbone doesn't support output_attentions.") + def test_torchscript_output_attentions(self): + pass + + @unittest.skip("Safetensors is not supported by timm.") + def test_can_use_safetensors(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + outputs = model(**inputs) + output = outputs[0][-1] + + # Encoder-/Decoder-only models + hidden_states = outputs.hidden_states[0] + hidden_states.retain_grad() + + if self.has_attentions: + attentions = outputs.attentions[0] + attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(hidden_states.grad) + + if self.has_attentions: + self.assertIsNotNone(attentions.grad) + + # TimmBackbone config doesn't have out_features attribute + def test_create_from_modified_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + result = model(**inputs_dict) + + self.assertEqual(len(result.feature_maps), len(config.out_indices)) + self.assertEqual(len(model.channels), len(config.out_indices)) + + # Check output of last stage is taken if out_features=None, out_indices=None + modified_config = copy.deepcopy(config) + modified_config.out_indices = None + model = model_class(modified_config) + model.to(torch_device) + model.eval() + result = model(**inputs_dict) + + self.assertEqual(len(result.feature_maps), 1) + self.assertEqual(len(model.channels), 1) + + # Check backbone can be initialized with fresh weights + modified_config = copy.deepcopy(config) + modified_config.use_pretrained_backbone = False + model = model_class(modified_config) + model.to(torch_device) + model.eval() + result = model(**inputs_dict) diff --git a/tests/test_backbone_common.py b/tests/test_backbone_common.py index fd9bbe3bfbf..1700ab98eed 100644 --- a/tests/test_backbone_common.py +++ b/tests/test_backbone_common.py @@ -17,6 +17,7 @@ import copy import inspect from transformers.testing_utils import require_torch, torch_device +from transformers.utils.backbone_utils import BackboneType @require_torch @@ -104,6 +105,8 @@ class BackboneTesterMixin: self.assertEqual(len(result.feature_maps), len(config.out_features)) self.assertEqual(len(model.channels), len(config.out_features)) + self.assertEqual(len(result.feature_maps), len(config.out_indices)) + self.assertEqual(len(model.channels), len(config.out_indices)) # Check output of last stage is taken if out_features=None, out_indices=None modified_config = copy.deepcopy(config) @@ -140,6 +143,7 @@ class BackboneTesterMixin: for backbone_class in self.all_model_classes: backbone = backbone_class(config) + self.assertTrue(hasattr(backbone, "backbone_type")) self.assertTrue(hasattr(backbone, "stage_names")) self.assertTrue(hasattr(backbone, "num_features")) self.assertTrue(hasattr(backbone, "out_indices")) @@ -147,6 +151,7 @@ class BackboneTesterMixin: self.assertTrue(hasattr(backbone, "out_feature_channels")) self.assertTrue(hasattr(backbone, "channels")) + self.assertIsInstance(backbone.backbone_type, BackboneType) # Verify num_features has been initialized in the backbone init self.assertIsNotNone(backbone.num_features) self.assertTrue(len(backbone.channels) == len(backbone.out_indices)) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 929f3a51b1c..02c3d2276f5 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -77,6 +77,7 @@ SPECIAL_CASES_TO_ALLOW = { "AutoformerConfig": ["num_static_real_features", "num_time_features"], } + # TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure SPECIAL_CASES_TO_ALLOW.update( { @@ -172,6 +173,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s "mask_index", "image_size", "use_cache", + "out_features", + "out_indices", ] attributes_used_in_generation = ["encoder_no_repeat_ngram_size"] diff --git a/utils/check_config_docstrings.py b/utils/check_config_docstrings.py index a4a18a66293..93385b127d7 100644 --- a/utils/check_config_docstrings.py +++ b/utils/check_config_docstrings.py @@ -39,6 +39,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = { "EncoderDecoderConfig", "RagConfig", "SpeechEncoderDecoderConfig", + "TimmBackboneConfig", "VisionEncoderDecoderConfig", "VisionTextDualEncoderConfig", "LlamaConfig", diff --git a/utils/check_copies.py b/utils/check_copies.py index 5fa0c8bfb98..bd79aa9e86c 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -517,6 +517,7 @@ MODELS_NOT_IN_README = [ "Speech Encoder decoder", "Speech2Text", "Speech2Text2", + "TimmBackbone", "Vision Encoder decoder", "VisionTextDualEncoder", ] diff --git a/utils/check_repo.py b/utils/check_repo.py index 8d1760d1335..12615550a8d 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -408,6 +408,7 @@ def get_model_modules(): "modeling_speech_encoder_decoder", "modeling_flax_speech_encoder_decoder", "modeling_flax_vision_encoder_decoder", + "modeling_timm_backbone", "modeling_transfo_xl_utilities", "modeling_tf_auto", "modeling_tf_encoder_decoder", @@ -846,6 +847,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ "NatBackbone", "ResNetBackbone", "SwinBackbone", + "TimmBackbone", + "TimmBackboneConfig", ]