mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[Backbones] Improve out features (#20675)
* Improve ResNet backbone * Improve Bit backbone * Improve docstrings * Fix default stage * Apply suggestions from code review Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
parent
9e56aff58a
commit
9a6c6ef97f
@ -63,7 +63,7 @@ class BitConfig(PretrainedConfig):
|
||||
The width factor for the model.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||
(depending on how many stages the model has).
|
||||
(depending on how many stages the model has). Will default to the last stage if unset.
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
@ -851,7 +851,7 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin):
|
||||
self.stage_names = config.stage_names
|
||||
self.bit = BitModel(config)
|
||||
|
||||
self.out_features = config.out_features
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
|
||||
out_feature_channels = {}
|
||||
out_feature_channels["stem"] = config.embedding_size
|
||||
|
@ -69,7 +69,8 @@ class MaskFormerSwinConfig(PretrainedConfig):
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`.
|
||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||
(depending on how many stages the model has). Will default to the last stage if unset.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -855,7 +855,7 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
|
||||
self.stage_names = config.stage_names
|
||||
self.model = MaskFormerSwinModel(config)
|
||||
|
||||
self.out_features = config.out_features
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
if "stem" in self.out_features:
|
||||
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
|
||||
|
||||
|
@ -59,8 +59,8 @@ class ResNetConfig(PretrainedConfig):
|
||||
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, the first stage will downsample the inputs using a `stride` of 2.
|
||||
out_features (`List[str]`, *optional*):
|
||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`,
|
||||
`"stage3"`, `"stage4"`.
|
||||
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
|
||||
(depending on how many stages the model has). Will default to the last stage if unset.
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (ResNetModel, ResNetBackbone)):
|
||||
if isinstance(module, ResNetEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
@ -439,7 +439,7 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
|
||||
self.embedder = ResNetEmbeddings(config)
|
||||
self.encoder = ResNetEncoder(config)
|
||||
|
||||
self.out_features = config.out_features
|
||||
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
|
||||
|
||||
out_feature_channels = {}
|
||||
out_feature_channels["stem"] = config.embedding_size
|
||||
|
@ -119,7 +119,7 @@ class BitModelTester:
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify hidden states
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
|
||||
|
||||
@ -127,6 +127,21 @@ class BitModelTester:
|
||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
|
||||
|
||||
# verify backbone works with out_features=None
|
||||
config.out_features = None
|
||||
model = BitBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), 1)
|
||||
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
|
@ -119,7 +119,7 @@ class ResNetModelTester:
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify hidden states
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
|
||||
|
||||
@ -127,6 +127,21 @@ class ResNetModelTester:
|
||||
self.parent.assertEqual(len(model.channels), len(config.out_features))
|
||||
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
|
||||
|
||||
# verify backbone works with out_features=None
|
||||
config.out_features = None
|
||||
model = ResNetBackbone(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
# verify feature maps
|
||||
self.parent.assertEqual(len(result.feature_maps), 1)
|
||||
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
|
||||
|
||||
# verify channels
|
||||
self.parent.assertEqual(len(model.channels), 1)
|
||||
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
config, pixel_values, labels = config_and_inputs
|
||||
|
Loading…
Reference in New Issue
Block a user