[Backbones] Improve out features (#20675)

* Improve ResNet backbone

* Improve Bit backbone

* Improve docstrings

* Fix default stage

* Apply suggestions from code review

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge 2022-12-09 09:14:52 +01:00 committed by GitHub
parent 9e56aff58a
commit 9a6c6ef97f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 41 additions and 10 deletions

View File

@ -63,7 +63,7 @@ class BitConfig(PretrainedConfig):
The width factor for the model.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has).
(depending on how many stages the model has). Will default to the last stage if unset.
Example:
```python

View File

@ -851,7 +851,7 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin):
self.stage_names = config.stage_names
self.bit = BitModel(config)
self.out_features = config.out_features
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
out_feature_channels = {}
out_feature_channels["stem"] = config.embedding_size

View File

@ -69,7 +69,8 @@ class MaskFormerSwinConfig(PretrainedConfig):
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
out_features (`List[str]`, *optional*):
If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`.
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.
Example:

View File

@ -855,7 +855,7 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
self.stage_names = config.stage_names
self.model = MaskFormerSwinModel(config)
self.out_features = config.out_features
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.")

View File

@ -59,8 +59,8 @@ class ResNetConfig(PretrainedConfig):
downsample_in_first_stage (`bool`, *optional*, defaults to `False`):
If `True`, the first stage will downsample the inputs using a `stride` of 2.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`,
`"stage3"`, `"stage4"`.
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.
Example:
```python

View File

@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (ResNetModel, ResNetBackbone)):
if isinstance(module, ResNetEncoder):
module.gradient_checkpointing = value
@ -439,7 +439,7 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
self.embedder = ResNetEmbeddings(config)
self.encoder = ResNetEncoder(config)
self.out_features = config.out_features
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
out_feature_channels = {}
out_feature_channels["stem"] = config.embedding_size

View File

@ -119,7 +119,7 @@ class BitModelTester:
model.eval()
result = model(pixel_values)
# verify hidden states
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
@ -127,6 +127,21 @@ class BitModelTester:
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
# verify backbone works with out_features=None
config.out_features = None
model = BitBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs

View File

@ -119,7 +119,7 @@ class ResNetModelTester:
model.eval()
result = model(pixel_values)
# verify hidden states
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
@ -127,6 +127,21 @@ class ResNetModelTester:
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
# verify backbone works with out_features=None
config.out_features = None
model = ResNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs