From e19c12e094678dd0b355c1bdd529d35abcb7b34c Mon Sep 17 00:00:00 2001 From: Sangbum Daniel Choi <34004152+SangbumChoi@users.noreply.github.com> Date: Fri, 2 Feb 2024 00:07:44 +0900 Subject: [PATCH] enable graident checkpointing in DetaObjectDetection and add tests in Swin/Donut_Swin (#28615) * enable graident checkpointing in DetaObjectDetection * fix missing part in original DETA * make style * make fix-copies * Revert "make fix-copies" This reverts commit 4041c86c29248f1673e8173b677c20b5a4511358. * remove fix-copies of DetaDecoder * enable swin gradient checkpointing * fix gradient checkpointing in donut_swin * add tests for deta/swin/donut * Revert "fix gradient checkpointing in donut_swin" This reverts commit 1cf345e34d3cc0e09eb800d9895805b1dd9b474d. * change supports_gradient_checkpointing pipeline to PreTrainedModel * Revert "add tests for deta/swin/donut" This reverts commit 6056ffbb1eddc3cb3a99e4ebb231ae3edf295f5b. * Revert "Revert "fix gradient checkpointing in donut_swin"" This reverts commit 24e25d0a14891241de58a0d86f817d0b5d2a341f. * Simple revert * enable deformable detr gradient checkpointing * add gradient in encoder --- .../deformable_detr/modeling_deformable_detr.py | 1 + src/transformers/models/deta/modeling_deta.py | 13 ++++++++++++- .../models/donut/modeling_donut_swin.py | 7 ++++++- src/transformers/models/swin/modeling_swin.py | 7 ++++++- 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index aea5b60bdee..001d379e9a1 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1050,6 +1050,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"] + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index eb0336f85bc..b98b2318508 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -947,6 +947,7 @@ class DetaPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "pixel_values" _no_split_modules = [r"DetaBackboneWithPositionalEncodings", r"DetaEncoderLayer", r"DetaDecoderLayer"] + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -1043,6 +1044,7 @@ class DetaEncoder(DetaPreTrainedModel): self.dropout = config.dropout self.layers = nn.ModuleList([DetaEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1265,9 +1267,13 @@ class DetaDecoder(DetaPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, + reference_points_input, + spatial_shapes, + level_start_index, encoder_hidden_states, encoder_attention_mask, - None, + output_attentions, ) else: layer_outputs = decoder_layer( @@ -1712,6 +1718,11 @@ class DetaModel(DetaPreTrainedModel): init_reference_points = reference_points pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_logits))) query_embed, target = torch.split(pos_trans_out, num_channels, dim=2) + + topk_feats = torch.stack( + [object_query_embedding[b][topk_proposals[b]] for b in range(batch_size)] + ).detach() + target = target + self.pix_trans_norm(self.pix_trans(topk_feats)) else: query_embed, target = torch.split(query_embeds, num_channels, dim=1) query_embed = query_embed.unsqueeze(0).expand(batch_size, -1, -1) diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 65af7f5b1c2..ed79b8ef8ec 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -750,7 +750,12 @@ class DonutSwinEncoder(nn.Module): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 967f9440090..a3f0643512a 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -826,7 +826,12 @@ class SwinEncoder(nn.Module): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, ) else: layer_outputs = layer_module(