Move common properties to BackboneMixin (#21855)

* Move common properties to BackboneMixin

* Fix failing tests

* Update ConvNextV2 backbone
This commit is contained in:
amyeroberts 2023-03-30 10:04:11 +01:00 committed by GitHub
parent cd73b9a8c1
commit c15f937581
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 27 additions and 84 deletions

View File

@ -968,12 +968,30 @@ class ModuleUtilsMixin:
class BackboneMixin:
@property
def out_feature_channels(self):
# the current backbones will output the number of channels for each stage
# even if that stage is not in the out_features list.
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
def forward(
self,
pixel_values: Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
raise NotImplementedError("This method should be implemented by the derived class.")
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r"""

View File

@ -849,21 +849,11 @@ class BitBackbone(BitPreTrainedModel, BackboneMixin):
self.bit = BitModel(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
out_feature_channels = {}
out_feature_channels["stem"] = config.embedding_size
for idx, stage in enumerate(self.stage_names[1:]):
out_feature_channels[stage] = config.hidden_sizes[idx]
self.out_feature_channels = out_feature_channels
self.num_features = [config.embedding_size] + config.hidden_sizes
# initialize weights and apply final processing
self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(BIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -486,13 +486,7 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
self.encoder = ConvNextEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
out_feature_channels = {}
out_feature_channels["stem"] = config.hidden_sizes[0]
for idx, stage in enumerate(self.stage_names[1:]):
out_feature_channels[stage] = config.hidden_sizes[idx]
self.out_feature_channels = out_feature_channels
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
# Add layer norms to hidden states of out_features
hidden_states_norms = {}
@ -503,10 +497,6 @@ class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
# initialize weights and apply final processing
self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(CONVNEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -509,13 +509,7 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
self.encoder = ConvNextV2Encoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
out_feature_channels = {}
out_feature_channels["stem"] = config.hidden_sizes[0]
for idx, stage in enumerate(self.stage_names[1:]):
out_feature_channels[stage] = config.hidden_sizes[idx]
self.out_feature_channels = out_feature_channels
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
# Add layer norms to hidden states of out_features
hidden_states_norms = {}
@ -526,10 +520,6 @@ class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
# initialize weights and apply final processing
self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(CONVNEXTV2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -891,12 +891,7 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
self.encoder = DinatEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features
hidden_states_norms = {}
@ -910,10 +905,6 @@ class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(DINAT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -859,20 +859,12 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(num_channels) for num_channels in self.channels])
# Initialize weights and apply final processing
self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward(
self,
pixel_values: Tensor,

View File

@ -869,12 +869,7 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin):
self.encoder = NatEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features
hidden_states_norms = {}
@ -888,10 +883,6 @@ class NatBackbone(NatPreTrainedModel, BackboneMixin):
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(NAT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -437,21 +437,11 @@ class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
self.encoder = ResNetEncoder(config)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
out_feature_channels = {}
out_feature_channels["stem"] = config.embedding_size
for idx, stage in enumerate(self.stage_names[1:]):
out_feature_channels[stage] = config.hidden_sizes[idx]
self.out_feature_channels = out_feature_channels
self.num_features = [config.embedding_size] + config.hidden_sizes
# initialize weights and apply final processing
self.post_init()
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
def forward(

View File

@ -1255,12 +1255,7 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features
hidden_states_norms = {}
@ -1274,10 +1269,6 @@ class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward(
self,
pixel_values: torch.Tensor,