mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 05:40:05 +06:00
Add TimmBackbone model (#22619)
* Add test_backbone for convnext * Add TimmBackbone model * Add check for backbone type * Tidying up - config checks * Update convnextv2 * Tidy up * Fix indices & clearer comment * Exceptions for config checks * Correclty update config for tests * Safer imports * Safer safer imports * Fix where decorators go * Update import logic and backbone tests * More import fixes * Fixup * Only import all_models if torch available * Fix kwarg updates in from_pretrained & main rebase * Tidy up * Add tests for AutoBackbone * Tidy up * Fix import error * Fix up * Install nattan in doc_test_job * Revert back to setting self._out_xxx directly * Bug fix - out_indices mapping from out_features * Fix tests * Dont accept output_loading_info for Timm models * Set out_xxx and don't remap * Use smaller checkpoint for test * Don't remap timm indices - check out_indices based on stage names * Skip test as it's n/a * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Cleaner imports / spelling is hard --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
b8935980a2
commit
a717e0318c
@ -456,6 +456,7 @@ doc_test_job = CircleCIJob(
|
|||||||
"pip install -e .[dev]",
|
"pip install -e .[dev]",
|
||||||
"pip install git+https://github.com/huggingface/accelerate",
|
"pip install git+https://github.com/huggingface/accelerate",
|
||||||
"pip install --upgrade pytest pytest-sugar",
|
"pip install --upgrade pytest pytest-sugar",
|
||||||
|
"pip install natten",
|
||||||
"find -name __pycache__ -delete",
|
"find -name __pycache__ -delete",
|
||||||
"find . -name \*.pyc -delete",
|
"find . -name \*.pyc -delete",
|
||||||
# Add an empty file to keep the test step running correctly even no file is selected to be tested.
|
# Add an empty file to keep the test step running correctly even no file is selected to be tested.
|
||||||
|
@ -424,6 +424,7 @@ Flax), PyTorch, and/or TensorFlow.
|
|||||||
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| TimeSformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| TimeSformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
| TimmBackbone | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| Trajectory Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| Trajectory Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
@ -483,6 +483,7 @@ _import_structure = {
|
|||||||
"TimeSeriesTransformerConfig",
|
"TimeSeriesTransformerConfig",
|
||||||
],
|
],
|
||||||
"models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"],
|
"models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"],
|
||||||
|
"models.timm_backbone": ["TimmBackboneConfig"],
|
||||||
"models.trajectory_transformer": [
|
"models.trajectory_transformer": [
|
||||||
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
"TrajectoryTransformerConfig",
|
"TrajectoryTransformerConfig",
|
||||||
@ -2578,6 +2579,7 @@ else:
|
|||||||
"TimesformerPreTrainedModel",
|
"TimesformerPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
_import_structure["models.timm_backbone"].extend(["TimmBackbone"])
|
||||||
_import_structure["models.trajectory_transformer"].extend(
|
_import_structure["models.trajectory_transformer"].extend(
|
||||||
[
|
[
|
||||||
"TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@ -4288,6 +4290,7 @@ if TYPE_CHECKING:
|
|||||||
TimeSeriesTransformerConfig,
|
TimeSeriesTransformerConfig,
|
||||||
)
|
)
|
||||||
from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig
|
from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig
|
||||||
|
from .models.timm_backbone import TimmBackboneConfig
|
||||||
from .models.trajectory_transformer import (
|
from .models.trajectory_transformer import (
|
||||||
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
TrajectoryTransformerConfig,
|
TrajectoryTransformerConfig,
|
||||||
@ -6024,6 +6027,7 @@ if TYPE_CHECKING:
|
|||||||
TimesformerModel,
|
TimesformerModel,
|
||||||
TimesformerPreTrainedModel,
|
TimesformerPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
from .models.timm_backbone import TimmBackbone
|
||||||
from .models.trajectory_transformer import (
|
from .models.trajectory_transformer import (
|
||||||
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TrajectoryTransformerModel,
|
TrajectoryTransformerModel,
|
||||||
|
@ -186,6 +186,7 @@ from . import (
|
|||||||
tapex,
|
tapex,
|
||||||
time_series_transformer,
|
time_series_transformer,
|
||||||
timesformer,
|
timesformer,
|
||||||
|
timm_backbone,
|
||||||
trajectory_transformer,
|
trajectory_transformer,
|
||||||
transfo_xl,
|
transfo_xl,
|
||||||
trocr,
|
trocr,
|
||||||
|
@ -19,7 +19,7 @@ from collections import OrderedDict
|
|||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...dynamic_module_utils import get_class_from_dynamic_module
|
from ...dynamic_module_utils import get_class_from_dynamic_module
|
||||||
from ...utils import copy_func, logging
|
from ...utils import copy_func, logging, requires_backends
|
||||||
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
|
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
|
||||||
|
|
||||||
|
|
||||||
@ -515,6 +515,48 @@ class _BaseAutoModelClass:
|
|||||||
cls._model_mapping.register(config_class, model_class)
|
cls._model_mapping.register(config_class, model_class)
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseAutoBackboneClass(_BaseAutoModelClass):
|
||||||
|
# Base class for auto backbone models.
|
||||||
|
_model_mapping = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
requires_backends(cls, ["vision", "timm"])
|
||||||
|
from ...models.timm_backbone import TimmBackboneConfig
|
||||||
|
|
||||||
|
config = kwargs.pop("config", TimmBackboneConfig())
|
||||||
|
|
||||||
|
use_timm = kwargs.pop("use_timm_backbone", True)
|
||||||
|
if not use_timm:
|
||||||
|
raise ValueError("`use_timm_backbone` must be `True` for timm backbones")
|
||||||
|
|
||||||
|
if kwargs.get("out_features", None) is not None:
|
||||||
|
raise ValueError("Cannot specify `out_features` for timm backbones")
|
||||||
|
|
||||||
|
if kwargs.get("output_loading_info", False):
|
||||||
|
raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
|
||||||
|
|
||||||
|
num_channels = kwargs.pop("num_channels", config.num_channels)
|
||||||
|
features_only = kwargs.pop("features_only", config.features_only)
|
||||||
|
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
|
||||||
|
out_indices = kwargs.pop("out_indices", config.out_indices)
|
||||||
|
config = TimmBackboneConfig(
|
||||||
|
backbone=pretrained_model_name_or_path,
|
||||||
|
num_channels=num_channels,
|
||||||
|
features_only=features_only,
|
||||||
|
use_pretrained_backbone=use_pretrained_backbone,
|
||||||
|
out_indices=out_indices,
|
||||||
|
)
|
||||||
|
return super().from_config(config, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
if kwargs.get("use_timm_backbone", False):
|
||||||
|
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def insert_head_doc(docstring, head_doc=""):
|
def insert_head_doc(docstring, head_doc=""):
|
||||||
if len(head_doc) > 0:
|
if len(head_doc) > 0:
|
||||||
return docstring.replace(
|
return docstring.replace(
|
||||||
|
@ -186,6 +186,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
|
|||||||
("tapas", "TapasConfig"),
|
("tapas", "TapasConfig"),
|
||||||
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
("time_series_transformer", "TimeSeriesTransformerConfig"),
|
||||||
("timesformer", "TimesformerConfig"),
|
("timesformer", "TimesformerConfig"),
|
||||||
|
("timm_backbone", "TimmBackboneConfig"),
|
||||||
("trajectory_transformer", "TrajectoryTransformerConfig"),
|
("trajectory_transformer", "TrajectoryTransformerConfig"),
|
||||||
("transfo-xl", "TransfoXLConfig"),
|
("transfo-xl", "TransfoXLConfig"),
|
||||||
("trocr", "TrOCRConfig"),
|
("trocr", "TrOCRConfig"),
|
||||||
@ -579,6 +580,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||||||
("tapex", "TAPEX"),
|
("tapex", "TAPEX"),
|
||||||
("time_series_transformer", "Time Series Transformer"),
|
("time_series_transformer", "Time Series Transformer"),
|
||||||
("timesformer", "TimeSformer"),
|
("timesformer", "TimeSformer"),
|
||||||
|
("timm_backbone", "TimmBackbone"),
|
||||||
("trajectory_transformer", "Trajectory Transformer"),
|
("trajectory_transformer", "Trajectory Transformer"),
|
||||||
("transfo-xl", "Transformer-XL"),
|
("transfo-xl", "Transformer-XL"),
|
||||||
("trocr", "TrOCR"),
|
("trocr", "TrOCR"),
|
||||||
|
@ -18,7 +18,7 @@ import warnings
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
from .auto_factory import _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
|
||||||
from .configuration_auto import CONFIG_MAPPING_NAMES
|
from .configuration_auto import CONFIG_MAPPING_NAMES
|
||||||
|
|
||||||
|
|
||||||
@ -179,6 +179,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("tapas", "TapasModel"),
|
("tapas", "TapasModel"),
|
||||||
("time_series_transformer", "TimeSeriesTransformerModel"),
|
("time_series_transformer", "TimeSeriesTransformerModel"),
|
||||||
("timesformer", "TimesformerModel"),
|
("timesformer", "TimesformerModel"),
|
||||||
|
("timm_backbone", "TimmBackbone"),
|
||||||
("trajectory_transformer", "TrajectoryTransformerModel"),
|
("trajectory_transformer", "TrajectoryTransformerModel"),
|
||||||
("transfo-xl", "TransfoXLModel"),
|
("transfo-xl", "TransfoXLModel"),
|
||||||
("tvlt", "TvltModel"),
|
("tvlt", "TvltModel"),
|
||||||
@ -999,6 +1000,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
|
|||||||
("nat", "NatBackbone"),
|
("nat", "NatBackbone"),
|
||||||
("resnet", "ResNetBackbone"),
|
("resnet", "ResNetBackbone"),
|
||||||
("swin", "SwinBackbone"),
|
("swin", "SwinBackbone"),
|
||||||
|
("timm_backbone", "TimmBackbone"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1330,7 +1332,7 @@ class AutoModelForAudioXVector(_BaseAutoModelClass):
|
|||||||
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
|
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
|
||||||
|
|
||||||
|
|
||||||
class AutoBackbone(_BaseAutoModelClass):
|
class AutoBackbone(_BaseAutoBackboneClass):
|
||||||
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
|
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
from .configuration_bit import BitConfig
|
from .configuration_bit import BitConfig
|
||||||
|
|
||||||
|
|
||||||
@ -845,14 +845,10 @@ class BitForImageClassification(BitPreTrainedModel):
|
|||||||
class BitBackbone(BitPreTrainedModel, BackboneMixin):
|
class BitBackbone(BitPreTrainedModel, BackboneMixin):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
super()._init_backbone(config)
|
||||||
|
|
||||||
self.stage_names = config.stage_names
|
|
||||||
self.bit = BitModel(config)
|
self.bit = BitModel(config)
|
||||||
|
|
||||||
self.num_features = [config.embedding_size] + config.hidden_sizes
|
self.num_features = [config.embedding_size] + config.hidden_sizes
|
||||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
|
||||||
config.out_features, config.out_indices, self.stage_names
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize weights and apply final processing
|
# initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
@ -37,7 +37,7 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
from .configuration_convnext import ConvNextConfig
|
from .configuration_convnext import ConvNextConfig
|
||||||
|
|
||||||
|
|
||||||
@ -481,15 +481,11 @@ class ConvNextForImageClassification(ConvNextPreTrainedModel):
|
|||||||
class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
|
class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
super()._init_backbone(config)
|
||||||
|
|
||||||
self.stage_names = config.stage_names
|
|
||||||
self.embeddings = ConvNextEmbeddings(config)
|
self.embeddings = ConvNextEmbeddings(config)
|
||||||
self.encoder = ConvNextEncoder(config)
|
self.encoder = ConvNextEncoder(config)
|
||||||
|
|
||||||
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
|
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
|
||||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
|
||||||
config.out_features, config.out_indices, self.stage_names
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add layer norms to hidden states of out_features
|
# Add layer norms to hidden states of out_features
|
||||||
hidden_states_norms = {}
|
hidden_states_norms = {}
|
||||||
|
@ -37,7 +37,7 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
from .configuration_convnextv2 import ConvNextV2Config
|
from .configuration_convnextv2 import ConvNextV2Config
|
||||||
|
|
||||||
|
|
||||||
@ -504,15 +504,11 @@ class ConvNextV2ForImageClassification(ConvNextV2PreTrainedModel):
|
|||||||
class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
|
class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
super()._init_backbone(config)
|
||||||
|
|
||||||
self.stage_names = config.stage_names
|
|
||||||
self.embeddings = ConvNextV2Embeddings(config)
|
self.embeddings = ConvNextV2Embeddings(config)
|
||||||
self.encoder = ConvNextV2Encoder(config)
|
self.encoder = ConvNextV2Encoder(config)
|
||||||
|
|
||||||
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
|
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
|
||||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
|
||||||
config.out_features, config.out_indices, self.stage_names
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add layer norms to hidden states of out_features
|
# Add layer norms to hidden states of out_features
|
||||||
hidden_states_norms = {}
|
hidden_states_norms = {}
|
||||||
|
@ -39,7 +39,7 @@ from ...utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
from .configuration_dinat import DinatConfig
|
from .configuration_dinat import DinatConfig
|
||||||
|
|
||||||
|
|
||||||
@ -883,17 +883,12 @@ class DinatForImageClassification(DinatPreTrainedModel):
|
|||||||
class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
|
class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
super()._init_backbone(config)
|
||||||
|
|
||||||
requires_backends(self, ["natten"])
|
requires_backends(self, ["natten"])
|
||||||
|
|
||||||
self.stage_names = config.stage_names
|
|
||||||
|
|
||||||
self.embeddings = DinatEmbeddings(config)
|
self.embeddings = DinatEmbeddings(config)
|
||||||
self.encoder = DinatEncoder(config)
|
self.encoder = DinatEncoder(config)
|
||||||
|
|
||||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
|
||||||
config.out_features, config.out_indices, self.stage_names
|
|
||||||
)
|
|
||||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||||
|
|
||||||
# Add layer norms to hidden states of out_features
|
# Add layer norms to hidden states of out_features
|
||||||
|
@ -36,7 +36,7 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
from .configuration_focalnet import FocalNetConfig
|
from .configuration_focalnet import FocalNetConfig
|
||||||
|
|
||||||
|
|
||||||
@ -981,16 +981,12 @@ class FocalNetForImageClassification(FocalNetPreTrainedModel):
|
|||||||
FOCALNET_START_DOCSTRING,
|
FOCALNET_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
|
class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
|
||||||
def __init__(self, config):
|
def __init__(self, config: FocalNetConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
super()._init_backbone(config)
|
||||||
self.stage_names = config.stage_names
|
|
||||||
self.focalnet = FocalNetModel(config)
|
|
||||||
|
|
||||||
self.num_features = [config.embed_dim] + config.hidden_sizes
|
self.num_features = [config.embed_dim] + config.hidden_sizes
|
||||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
self.focalnet = FocalNetModel(config)
|
||||||
config.out_features, config.out_indices, self.stage_names
|
|
||||||
)
|
|
||||||
|
|
||||||
# initialize weights and apply final processing
|
# initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
@ -29,7 +29,7 @@ from ...file_utils import ModelOutput
|
|||||||
from ...modeling_outputs import BackboneOutput
|
from ...modeling_outputs import BackboneOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
||||||
|
|
||||||
|
|
||||||
@ -852,17 +852,11 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
|
|||||||
|
|
||||||
def __init__(self, config: MaskFormerSwinConfig):
|
def __init__(self, config: MaskFormerSwinConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
super()._init_backbone(config)
|
||||||
|
|
||||||
self.stage_names = config.stage_names
|
|
||||||
self.model = MaskFormerSwinModel(config)
|
self.model = MaskFormerSwinModel(config)
|
||||||
|
|
||||||
self._out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
|
||||||
if "stem" in self.out_features:
|
if "stem" in self.out_features:
|
||||||
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
|
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
|
||||||
|
|
||||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
|
||||||
config.out_features, config.out_indices, self.stage_names
|
|
||||||
)
|
|
||||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||||
self.hidden_states_norms = nn.ModuleList(
|
self.hidden_states_norms = nn.ModuleList(
|
||||||
[nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]]
|
[nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]]
|
||||||
|
@ -39,7 +39,7 @@ from ...utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
from .configuration_nat import NatConfig
|
from .configuration_nat import NatConfig
|
||||||
|
|
||||||
|
|
||||||
@ -861,17 +861,12 @@ class NatForImageClassification(NatPreTrainedModel):
|
|||||||
class NatBackbone(NatPreTrainedModel, BackboneMixin):
|
class NatBackbone(NatPreTrainedModel, BackboneMixin):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
super()._init_backbone(config)
|
||||||
|
|
||||||
requires_backends(self, ["natten"])
|
requires_backends(self, ["natten"])
|
||||||
|
|
||||||
self.stage_names = config.stage_names
|
|
||||||
|
|
||||||
self.embeddings = NatEmbeddings(config)
|
self.embeddings = NatEmbeddings(config)
|
||||||
self.encoder = NatEncoder(config)
|
self.encoder = NatEncoder(config)
|
||||||
|
|
||||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
|
||||||
config.out_features, config.out_indices, self.stage_names
|
|
||||||
)
|
|
||||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||||
|
|
||||||
# Add layer norms to hidden states of out_features
|
# Add layer norms to hidden states of out_features
|
||||||
|
@ -36,7 +36,7 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
from .configuration_resnet import ResNetConfig
|
from .configuration_resnet import ResNetConfig
|
||||||
|
|
||||||
|
|
||||||
@ -432,16 +432,12 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
|
|||||||
class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
|
class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
super()._init_backbone(config)
|
||||||
|
|
||||||
self.stage_names = config.stage_names
|
self.num_features = [config.embedding_size] + config.hidden_sizes
|
||||||
self.embedder = ResNetEmbeddings(config)
|
self.embedder = ResNetEmbeddings(config)
|
||||||
self.encoder = ResNetEncoder(config)
|
self.encoder = ResNetEncoder(config)
|
||||||
|
|
||||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
|
||||||
config.out_features, config.out_indices, self.stage_names
|
|
||||||
)
|
|
||||||
self.num_features = [config.embedding_size] + config.hidden_sizes
|
|
||||||
|
|
||||||
# initialize weights and apply final processing
|
# initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
from .configuration_swin import SwinConfig
|
from .configuration_swin import SwinConfig
|
||||||
|
|
||||||
|
|
||||||
@ -1259,17 +1259,12 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
|||||||
class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
|
class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
|
||||||
def __init__(self, config: SwinConfig):
|
def __init__(self, config: SwinConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
super()._init_backbone(config)
|
||||||
|
|
||||||
self.stage_names = config.stage_names
|
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||||
|
|
||||||
self.embeddings = SwinEmbeddings(config)
|
self.embeddings = SwinEmbeddings(config)
|
||||||
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
|
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
|
||||||
|
|
||||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
|
||||||
config.out_features, config.out_indices, self.stage_names
|
|
||||||
)
|
|
||||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
|
||||||
|
|
||||||
# Add layer norms to hidden states of out_features
|
# Add layer norms to hidden states of out_features
|
||||||
hidden_states_norms = {}
|
hidden_states_norms = {}
|
||||||
for stage, num_channels in zip(self._out_features, self.channels):
|
for stage, num_channels in zip(self._out_features, self.channels):
|
||||||
|
49
src/transformers/models/timm_backbone/__init__.py
Normal file
49
src/transformers/models/timm_backbone/__init__.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
_import_structure = {"configuration_timm_backbone": ["TimmBackboneConfig"]}
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_import_structure["modeling_timm_backbone"] = ["TimmBackbone"]
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_timm_backbone import TimmBackboneConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from .modeling_timm_backbone import TimmBackbone
|
||||||
|
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
@ -0,0 +1,78 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
""" Configuration for Backbone models"""
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TimmBackboneConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration for a timm backbone [`TimmBackbone`].
|
||||||
|
|
||||||
|
It is used to instantiate a timm backbone model according to the specified arguments, defining the model.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backbone (`str`, *optional*):
|
||||||
|
The timm checkpoint to load.
|
||||||
|
num_channels (`int`, *optional*, defaults to 3):
|
||||||
|
The number of input channels.
|
||||||
|
features_only (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to output only the features or also the logits.
|
||||||
|
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use a pretrained backbone.
|
||||||
|
out_indices (`List[int]`, *optional*):
|
||||||
|
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
|
||||||
|
many stages the model has). Will default to the last stage if unset.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> from transformers import TimmBackboneConfig, TimmBackbone
|
||||||
|
|
||||||
|
>>> # Initializing a timm backbone
|
||||||
|
>>> configuration = TimmBackboneConfig("resnet50")
|
||||||
|
|
||||||
|
>>> # Initializing a model from the configuration
|
||||||
|
>>> model = TimmBackbone(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
model_type = "timm_backbone"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone=None,
|
||||||
|
num_channels=3,
|
||||||
|
features_only=True,
|
||||||
|
use_pretrained_backbone=True,
|
||||||
|
out_indices=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.backbone = backbone
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.features_only = features_only
|
||||||
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
|
self.use_timm_backbone = True
|
||||||
|
self.out_indices = out_indices if out_indices is not None else (-1,)
|
140
src/transformers/models/timm_backbone/modeling_timm_backbone.py
Normal file
140
src/transformers/models/timm_backbone/modeling_timm_backbone.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
from ...modeling_outputs import BackboneOutput
|
||||||
|
from ...modeling_utils import PreTrainedModel
|
||||||
|
from ...utils import is_timm_available, is_torch_available, requires_backends
|
||||||
|
from ...utils.backbone_utils import BackboneMixin
|
||||||
|
from .configuration_timm_backbone import TimmBackboneConfig
|
||||||
|
|
||||||
|
|
||||||
|
if is_timm_available():
|
||||||
|
import timm
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class TimmBackbone(PreTrainedModel, BackboneMixin):
|
||||||
|
"""
|
||||||
|
Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the
|
||||||
|
other models in the library keeping the same API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
supports_gradient_checkpointing = False
|
||||||
|
config_class = TimmBackboneConfig
|
||||||
|
|
||||||
|
def __init__(self, config, **kwargs):
|
||||||
|
requires_backends(self, "timm")
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if config.backbone is None:
|
||||||
|
raise ValueError("backbone is not set in the config. Please set it to a timm model name.")
|
||||||
|
|
||||||
|
if config.backbone not in timm.list_models():
|
||||||
|
raise ValueError(f"backbone {config.backbone} is not supported by timm.")
|
||||||
|
|
||||||
|
if hasattr(config, "out_features") and config.out_features is not None:
|
||||||
|
raise ValueError("out_features is not supported by TimmBackbone. Please use out_indices instead.")
|
||||||
|
|
||||||
|
pretrained = getattr(config, "use_pretrained_backbone", None)
|
||||||
|
if pretrained is None:
|
||||||
|
raise ValueError("use_pretrained_backbone is not set in the config. Please set it to True or False.")
|
||||||
|
|
||||||
|
# We just take the final layer by default. This matches the default for the transformers models.
|
||||||
|
out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)
|
||||||
|
|
||||||
|
self._backbone = timm.create_model(
|
||||||
|
config.backbone,
|
||||||
|
pretrained=pretrained,
|
||||||
|
# This is currently not possible for transformer architectures.
|
||||||
|
features_only=config.features_only,
|
||||||
|
in_chans=config.num_channels,
|
||||||
|
out_indices=out_indices,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
# These are used to control the output of the model when called. If output_hidden_states is True, then
|
||||||
|
# return_layers is modified to include all layers.
|
||||||
|
self._return_layers = self._backbone.return_layers
|
||||||
|
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
|
||||||
|
super()._init_backbone(config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
requires_backends(cls, ["vision", "timm"])
|
||||||
|
from ...models.timm_backbone import TimmBackboneConfig
|
||||||
|
|
||||||
|
config = kwargs.pop("config", TimmBackboneConfig())
|
||||||
|
|
||||||
|
use_timm = kwargs.pop("use_timm_backbone", True)
|
||||||
|
if not use_timm:
|
||||||
|
raise ValueError("use_timm_backbone must be True for timm backbones")
|
||||||
|
|
||||||
|
num_channels = kwargs.pop("num_channels", config.num_channels)
|
||||||
|
features_only = kwargs.pop("features_only", config.features_only)
|
||||||
|
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
|
||||||
|
out_indices = kwargs.pop("out_indices", config.out_indices)
|
||||||
|
config = TimmBackboneConfig(
|
||||||
|
backbone=pretrained_model_name_or_path,
|
||||||
|
num_channels=num_channels,
|
||||||
|
features_only=features_only,
|
||||||
|
use_pretrained_backbone=use_pretrained_backbone,
|
||||||
|
out_indices=out_indices,
|
||||||
|
)
|
||||||
|
return super()._from_config(config, **kwargs)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
"""
|
||||||
|
Empty init weights function to ensure compatibility of the class in the library.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs
|
||||||
|
) -> Union[BackboneOutput, Tuple[Tensor, ...]]:
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
raise ValueError("Cannot output attentions for timm backbones at the moment")
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
# We modify the return layers to include all the stages of the backbone
|
||||||
|
self._backbone.return_layers = self._all_layers
|
||||||
|
hidden_states = self._backbone(pixel_values, **kwargs)
|
||||||
|
self._backbone.return_layers = self._return_layers
|
||||||
|
feature_maps = tuple(hidden_states[i] for i in self.out_indices)
|
||||||
|
else:
|
||||||
|
feature_maps = self._backbone(pixel_values, **kwargs)
|
||||||
|
hidden_states = None
|
||||||
|
|
||||||
|
feature_maps = tuple(feature_maps)
|
||||||
|
hidden_states = tuple(hidden_states) if hidden_states is not None else None
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (feature_maps,)
|
||||||
|
if output_hidden_states:
|
||||||
|
output = output + (hidden_states,)
|
||||||
|
return output
|
||||||
|
|
||||||
|
return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None)
|
@ -15,10 +15,16 @@
|
|||||||
|
|
||||||
""" Collection of utils to be used by backbones and their components."""
|
""" Collection of utils to be used by backbones and their components."""
|
||||||
|
|
||||||
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Iterable, List, Optional, Tuple, Union
|
from typing import Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneType(enum.Enum):
|
||||||
|
TIMM = "timm"
|
||||||
|
TRANSFORMERS = "transformers"
|
||||||
|
|
||||||
|
|
||||||
def verify_out_features_out_indices(
|
def verify_out_features_out_indices(
|
||||||
out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]]
|
out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]]
|
||||||
):
|
):
|
||||||
@ -72,7 +78,7 @@ def _align_output_features_output_indices(
|
|||||||
out_indices = [len(stage_names) - 1]
|
out_indices = [len(stage_names) - 1]
|
||||||
out_features = [stage_names[-1]]
|
out_features = [stage_names[-1]]
|
||||||
elif out_indices is None and out_features is not None:
|
elif out_indices is None and out_features is not None:
|
||||||
out_indices = [stage_names.index(layer) for layer in stage_names if layer in out_features]
|
out_indices = [stage_names.index(layer) for layer in out_features]
|
||||||
elif out_features is None and out_indices is not None:
|
elif out_features is None and out_indices is not None:
|
||||||
out_features = [stage_names[idx] for idx in out_indices]
|
out_features = [stage_names[idx] for idx in out_indices]
|
||||||
return out_features, out_indices
|
return out_features, out_indices
|
||||||
@ -110,29 +116,57 @@ def get_aligned_output_features_output_indices(
|
|||||||
|
|
||||||
|
|
||||||
class BackboneMixin:
|
class BackboneMixin:
|
||||||
@property
|
backbone_type: Optional[BackboneType] = None
|
||||||
def out_feature_channels(self):
|
|
||||||
# the current backbones will output the number of channels for each stage
|
|
||||||
# even if that stage is not in the out_features list.
|
|
||||||
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
|
|
||||||
|
|
||||||
@property
|
def _init_timm_backbone(self, config) -> None:
|
||||||
def channels(self):
|
"""
|
||||||
return [self.out_feature_channels[name] for name in self.out_features]
|
Initialize the backbone model from timm The backbone must already be loaded to self._backbone
|
||||||
|
"""
|
||||||
|
if getattr(self, "_backbone", None) is None:
|
||||||
|
raise ValueError("self._backbone must be set before calling _init_timm_backbone")
|
||||||
|
|
||||||
def forward_with_filtered_kwargs(self, *args, **kwargs):
|
# These will diagree with the defaults for the transformers models e.g. for resnet50
|
||||||
signature = dict(inspect.signature(self.forward).parameters)
|
# the transformer model has out_features = ['stem', 'stage1', 'stage2', 'stage3', 'stage4']
|
||||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
|
# the timm model has out_features = ['act', 'layer1', 'layer2', 'layer3', 'layer4']
|
||||||
return self(*args, **filtered_kwargs)
|
self.stage_names = [stage["module"] for stage in self._backbone.feature_info.info]
|
||||||
|
self.num_features = [stage["num_chs"] for stage in self._backbone.feature_info.info]
|
||||||
|
out_indices = self._backbone.feature_info.out_indices
|
||||||
|
out_features = self._backbone.feature_info.module_name()
|
||||||
|
|
||||||
def forward(
|
# We verify the out indices and out features are valid
|
||||||
self,
|
verify_out_features_out_indices(
|
||||||
pixel_values,
|
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||||
output_hidden_states: Optional[bool] = None,
|
)
|
||||||
output_attentions: Optional[bool] = None,
|
self._out_features, self._out_indices = out_features, out_indices
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
):
|
def _init_transformers_backbone(self, config) -> None:
|
||||||
raise NotImplementedError("This method should be implemented by the derived class.")
|
stage_names = getattr(config, "stage_names")
|
||||||
|
out_features = getattr(config, "out_features", None)
|
||||||
|
out_indices = getattr(config, "out_indices", None)
|
||||||
|
|
||||||
|
self.stage_names = stage_names
|
||||||
|
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||||
|
out_features=out_features, out_indices=out_indices, stage_names=stage_names
|
||||||
|
)
|
||||||
|
# Number of channels for each stage. This is set in the transformer backbone model init
|
||||||
|
self.num_features = None
|
||||||
|
|
||||||
|
def _init_backbone(self, config) -> None:
|
||||||
|
"""
|
||||||
|
Method to initialize the backbone. This method is called by the constructor of the base class after the
|
||||||
|
pretrained model weights have been loaded.
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.use_timm_backbone = getattr(config, "use_timm_backbone", False)
|
||||||
|
self.backbone_type = BackboneType.TIMM if self.use_timm_backbone else BackboneType.TRANSFORMERS
|
||||||
|
|
||||||
|
if self.backbone_type == BackboneType.TIMM:
|
||||||
|
self._init_timm_backbone(config)
|
||||||
|
elif self.backbone_type == BackboneType.TRANSFORMERS:
|
||||||
|
self._init_transformers_backbone(config)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"backbone_type {self.backbone_type} not supported.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def out_features(self):
|
def out_features(self):
|
||||||
@ -160,6 +194,40 @@ class BackboneMixin:
|
|||||||
out_features=None, out_indices=out_indices, stage_names=self.stage_names
|
out_features=None, out_indices=out_indices, stage_names=self.stage_names
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def out_feature_channels(self):
|
||||||
|
# the current backbones will output the number of channels for each stage
|
||||||
|
# even if that stage is not in the out_features list.
|
||||||
|
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self):
|
||||||
|
return [self.out_feature_channels[name] for name in self.out_features]
|
||||||
|
|
||||||
|
def forward_with_filtered_kwargs(self, *args, **kwargs):
|
||||||
|
signature = dict(inspect.signature(self.forward).parameters)
|
||||||
|
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
|
||||||
|
return self(*args, **filtered_kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
raise NotImplementedError("This method should be implemented by the derived class.")
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""
|
||||||
|
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to
|
||||||
|
include the `out_features` and `out_indices` attributes.
|
||||||
|
"""
|
||||||
|
output = super().to_dict()
|
||||||
|
output["out_features"] = output.pop("_out_features")
|
||||||
|
output["out_indices"] = output.pop("_out_indices")
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class BackboneConfigMixin:
|
class BackboneConfigMixin:
|
||||||
"""
|
"""
|
||||||
|
@ -6806,6 +6806,13 @@ class TimesformerPreTrainedModel(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class TimmBackbone(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ if is_torch_available():
|
|||||||
from test_module.custom_modeling import CustomModel
|
from test_module.custom_modeling import CustomModel
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
AutoBackbone,
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@ -66,11 +67,13 @@ if is_torch_available():
|
|||||||
FunnelModel,
|
FunnelModel,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
|
ResNetBackbone,
|
||||||
RobertaForMaskedLM,
|
RobertaForMaskedLM,
|
||||||
T5Config,
|
T5Config,
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
TapasConfig,
|
TapasConfig,
|
||||||
TapasForQuestionAnswering,
|
TapasForQuestionAnswering,
|
||||||
|
TimmBackbone,
|
||||||
)
|
)
|
||||||
from transformers.models.auto.modeling_auto import (
|
from transformers.models.auto.modeling_auto import (
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||||
@ -224,6 +227,42 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertIsInstance(model, BertForTokenClassification)
|
self.assertIsInstance(model, BertForTokenClassification)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_auto_backbone_timm_model_from_pretrained(self):
|
||||||
|
# Configs can't be loaded for timm models
|
||||||
|
model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# We can't pass output_loading_info=True as we're loading from timm
|
||||||
|
AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, output_loading_info=True)
|
||||||
|
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, TimmBackbone)
|
||||||
|
|
||||||
|
# Check kwargs are correctly passed to the backbone
|
||||||
|
model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_indices=(-1, -2))
|
||||||
|
self.assertEqual(model.out_indices, (-1, -2))
|
||||||
|
|
||||||
|
# Check out_features cannot be passed to Timm backbones
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_features=["stage1"])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_auto_backbone_from_pretrained(self):
|
||||||
|
model = AutoBackbone.from_pretrained("microsoft/resnet-18")
|
||||||
|
model, loading_info = AutoBackbone.from_pretrained("microsoft/resnet-18", output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, ResNetBackbone)
|
||||||
|
|
||||||
|
# Check kwargs are correctly passed to the backbone
|
||||||
|
model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_indices=[-1, -2])
|
||||||
|
self.assertEqual(model.out_indices, [-1, -2])
|
||||||
|
self.assertEqual(model.out_features, ["stage4", "stage3"])
|
||||||
|
|
||||||
|
model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_features=["stage2", "stage4"])
|
||||||
|
self.assertEqual(model.out_indices, [2, 4])
|
||||||
|
self.assertEqual(model.out_features, ["stage2", "stage4"])
|
||||||
|
|
||||||
def test_from_pretrained_identifier(self):
|
def test_from_pretrained_identifier(self):
|
||||||
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
|
||||||
self.assertIsInstance(model, BertForMaskedLM)
|
self.assertIsInstance(model, BertForMaskedLM)
|
||||||
|
0
tests/models/timm_backbone/__init__.py
Normal file
0
tests/models/timm_backbone/__init__.py
Normal file
259
tests/models/timm_backbone/test_modeling_timm_backbone.py
Normal file
259
tests/models/timm_backbone/test_modeling_timm_backbone.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import inspect
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import AutoBackbone
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.testing_utils import require_timm, require_torch, torch_device
|
||||||
|
from transformers.utils.import_utils import is_torch_available
|
||||||
|
|
||||||
|
from ...test_backbone_common import BackboneTesterMixin
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import TimmBackbone, TimmBackboneConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TimmBackboneModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
out_indices=None,
|
||||||
|
out_features=None,
|
||||||
|
stage_names=None,
|
||||||
|
backbone="resnet50",
|
||||||
|
batch_size=3,
|
||||||
|
image_size=32,
|
||||||
|
num_channels=3,
|
||||||
|
is_training=True,
|
||||||
|
use_pretrained_backbone=True,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.out_indices = out_indices if out_indices is not None else [4]
|
||||||
|
self.stage_names = stage_names
|
||||||
|
self.out_features = out_features
|
||||||
|
self.backbone = backbone
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.image_size = image_size
|
||||||
|
self.num_channels = num_channels
|
||||||
|
self.use_pretrained_backbone = use_pretrained_backbone
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||||
|
config = self.get_config()
|
||||||
|
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return TimmBackboneConfig(
|
||||||
|
image_size=self.image_size,
|
||||||
|
num_channels=self.num_channels,
|
||||||
|
out_features=self.out_features,
|
||||||
|
out_indices=self.out_indices,
|
||||||
|
stage_names=self.stage_names,
|
||||||
|
use_pretrained_backbone=self.use_pretrained_backbone,
|
||||||
|
backbone=self.backbone,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, pixel_values):
|
||||||
|
model = TimmBackbone(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
result = model(pixel_values)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.feature_map[-1].shape,
|
||||||
|
(self.batch_size, model.channels[-1], 14, 14),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {"pixel_values": pixel_values}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_timm
|
||||||
|
class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (TimmBackbone,) if is_torch_available() else ()
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_head_masking = False
|
||||||
|
test_pruning = False
|
||||||
|
has_attentions = False
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = TimmBackboneModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=PretrainedConfig, has_text_modality=False)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.create_and_test_config_to_json_string()
|
||||||
|
self.config_tester.create_and_test_config_to_json_file()
|
||||||
|
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||||
|
self.config_tester.create_and_test_config_with_num_labels()
|
||||||
|
self.config_tester.check_config_can_be_init_without_params()
|
||||||
|
self.config_tester.check_config_arguments_init()
|
||||||
|
|
||||||
|
def test_timm_transformer_backbone_equivalence(self):
|
||||||
|
timm_checkpoint = "resnet18"
|
||||||
|
transformers_checkpoint = "microsoft/resnet-18"
|
||||||
|
|
||||||
|
timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True)
|
||||||
|
transformers_model = AutoBackbone.from_pretrained(transformers_checkpoint)
|
||||||
|
|
||||||
|
self.assertEqual(len(timm_model.out_features), len(transformers_model.out_features))
|
||||||
|
self.assertEqual(len(timm_model.stage_names), len(transformers_model.stage_names))
|
||||||
|
self.assertEqual(timm_model.channels, transformers_model.channels)
|
||||||
|
# Out indices are set to the last layer by default. For timm models, we don't know
|
||||||
|
# the number of layers in advance, so we set it to (-1,), whereas for transformers
|
||||||
|
# models, we set it to [len(stage_names) - 1] (kept for backward compatibility).
|
||||||
|
self.assertEqual(timm_model.out_indices, (-1,))
|
||||||
|
self.assertEqual(transformers_model.out_indices, [len(timm_model.stage_names) - 1])
|
||||||
|
|
||||||
|
timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True, out_indices=[1, 2, 3])
|
||||||
|
transformers_model = AutoBackbone.from_pretrained(transformers_checkpoint, out_indices=[1, 2, 3])
|
||||||
|
|
||||||
|
self.assertEqual(timm_model.out_indices, transformers_model.out_indices)
|
||||||
|
self.assertEqual(len(timm_model.out_features), len(transformers_model.out_features))
|
||||||
|
self.assertEqual(timm_model.channels, transformers_model.channels)
|
||||||
|
|
||||||
|
@unittest.skip("TimmBackbone doesn't support feed forward chunking")
|
||||||
|
def test_feed_forward_chunking(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("TimmBackbone doesn't have num_hidden_layers attribute")
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("TimmBackbone initialization is managed on the timm side")
|
||||||
|
def test_initialization(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("TimmBackbone models doesn't have inputs_embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("TimmBackbone models doesn't have inputs_embeds")
|
||||||
|
def test_model_common_attributes(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("TimmBackbone model cannot be created without specifying a backbone checkpoint")
|
||||||
|
def test_from_pretrained_no_checkpoint(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone")
|
||||||
|
def test_save_load(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("model weights aren't tied in TimmBackbone.")
|
||||||
|
def test_tie_model_weights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("model weights aren't tied in TimmBackbone.")
|
||||||
|
def test_tied_model_weights_key_ignore(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("TimmBackbone doesn't have hidden size info in its configuration.")
|
||||||
|
def test_channels(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("TimmBackbone doesn't support output_attentions.")
|
||||||
|
def test_torchscript_output_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip("Safetensors is not supported by timm.")
|
||||||
|
def test_can_use_safetensors(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["pixel_values"]
|
||||||
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.output_hidden_states = True
|
||||||
|
config.output_attentions = self.has_attentions
|
||||||
|
|
||||||
|
# no need to test all models as different heads yield the same functionality
|
||||||
|
model_class = self.all_model_classes[0]
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
outputs = model(**inputs)
|
||||||
|
output = outputs[0][-1]
|
||||||
|
|
||||||
|
# Encoder-/Decoder-only models
|
||||||
|
hidden_states = outputs.hidden_states[0]
|
||||||
|
hidden_states.retain_grad()
|
||||||
|
|
||||||
|
if self.has_attentions:
|
||||||
|
attentions = outputs.attentions[0]
|
||||||
|
attentions.retain_grad()
|
||||||
|
|
||||||
|
output.flatten()[0].backward(retain_graph=True)
|
||||||
|
|
||||||
|
self.assertIsNotNone(hidden_states.grad)
|
||||||
|
|
||||||
|
if self.has_attentions:
|
||||||
|
self.assertIsNotNone(attentions.grad)
|
||||||
|
|
||||||
|
# TimmBackbone config doesn't have out_features attribute
|
||||||
|
def test_create_from_modified_config(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(**inputs_dict)
|
||||||
|
|
||||||
|
self.assertEqual(len(result.feature_maps), len(config.out_indices))
|
||||||
|
self.assertEqual(len(model.channels), len(config.out_indices))
|
||||||
|
|
||||||
|
# Check output of last stage is taken if out_features=None, out_indices=None
|
||||||
|
modified_config = copy.deepcopy(config)
|
||||||
|
modified_config.out_indices = None
|
||||||
|
model = model_class(modified_config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(**inputs_dict)
|
||||||
|
|
||||||
|
self.assertEqual(len(result.feature_maps), 1)
|
||||||
|
self.assertEqual(len(model.channels), 1)
|
||||||
|
|
||||||
|
# Check backbone can be initialized with fresh weights
|
||||||
|
modified_config = copy.deepcopy(config)
|
||||||
|
modified_config.use_pretrained_backbone = False
|
||||||
|
model = model_class(modified_config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(**inputs_dict)
|
@ -17,6 +17,7 @@ import copy
|
|||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
from transformers.testing_utils import require_torch, torch_device
|
from transformers.testing_utils import require_torch, torch_device
|
||||||
|
from transformers.utils.backbone_utils import BackboneType
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@ -104,6 +105,8 @@ class BackboneTesterMixin:
|
|||||||
|
|
||||||
self.assertEqual(len(result.feature_maps), len(config.out_features))
|
self.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||||
self.assertEqual(len(model.channels), len(config.out_features))
|
self.assertEqual(len(model.channels), len(config.out_features))
|
||||||
|
self.assertEqual(len(result.feature_maps), len(config.out_indices))
|
||||||
|
self.assertEqual(len(model.channels), len(config.out_indices))
|
||||||
|
|
||||||
# Check output of last stage is taken if out_features=None, out_indices=None
|
# Check output of last stage is taken if out_features=None, out_indices=None
|
||||||
modified_config = copy.deepcopy(config)
|
modified_config = copy.deepcopy(config)
|
||||||
@ -140,6 +143,7 @@ class BackboneTesterMixin:
|
|||||||
for backbone_class in self.all_model_classes:
|
for backbone_class in self.all_model_classes:
|
||||||
backbone = backbone_class(config)
|
backbone = backbone_class(config)
|
||||||
|
|
||||||
|
self.assertTrue(hasattr(backbone, "backbone_type"))
|
||||||
self.assertTrue(hasattr(backbone, "stage_names"))
|
self.assertTrue(hasattr(backbone, "stage_names"))
|
||||||
self.assertTrue(hasattr(backbone, "num_features"))
|
self.assertTrue(hasattr(backbone, "num_features"))
|
||||||
self.assertTrue(hasattr(backbone, "out_indices"))
|
self.assertTrue(hasattr(backbone, "out_indices"))
|
||||||
@ -147,6 +151,7 @@ class BackboneTesterMixin:
|
|||||||
self.assertTrue(hasattr(backbone, "out_feature_channels"))
|
self.assertTrue(hasattr(backbone, "out_feature_channels"))
|
||||||
self.assertTrue(hasattr(backbone, "channels"))
|
self.assertTrue(hasattr(backbone, "channels"))
|
||||||
|
|
||||||
|
self.assertIsInstance(backbone.backbone_type, BackboneType)
|
||||||
# Verify num_features has been initialized in the backbone init
|
# Verify num_features has been initialized in the backbone init
|
||||||
self.assertIsNotNone(backbone.num_features)
|
self.assertIsNotNone(backbone.num_features)
|
||||||
self.assertTrue(len(backbone.channels) == len(backbone.out_indices))
|
self.assertTrue(len(backbone.channels) == len(backbone.out_indices))
|
||||||
|
@ -77,6 +77,7 @@ SPECIAL_CASES_TO_ALLOW = {
|
|||||||
"AutoformerConfig": ["num_static_real_features", "num_time_features"],
|
"AutoformerConfig": ["num_static_real_features", "num_time_features"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure
|
# TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure
|
||||||
SPECIAL_CASES_TO_ALLOW.update(
|
SPECIAL_CASES_TO_ALLOW.update(
|
||||||
{
|
{
|
||||||
@ -172,6 +173,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
|
|||||||
"mask_index",
|
"mask_index",
|
||||||
"image_size",
|
"image_size",
|
||||||
"use_cache",
|
"use_cache",
|
||||||
|
"out_features",
|
||||||
|
"out_indices",
|
||||||
]
|
]
|
||||||
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
|
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
|||||||
"EncoderDecoderConfig",
|
"EncoderDecoderConfig",
|
||||||
"RagConfig",
|
"RagConfig",
|
||||||
"SpeechEncoderDecoderConfig",
|
"SpeechEncoderDecoderConfig",
|
||||||
|
"TimmBackboneConfig",
|
||||||
"VisionEncoderDecoderConfig",
|
"VisionEncoderDecoderConfig",
|
||||||
"VisionTextDualEncoderConfig",
|
"VisionTextDualEncoderConfig",
|
||||||
"LlamaConfig",
|
"LlamaConfig",
|
||||||
|
@ -517,6 +517,7 @@ MODELS_NOT_IN_README = [
|
|||||||
"Speech Encoder decoder",
|
"Speech Encoder decoder",
|
||||||
"Speech2Text",
|
"Speech2Text",
|
||||||
"Speech2Text2",
|
"Speech2Text2",
|
||||||
|
"TimmBackbone",
|
||||||
"Vision Encoder decoder",
|
"Vision Encoder decoder",
|
||||||
"VisionTextDualEncoder",
|
"VisionTextDualEncoder",
|
||||||
]
|
]
|
||||||
|
@ -408,6 +408,7 @@ def get_model_modules():
|
|||||||
"modeling_speech_encoder_decoder",
|
"modeling_speech_encoder_decoder",
|
||||||
"modeling_flax_speech_encoder_decoder",
|
"modeling_flax_speech_encoder_decoder",
|
||||||
"modeling_flax_vision_encoder_decoder",
|
"modeling_flax_vision_encoder_decoder",
|
||||||
|
"modeling_timm_backbone",
|
||||||
"modeling_transfo_xl_utilities",
|
"modeling_transfo_xl_utilities",
|
||||||
"modeling_tf_auto",
|
"modeling_tf_auto",
|
||||||
"modeling_tf_encoder_decoder",
|
"modeling_tf_encoder_decoder",
|
||||||
@ -846,6 +847,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
|||||||
"NatBackbone",
|
"NatBackbone",
|
||||||
"ResNetBackbone",
|
"ResNetBackbone",
|
||||||
"SwinBackbone",
|
"SwinBackbone",
|
||||||
|
"TimmBackbone",
|
||||||
|
"TimmBackboneConfig",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user