mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
Add methods to update and verify out_features out_indices (#23031)
* Add methods to update and verify out_features out_indices * Safe update for config attributes * Fix function names * Save config correctly * PR comments - use property setters * PR comment - directly set attributes * Update test * Add updates to recently merged focalnet backbone
This commit is contained in:
parent
78b7debf56
commit
90e8263d91
@ -1006,32 +1006,6 @@ class ModuleUtilsMixin:
|
||||
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
||||
|
||||
|
||||
class BackboneMixin:
|
||||
@property
|
||||
def out_feature_channels(self):
|
||||
# the current backbones will output the number of channels for each stage
|
||||
# even if that stage is not in the out_features list.
|
||||
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
return [self.out_feature_channels[name] for name in self.out_features]
|
||||
|
||||
def forward_with_filtered_kwargs(self, *args, **kwargs):
|
||||
signature = dict(inspect.signature(self.forward).parameters)
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
|
||||
return self(*args, **filtered_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Tensor,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
raise NotImplementedError("This method should be implemented by the derived class.")
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -25,7 +26,7 @@ BIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class BitConfig(PretrainedConfig):
|
||||
class BitConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`BitModel`]. It is used to instantiate an BiT
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
@ -128,35 +129,6 @@ class BitConfig(PretrainedConfig):
|
||||
self.width_factor = width_factor
|
||||
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
raise ValueError("out_features and out_indices should have the same length if both are set")
|
||||
elif out_features != [self.stage_names[idx] for idx in out_indices]:
|
||||
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
|
||||
|
||||
if out_features is None and out_indices is not None:
|
||||
out_features = [self.stage_names[idx] for idx in out_indices]
|
||||
elif out_features is not None and out_indices is None:
|
||||
out_indices = [self.stage_names.index(feature) for feature in out_features]
|
||||
elif out_features is None and out_indices is None:
|
||||
out_features = [self.stage_names[-1]]
|
||||
out_indices = [len(self.stage_names) - 1]
|
||||
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
if out_indices is not None:
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError("out_indices should be a list or tuple")
|
||||
for idx in out_indices:
|
||||
if idx >= len(self.stage_names):
|
||||
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
|
||||
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
@ -31,7 +31,7 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPoolingAndNoAttention,
|
||||
ImageClassifierOutputWithNoAttention,
|
||||
)
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -39,6 +39,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from .configuration_bit import BitConfig
|
||||
|
||||
|
||||
@ -848,12 +849,10 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin):
|
||||
self.stage_names = config.stage_names
|
||||
self.bit = BitModel(config)
|
||||
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
self.num_features = [config.embedding_size] + config.hidden_sizes
|
||||
if config.out_indices is not None:
|
||||
self.out_indices = config.out_indices
|
||||
else:
|
||||
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
|
||||
# initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
@ -22,6 +22,7 @@ from packaging import version
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -32,7 +33,7 @@ CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class ConvNextConfig(PretrainedConfig):
|
||||
class ConvNextConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`ConvNextModel`]. It is used to instantiate an
|
||||
ConvNeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
@ -119,38 +120,9 @@ class ConvNextConfig(PretrainedConfig):
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.image_size = image_size
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
raise ValueError("out_features and out_indices should have the same length if both are set")
|
||||
elif out_features != [self.stage_names[idx] for idx in out_indices]:
|
||||
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
|
||||
|
||||
if out_features is None and out_indices is not None:
|
||||
out_features = [self.stage_names[idx] for idx in out_indices]
|
||||
elif out_features is not None and out_indices is None:
|
||||
out_indices = [self.stage_names.index(feature) for feature in out_features]
|
||||
elif out_features is None and out_indices is None:
|
||||
out_features = [self.stage_names[-1]]
|
||||
out_indices = [len(self.stage_names) - 1]
|
||||
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
if out_indices is not None:
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError("out_indices should be a list or tuple")
|
||||
for idx in out_indices:
|
||||
if idx >= len(self.stage_names):
|
||||
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
|
||||
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
||||
|
||||
class ConvNextOnnxConfig(OnnxConfig):
|
||||
|
@ -29,7 +29,7 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPoolingAndNoAttention,
|
||||
ImageClassifierOutputWithNoAttention,
|
||||
)
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -37,6 +37,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from .configuration_convnext import ConvNextConfig
|
||||
|
||||
|
||||
@ -485,16 +486,14 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
|
||||
self.embeddings = ConvNextEmbeddings(config)
|
||||
self.encoder = ConvNextEncoder(config)
|
||||
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
|
||||
if config.out_indices is not None:
|
||||
self.out_indices = config.out_indices
|
||||
else:
|
||||
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
hidden_states_norms = {}
|
||||
for stage, num_channels in zip(self.out_features, self.channels):
|
||||
for stage, num_channels in zip(self._out_features, self.channels):
|
||||
hidden_states_norms[stage] = ConvNextLayerNorm(num_channels, data_format="channels_first")
|
||||
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -26,7 +27,7 @@ CONVNEXTV2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class ConvNextV2Config(PretrainedConfig):
|
||||
class ConvNextV2Config(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`ConvNextV2Model`]. It is used to instantiate an
|
||||
ConvNeXTV2 model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
@ -109,35 +110,6 @@ class ConvNextV2Config(PretrainedConfig):
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.image_size = image_size
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
raise ValueError("out_features and out_indices should have the same length if both are set")
|
||||
elif out_features != [self.stage_names[idx] for idx in out_indices]:
|
||||
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
|
||||
|
||||
if out_features is None and out_indices is not None:
|
||||
out_features = [self.stage_names[idx] for idx in out_indices]
|
||||
elif out_features is not None and out_indices is None:
|
||||
out_indices = [self.stage_names.index(feature) for feature in out_features]
|
||||
elif out_features is None and out_indices is None:
|
||||
out_features = [self.stage_names[-1]]
|
||||
out_indices = [len(self.stage_names) - 1]
|
||||
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
if out_indices is not None:
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError("out_indices should be a list or tuple")
|
||||
for idx in out_indices:
|
||||
if idx >= len(self.stage_names):
|
||||
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
|
||||
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPoolingAndNoAttention,
|
||||
ImageClassifierOutputWithNoAttention,
|
||||
)
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -37,6 +37,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from .configuration_convnextv2 import ConvNextV2Config
|
||||
|
||||
|
||||
@ -508,16 +509,14 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
|
||||
self.embeddings = ConvNextV2Embeddings(config)
|
||||
self.encoder = ConvNextV2Encoder(config)
|
||||
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
|
||||
if config.out_indices is not None:
|
||||
self.out_indices = config.out_indices
|
||||
else:
|
||||
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
hidden_states_norms = {}
|
||||
for stage, num_channels in zip(self.out_features, self.channels):
|
||||
for stage, num_channels in zip(self._out_features, self.channels):
|
||||
hidden_states_norms[stage] = ConvNextV2LayerNorm(num_channels, data_format="channels_first")
|
||||
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -26,7 +27,7 @@ DINAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class DinatConfig(PretrainedConfig):
|
||||
class DinatConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`DinatModel`]. It is used to instantiate a Dinat
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
@ -145,35 +146,6 @@ class DinatConfig(PretrainedConfig):
|
||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
raise ValueError("out_features and out_indices should have the same length if both are set")
|
||||
elif out_features != [self.stage_names[idx] for idx in out_indices]:
|
||||
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
|
||||
|
||||
if out_features is None and out_indices is not None:
|
||||
out_features = [self.stage_names[idx] for idx in out_indices]
|
||||
elif out_features is not None and out_indices is None:
|
||||
out_indices = [self.stage_names.index(feature) for feature in out_features]
|
||||
elif out_features is None and out_indices is None:
|
||||
out_features = [self.stage_names[-1]]
|
||||
out_indices = [len(self.stage_names) - 1]
|
||||
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
if out_indices is not None:
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError("out_indices should be a list or tuple")
|
||||
for idx in out_indices:
|
||||
if idx >= len(self.stage_names):
|
||||
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
|
||||
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@ -39,6 +39,7 @@ from ...utils import (
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from .configuration_dinat import DinatConfig
|
||||
|
||||
|
||||
@ -890,16 +891,14 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
|
||||
self.embeddings = DinatEmbeddings(config)
|
||||
self.encoder = DinatEncoder(config)
|
||||
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
if config.out_indices is not None:
|
||||
self.out_indices = config.out_indices
|
||||
else:
|
||||
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
hidden_states_norms = {}
|
||||
for stage, num_channels in zip(self.out_features, self.channels):
|
||||
for stage, num_channels in zip(self._out_features, self.channels):
|
||||
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
|
||||
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -25,7 +26,7 @@ FOCALNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class FocalNetConfig(PretrainedConfig):
|
||||
class FocalNetConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`FocalNetModel`]. It is used to instantiate a
|
||||
FocalNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
@ -156,35 +157,6 @@ class FocalNetConfig(PretrainedConfig):
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.encoder_stride = encoder_stride
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(self.depths) + 1)]
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
raise ValueError("out_features and out_indices should have the same length if both are set")
|
||||
elif out_features != [self.stage_names[idx] for idx in out_indices]:
|
||||
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
|
||||
|
||||
if out_features is None and out_indices is not None:
|
||||
out_features = [self.stage_names[idx] for idx in out_indices]
|
||||
elif out_features is not None and out_indices is None:
|
||||
out_indices = [self.stage_names.index(feature) for feature in out_features]
|
||||
elif out_features is None and out_indices is None:
|
||||
out_features = [self.stage_names[-1]]
|
||||
out_indices = [len(self.stage_names) - 1]
|
||||
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
if out_indices is not None:
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError("out_indices should be a list or tuple")
|
||||
for idx in out_indices:
|
||||
if idx >= len(self.stage_names):
|
||||
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
|
||||
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
add_code_sample_docstrings,
|
||||
@ -36,6 +36,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from .configuration_focalnet import FocalNetConfig
|
||||
|
||||
|
||||
@ -987,11 +988,9 @@ class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
|
||||
self.focalnet = FocalNetModel(config)
|
||||
|
||||
self.num_features = [config.embed_dim] + config.hidden_sizes
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
if config.out_indices is not None:
|
||||
self.out_indices = config.out_indices
|
||||
else:
|
||||
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
|
||||
# initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
@ -16,12 +16,13 @@
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class MaskFormerSwinConfig(PretrainedConfig):
|
||||
class MaskFormerSwinConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MaskFormerSwinModel`]. It is used to instantiate
|
||||
a Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
@ -141,35 +142,6 @@ class MaskFormerSwinConfig(PretrainedConfig):
|
||||
# this indicates the channel dimension after the last stage of the model
|
||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
raise ValueError("out_features and out_indices should have the same length if both are set")
|
||||
elif out_features != [self.stage_names[idx] for idx in out_indices]:
|
||||
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
|
||||
|
||||
if out_features is None and out_indices is not None:
|
||||
out_features = [self.stage_names[idx] for idx in out_indices]
|
||||
elif out_features is not None and out_indices is None:
|
||||
out_indices = [self.stage_names.index(feature) for feature in out_features]
|
||||
elif out_features is None and out_indices is None:
|
||||
out_features = [self.stage_names[-1]]
|
||||
out_indices = [len(self.stage_names) - 1]
|
||||
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
if out_indices is not None:
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError("out_indices should be a list or tuple")
|
||||
for idx in out_indices:
|
||||
if idx >= len(self.stage_names):
|
||||
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
|
||||
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
@ -27,8 +27,9 @@ from torch import Tensor, nn
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import ModelOutput
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from .configuration_maskformer_swin import MaskFormerSwinConfig
|
||||
|
||||
|
||||
@ -855,14 +856,13 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
|
||||
self.stage_names = config.stage_names
|
||||
self.model = MaskFormerSwinModel(config)
|
||||
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
self._out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
if "stem" in self.out_features:
|
||||
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
|
||||
|
||||
if config.out_indices is not None:
|
||||
self.out_indices = config.out_indices
|
||||
else:
|
||||
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
self.hidden_states_norms = nn.ModuleList(
|
||||
[nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]]
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -26,7 +27,7 @@ NAT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class NatConfig(PretrainedConfig):
|
||||
class NatConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`NatModel`]. It is used to instantiate a Nat model
|
||||
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
@ -141,35 +142,6 @@ class NatConfig(PretrainedConfig):
|
||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
self.layer_scale_init_value = layer_scale_init_value
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
raise ValueError("out_features and out_indices should have the same length if both are set")
|
||||
elif out_features != [self.stage_names[idx] for idx in out_indices]:
|
||||
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
|
||||
|
||||
if out_features is None and out_indices is not None:
|
||||
out_features = [self.stage_names[idx] for idx in out_indices]
|
||||
elif out_features is not None and out_indices is None:
|
||||
out_indices = [self.stage_names.index(feature) for feature in out_features]
|
||||
elif out_features is None and out_indices is None:
|
||||
out_features = [self.stage_names[-1]]
|
||||
out_indices = [len(self.stage_names) - 1]
|
||||
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
if out_indices is not None:
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError("out_indices should be a list or tuple")
|
||||
for idx in out_indices:
|
||||
if idx >= len(self.stage_names):
|
||||
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
|
||||
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
@ -26,7 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@ -39,6 +39,7 @@ from ...utils import (
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from .configuration_nat import NatConfig
|
||||
|
||||
|
||||
@ -868,11 +869,9 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin):
|
||||
self.embeddings = NatEmbeddings(config)
|
||||
self.encoder = NatEncoder(config)
|
||||
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
if config.out_indices is not None:
|
||||
self.out_indices = config.out_indices
|
||||
else:
|
||||
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
|
@ -22,6 +22,7 @@ from packaging import version
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -31,7 +32,7 @@ RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class ResNetConfig(PretrainedConfig):
|
||||
class ResNetConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`ResNetModel`]. It is used to instantiate an
|
||||
ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
@ -108,38 +109,9 @@ class ResNetConfig(PretrainedConfig):
|
||||
self.hidden_act = hidden_act
|
||||
self.downsample_in_first_stage = downsample_in_first_stage
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
raise ValueError("out_features and out_indices should have the same length if both are set")
|
||||
elif out_features != [self.stage_names[idx] for idx in out_indices]:
|
||||
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
|
||||
|
||||
if out_features is None and out_indices is not None:
|
||||
out_features = [self.stage_names[idx] for idx in out_indices]
|
||||
elif out_features is not None and out_indices is None:
|
||||
out_indices = [self.stage_names.index(feature) for feature in out_features]
|
||||
elif out_features is None and out_indices is None:
|
||||
out_features = [self.stage_names[-1]]
|
||||
out_indices = [len(self.stage_names) - 1]
|
||||
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
if out_indices is not None:
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError("out_indices should be a list or tuple")
|
||||
for idx in out_indices:
|
||||
if idx >= len(self.stage_names):
|
||||
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
|
||||
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
||||
|
||||
class ResNetOnnxConfig(OnnxConfig):
|
||||
|
@ -28,7 +28,7 @@ from ...modeling_outputs import (
|
||||
BaseModelOutputWithPoolingAndNoAttention,
|
||||
ImageClassifierOutputWithNoAttention,
|
||||
)
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import (
|
||||
add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
@ -36,6 +36,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from .configuration_resnet import ResNetConfig
|
||||
|
||||
|
||||
@ -436,11 +437,9 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
|
||||
self.embedder = ResNetEmbeddings(config)
|
||||
self.encoder = ResNetEncoder(config)
|
||||
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
if config.out_indices is not None:
|
||||
self.out_indices = config.out_indices
|
||||
else:
|
||||
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.num_features = [config.embedding_size] + config.hidden_sizes
|
||||
|
||||
# initialize weights and apply final processing
|
||||
|
@ -22,6 +22,7 @@ from packaging import version
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxConfig
|
||||
from ...utils import logging
|
||||
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -34,7 +35,7 @@ SWIN_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class SwinConfig(PretrainedConfig):
|
||||
class SwinConfig(BackboneConfigMixin, PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`SwinModel`]. It is used to instantiate a Swin
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
@ -158,38 +159,9 @@ class SwinConfig(PretrainedConfig):
|
||||
# this indicates the channel dimension after the last stage of the model
|
||||
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
|
||||
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
raise ValueError("out_features and out_indices should have the same length if both are set")
|
||||
elif out_features != [self.stage_names[idx] for idx in out_indices]:
|
||||
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
|
||||
|
||||
if out_features is None and out_indices is not None:
|
||||
out_features = [self.stage_names[idx] for idx in out_indices]
|
||||
elif out_features is not None and out_indices is None:
|
||||
out_indices = [self.stage_names.index(feature) for feature in out_features]
|
||||
elif out_features is None and out_indices is None:
|
||||
out_features = [self.stage_names[-1]]
|
||||
out_indices = [len(self.stage_names) - 1]
|
||||
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, list):
|
||||
raise ValueError("out_features should be a list")
|
||||
for feature in out_features:
|
||||
if feature not in self.stage_names:
|
||||
raise ValueError(
|
||||
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
|
||||
)
|
||||
if out_indices is not None:
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError("out_indices should be a list or tuple")
|
||||
for idx in out_indices:
|
||||
if idx >= len(self.stage_names):
|
||||
raise ValueError(f"Index {idx} is not a valid index for a list of length {len(self.stage_names)}")
|
||||
|
||||
self.out_features = out_features
|
||||
self.out_indices = out_indices
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
||||
|
||||
class SwinOnnxConfig(OnnxConfig):
|
||||
|
@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BackboneOutput
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
|
||||
from ...utils import (
|
||||
ModelOutput,
|
||||
@ -38,6 +38,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
|
||||
from .configuration_swin import SwinConfig
|
||||
|
||||
|
||||
@ -1264,16 +1265,14 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
|
||||
self.embeddings = SwinEmbeddings(config)
|
||||
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
|
||||
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
if config.out_indices is not None:
|
||||
self.out_indices = config.out_indices
|
||||
else:
|
||||
self.out_indices = tuple(i for i, layer in enumerate(self.stage_names) if layer in self.out_features)
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
config.out_features, config.out_indices, self.stage_names
|
||||
)
|
||||
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
|
||||
|
||||
# Add layer norms to hidden states of out_features
|
||||
hidden_states_norms = {}
|
||||
for stage, num_channels in zip(self.out_features, self.channels):
|
||||
for stage, num_channels in zip(self._out_features, self.channels):
|
||||
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
|
||||
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
|
||||
|
||||
|
@ -22,8 +22,9 @@ from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ... import AutoBackbone
|
||||
from ...modeling_outputs import SemanticSegmenterOutput
|
||||
from ...modeling_utils import BackboneMixin, PreTrainedModel
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from .configuration_upernet import UperNetConfig
|
||||
|
||||
|
||||
|
203
src/transformers/utils/backbone_utils.py
Normal file
203
src/transformers/utils/backbone_utils.py
Normal file
@ -0,0 +1,203 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" Collection of utils to be used by backbones and their components."""
|
||||
|
||||
import inspect
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
def verify_out_features_out_indices(
|
||||
out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]]
|
||||
):
|
||||
"""
|
||||
Verify that out_indices and out_features are valid for the given stage_names.
|
||||
"""
|
||||
if stage_names is None:
|
||||
raise ValueError("Stage_names must be set for transformers backbones")
|
||||
|
||||
if out_features is not None:
|
||||
if not isinstance(out_features, (list,)):
|
||||
raise ValueError(f"out_features must be a list {type(out_features)}")
|
||||
if any(feat not in stage_names for feat in out_features):
|
||||
raise ValueError(f"out_features must be a subset of stage_names: {stage_names} got {out_features}")
|
||||
|
||||
if out_indices is not None:
|
||||
if not isinstance(out_indices, (list, tuple)):
|
||||
raise ValueError(f"out_indices must be a list or tuple, got {type(out_indices)}")
|
||||
if any(idx >= len(stage_names) for idx in out_indices):
|
||||
raise ValueError("out_indices must be valid indices for stage_names {stage_names}, got {out_indices}")
|
||||
|
||||
if out_features is not None and out_indices is not None:
|
||||
if len(out_features) != len(out_indices):
|
||||
raise ValueError("out_features and out_indices should have the same length if both are set")
|
||||
if out_features != [stage_names[idx] for idx in out_indices]:
|
||||
raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
|
||||
|
||||
|
||||
def _align_output_features_output_indices(
|
||||
out_features: Optional[List[str]],
|
||||
out_indices: Optional[Union[List[int], Tuple[int]]],
|
||||
stage_names: List[str],
|
||||
):
|
||||
"""
|
||||
Finds the corresponding `out_features` and `out_indices` for the given `stage_names`.
|
||||
|
||||
The logic is as follows:
|
||||
- `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the
|
||||
`out_indices`.
|
||||
- `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the
|
||||
`out_features`.
|
||||
- `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage.
|
||||
- `out_indices` and `out_features` set: input `out_indices` and `out_features` are returned.
|
||||
|
||||
Args:
|
||||
out_features (`List[str]`): The names of the features for the backbone to output.
|
||||
out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output.
|
||||
stage_names (`List[str]`): The names of the stages of the backbone.
|
||||
"""
|
||||
if out_indices is None and out_features is None:
|
||||
out_indices = [len(stage_names) - 1]
|
||||
out_features = [stage_names[-1]]
|
||||
elif out_indices is None and out_features is not None:
|
||||
out_indices = [stage_names.index(layer) for layer in stage_names if layer in out_features]
|
||||
elif out_features is None and out_indices is not None:
|
||||
out_features = [stage_names[idx] for idx in out_indices]
|
||||
return out_features, out_indices
|
||||
|
||||
|
||||
def get_aligned_output_features_output_indices(
|
||||
out_features: Optional[List[str]],
|
||||
out_indices: Optional[Union[List[int], Tuple[int]]],
|
||||
stage_names: List[str],
|
||||
) -> Tuple[List[str], List[int]]:
|
||||
"""
|
||||
Get the `out_features` and `out_indices` so that they are aligned.
|
||||
|
||||
The logic is as follows:
|
||||
- `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the
|
||||
`out_indices`.
|
||||
- `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the
|
||||
`out_features`.
|
||||
- `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage.
|
||||
- `out_indices` and `out_features` set: they are verified to be aligned.
|
||||
|
||||
Args:
|
||||
out_features (`List[str]`): The names of the features for the backbone to output.
|
||||
out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output.
|
||||
stage_names (`List[str]`): The names of the stages of the backbone.
|
||||
"""
|
||||
# First verify that the out_features and out_indices are valid
|
||||
verify_out_features_out_indices(out_features=out_features, out_indices=out_indices, stage_names=stage_names)
|
||||
output_features, output_indices = _align_output_features_output_indices(
|
||||
out_features=out_features, out_indices=out_indices, stage_names=stage_names
|
||||
)
|
||||
# Verify that the aligned out_features and out_indices are valid
|
||||
verify_out_features_out_indices(out_features=output_features, out_indices=output_indices, stage_names=stage_names)
|
||||
return output_features, output_indices
|
||||
|
||||
|
||||
class BackboneMixin:
|
||||
@property
|
||||
def out_feature_channels(self):
|
||||
# the current backbones will output the number of channels for each stage
|
||||
# even if that stage is not in the out_features list.
|
||||
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
return [self.out_feature_channels[name] for name in self.out_features]
|
||||
|
||||
def forward_with_filtered_kwargs(self, *args, **kwargs):
|
||||
signature = dict(inspect.signature(self.forward).parameters)
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
|
||||
return self(*args, **filtered_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
raise NotImplementedError("This method should be implemented by the derived class.")
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
return self._out_features
|
||||
|
||||
@out_features.setter
|
||||
def out_features(self, out_features: List[str]):
|
||||
"""
|
||||
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
|
||||
"""
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=None, stage_names=self.stage_names
|
||||
)
|
||||
|
||||
@property
|
||||
def out_indices(self):
|
||||
return self._out_indices
|
||||
|
||||
@out_indices.setter
|
||||
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
|
||||
"""
|
||||
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
|
||||
"""
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=None, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
||||
|
||||
class BackboneConfigMixin:
|
||||
"""
|
||||
A Mixin to support handling the `out_features` and `out_indices` attributes for the backbone configurations.
|
||||
"""
|
||||
|
||||
@property
|
||||
def out_features(self):
|
||||
return self._out_features
|
||||
|
||||
@out_features.setter
|
||||
def out_features(self, out_features: List[str]):
|
||||
"""
|
||||
Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
|
||||
"""
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=out_features, out_indices=None, stage_names=self.stage_names
|
||||
)
|
||||
|
||||
@property
|
||||
def out_indices(self):
|
||||
return self._out_indices
|
||||
|
||||
@out_indices.setter
|
||||
def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
|
||||
"""
|
||||
Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
|
||||
"""
|
||||
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
|
||||
out_features=None, out_indices=out_indices, stage_names=self.stage_names
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to
|
||||
include the `out_features` and `out_indices` attributes.
|
||||
"""
|
||||
output = super().to_dict()
|
||||
output["out_features"] = output.pop("_out_features")
|
||||
output["out_indices"] = output.pop("_out_indices")
|
||||
return output
|
@ -81,9 +81,15 @@ class BackboneTesterMixin:
|
||||
out_channels = [num_features[idx] for idx in out_indices]
|
||||
self.assertListEqual(model.channels, out_channels)
|
||||
|
||||
config.out_features = None
|
||||
config.out_indices = None
|
||||
model = model_class(config)
|
||||
new_config = copy.deepcopy(config)
|
||||
new_config.out_features = None
|
||||
model = model_class(new_config)
|
||||
self.assertEqual(len(model.channels), 1)
|
||||
self.assertListEqual(model.channels, [num_features[-1]])
|
||||
|
||||
new_config = copy.deepcopy(config)
|
||||
new_config.out_indices = None
|
||||
model = model_class(new_config)
|
||||
self.assertEqual(len(model.channels), 1)
|
||||
self.assertListEqual(model.channels, [num_features[-1]])
|
||||
|
||||
@ -102,6 +108,15 @@ class BackboneTesterMixin:
|
||||
# Check output of last stage is taken if out_features=None, out_indices=None
|
||||
modified_config = copy.deepcopy(config)
|
||||
modified_config.out_features = None
|
||||
model = model_class(modified_config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(**inputs_dict)
|
||||
|
||||
self.assertEqual(len(result.feature_maps), 1)
|
||||
self.assertEqual(len(model.channels), 1)
|
||||
|
||||
modified_config = copy.deepcopy(config)
|
||||
modified_config.out_indices = None
|
||||
model = model_class(modified_config)
|
||||
model.to(torch_device)
|
||||
|
102
tests/utils/test_backbone_utils.py
Normal file
102
tests/utils/test_backbone_utils.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.utils.backbone_utils import (
|
||||
BackboneMixin,
|
||||
get_aligned_output_features_output_indices,
|
||||
verify_out_features_out_indices,
|
||||
)
|
||||
|
||||
|
||||
class BackboneUtilsTester(unittest.TestCase):
|
||||
def test_get_aligned_output_features_output_indices(self):
|
||||
stage_names = ["a", "b", "c"]
|
||||
|
||||
# Defaults to last layer if both are None
|
||||
out_features, out_indices = get_aligned_output_features_output_indices(None, None, stage_names)
|
||||
self.assertEqual(out_features, ["c"])
|
||||
self.assertEqual(out_indices, [2])
|
||||
|
||||
# Out indices set to match out features
|
||||
out_features, out_indices = get_aligned_output_features_output_indices(["a", "c"], None, stage_names)
|
||||
self.assertEqual(out_features, ["a", "c"])
|
||||
self.assertEqual(out_indices, [0, 2])
|
||||
|
||||
# Out features set to match out indices
|
||||
out_features, out_indices = get_aligned_output_features_output_indices(None, [0, 2], stage_names)
|
||||
self.assertEqual(out_features, ["a", "c"])
|
||||
self.assertEqual(out_indices, [0, 2])
|
||||
|
||||
# Out features selected from negative indices
|
||||
out_features, out_indices = get_aligned_output_features_output_indices(None, [-3, -1], stage_names)
|
||||
self.assertEqual(out_features, ["a", "c"])
|
||||
self.assertEqual(out_indices, [-3, -1])
|
||||
|
||||
def test_verify_out_features_out_indices(self):
|
||||
# Stage names must be set
|
||||
with self.assertRaises(ValueError):
|
||||
verify_out_features_out_indices(["a", "b"], (0, 1), None)
|
||||
|
||||
# Out features must be a list
|
||||
with self.assertRaises(ValueError):
|
||||
verify_out_features_out_indices(("a", "b"), (0, 1), ["a", "b"])
|
||||
|
||||
# Out features must be a subset of stage names
|
||||
with self.assertRaises(ValueError):
|
||||
verify_out_features_out_indices(["a", "b"], (0, 1), ["a"])
|
||||
|
||||
# Out indices must be a list or tuple
|
||||
with self.assertRaises(ValueError):
|
||||
verify_out_features_out_indices(None, 0, ["a", "b"])
|
||||
|
||||
# Out indices must be a subset of stage names
|
||||
with self.assertRaises(ValueError):
|
||||
verify_out_features_out_indices(None, (0, 1), ["a"])
|
||||
|
||||
# Out features and out indices must be the same length
|
||||
with self.assertRaises(ValueError):
|
||||
verify_out_features_out_indices(["a", "b"], (0,), ["a", "b", "c"])
|
||||
|
||||
# Out features should match out indices
|
||||
with self.assertRaises(ValueError):
|
||||
verify_out_features_out_indices(["a", "b"], (0, 2), ["a", "b", "c"])
|
||||
|
||||
# Out features and out indices should be in order
|
||||
with self.assertRaises(ValueError):
|
||||
verify_out_features_out_indices(["b", "a"], (0, 1), ["a", "b"])
|
||||
|
||||
# Check passes with valid inputs
|
||||
verify_out_features_out_indices(["a", "b", "d"], (0, 1, -1), ["a", "b", "c", "d"])
|
||||
|
||||
def test_backbone_mixin(self):
|
||||
backbone = BackboneMixin()
|
||||
|
||||
backbone.stage_names = ["a", "b", "c"]
|
||||
backbone._out_features = ["a", "c"]
|
||||
backbone._out_indices = [0, 2]
|
||||
|
||||
# Check that the output features and indices are set correctly
|
||||
self.assertEqual(backbone.out_features, ["a", "c"])
|
||||
self.assertEqual(backbone.out_indices, [0, 2])
|
||||
|
||||
# Check out features and indices are updated correctly
|
||||
backbone.out_features = ["a", "b"]
|
||||
self.assertEqual(backbone.out_features, ["a", "b"])
|
||||
self.assertEqual(backbone.out_indices, [0, 1])
|
||||
|
||||
backbone.out_indices = [-3, -1]
|
||||
self.assertEqual(backbone.out_features, ["a", "c"])
|
||||
self.assertEqual(backbone.out_indices, [-3, -1])
|
Loading…
Reference in New Issue
Block a user