enable graident checkpointing in DetaObjectDetection and add tests in Swin/Donut_Swin (#28615)

* enable graident checkpointing in DetaObjectDetection

* fix missing part in original DETA

* make style

* make fix-copies

* Revert "make fix-copies"

This reverts commit 4041c86c29248f1673e8173b677c20b5a4511358.

* remove fix-copies of DetaDecoder

* enable swin gradient checkpointing

* fix gradient checkpointing in donut_swin

* add tests for deta/swin/donut

* Revert "fix gradient checkpointing in donut_swin"

This reverts commit 1cf345e34d3cc0e09eb800d9895805b1dd9b474d.

* change supports_gradient_checkpointing pipeline to PreTrainedModel

* Revert "add tests for deta/swin/donut"

This reverts commit 6056ffbb1eddc3cb3a99e4ebb231ae3edf295f5b.

* Revert "Revert "fix gradient checkpointing in donut_swin""

This reverts commit 24e25d0a14891241de58a0d86f817d0b5d2a341f.

* Simple revert

* enable deformable detr gradient checkpointing

* add gradient in encoder
This commit is contained in:
Sangbum Daniel Choi 2024-02-02 00:07:44 +09:00 committed by GitHub
parent 7bc6d76396
commit e19c12e094
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 25 additions and 3 deletions

View File

@ -1050,6 +1050,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std

View File

@ -947,6 +947,7 @@ class DetaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
main_input_name = "pixel_values"
_no_split_modules = [r"DetaBackboneWithPositionalEncodings", r"DetaEncoderLayer", r"DetaDecoderLayer"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
@ -1043,6 +1044,7 @@ class DetaEncoder(DetaPreTrainedModel):
self.dropout = config.dropout
self.layers = nn.ModuleList([DetaEncoderLayer(config) for _ in range(config.encoder_layers)])
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@ -1265,9 +1267,13 @@ class DetaDecoder(DetaPreTrainedModel):
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings,
reference_points_input,
spatial_shapes,
level_start_index,
encoder_hidden_states,
encoder_attention_mask,
None,
output_attentions,
)
else:
layer_outputs = decoder_layer(
@ -1712,6 +1718,11 @@ class DetaModel(DetaPreTrainedModel):
init_reference_points = reference_points
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_logits)))
query_embed, target = torch.split(pos_trans_out, num_channels, dim=2)
topk_feats = torch.stack(
[object_query_embedding[b][topk_proposals[b]] for b in range(batch_size)]
).detach()
target = target + self.pix_trans_norm(self.pix_trans(topk_feats))
else:
query_embed, target = torch.split(query_embeds, num_channels, dim=1)
query_embed = query_embed.unsqueeze(0).expand(batch_size, -1, -1)

View File

@ -750,7 +750,12 @@ class DonutSwinEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
layer_module.__call__,
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
else:
layer_outputs = layer_module(

View File

@ -826,7 +826,12 @@ class SwinEncoder(nn.Module):
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
layer_module.__call__,
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
else:
layer_outputs = layer_module(