Enable Gradient Checkpointing in Deformable DETR (#28686)

* Enabled gradient checkpointing in Deformable DETR

* Enabled gradient checkpointing in Deformable DETR encoder

* Removed # Copied from headers in modeling_deta.py to break dependence on Deformable DETR code
This commit is contained in:
Nate Cibik 2024-01-29 02:10:40 -08:00 committed by GitHub
parent f72c7c22d9
commit 0548af54cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 13 deletions

View File

@ -1048,6 +1048,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
config_class = DeformableDetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]
def _init_weights(self, module):
@ -1143,6 +1144,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
def __init__(self, config: DeformableDetrConfig):
super().__init__(config)
self.gradient_checkpointing = False
self.dropout = config.dropout
self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
@ -1235,15 +1237,27 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):
for i, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
position_embeddings,
reference_points,
spatial_shapes,
level_start_index,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
@ -1368,9 +1382,13 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings,
reference_points_input,
spatial_shapes,
level_start_index,
encoder_hidden_states,
encoder_attention_mask,
None,
output_attentions,
)
else:
layer_outputs = decoder_layer(

View File

@ -942,7 +942,6 @@ class DetaClassificationHead(nn.Module):
return hidden_states
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetrConvEncoder->DetaBackboneWithPositionalEncodings,DeformableDetr->Deta
class DetaPreTrainedModel(PreTrainedModel):
config_class = DetaConfig
base_model_prefix = "model"
@ -1028,7 +1027,6 @@ DETA_INPUTS_DOCSTRING = r"""
"""
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetr->Deta
class DetaEncoder(DetaPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
@ -1159,7 +1157,6 @@ class DetaEncoder(DetaPreTrainedModel):
)
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrDecoder with DeformableDetr->Deta,Deformable DETR->DETA
class DetaDecoder(DetaPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetaDecoderLayer`].