mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enable Gradient Checkpointing in Deformable DETR (#28686)
* Enabled gradient checkpointing in Deformable DETR * Enabled gradient checkpointing in Deformable DETR encoder * Removed # Copied from headers in modeling_deta.py to break dependence on Deformable DETR code
This commit is contained in:
parent
f72c7c22d9
commit
0548af54cc
@ -1048,6 +1048,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
|
||||
config_class = DeformableDetrConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
@ -1143,6 +1144,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
||||
|
||||
def __init__(self, config: DeformableDetrConfig):
|
||||
super().__init__(config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.dropout = config.dropout
|
||||
self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
@ -1235,15 +1237,27 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
|
||||
for i, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
level_start_index=level_start_index,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_embeddings,
|
||||
reference_points,
|
||||
spatial_shapes,
|
||||
level_start_index,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
level_start_index=level_start_index,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
@ -1368,9 +1382,13 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
reference_points_input,
|
||||
spatial_shapes,
|
||||
level_start_index,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
None,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
|
@ -942,7 +942,6 @@ class DetaClassificationHead(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetrConvEncoder->DetaBackboneWithPositionalEncodings,DeformableDetr->Deta
|
||||
class DetaPreTrainedModel(PreTrainedModel):
|
||||
config_class = DetaConfig
|
||||
base_model_prefix = "model"
|
||||
@ -1028,7 +1027,6 @@ DETA_INPUTS_DOCSTRING = r"""
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetr->Deta
|
||||
class DetaEncoder(DetaPreTrainedModel):
|
||||
"""
|
||||
Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
|
||||
@ -1159,7 +1157,6 @@ class DetaEncoder(DetaPreTrainedModel):
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrDecoder with DeformableDetr->Deta,Deformable DETR->DETA
|
||||
class DetaDecoder(DetaPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetaDecoderLayer`].
|
||||
|
Loading…
Reference in New Issue
Block a user