mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
enable graident checkpointing in DetaObjectDetection and add tests in Swin/Donut_Swin (#28615)
* enable graident checkpointing in DetaObjectDetection * fix missing part in original DETA * make style * make fix-copies * Revert "make fix-copies" This reverts commit 4041c86c29248f1673e8173b677c20b5a4511358. * remove fix-copies of DetaDecoder * enable swin gradient checkpointing * fix gradient checkpointing in donut_swin * add tests for deta/swin/donut * Revert "fix gradient checkpointing in donut_swin" This reverts commit 1cf345e34d3cc0e09eb800d9895805b1dd9b474d. * change supports_gradient_checkpointing pipeline to PreTrainedModel * Revert "add tests for deta/swin/donut" This reverts commit 6056ffbb1eddc3cb3a99e4ebb231ae3edf295f5b. * Revert "Revert "fix gradient checkpointing in donut_swin"" This reverts commit 24e25d0a14891241de58a0d86f817d0b5d2a341f. * Simple revert * enable deformable detr gradient checkpointing * add gradient in encoder
This commit is contained in:
parent
7bc6d76396
commit
e19c12e094
@ -1050,6 +1050,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
@ -947,6 +947,7 @@ class DetaPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "pixel_values"
|
||||
_no_split_modules = [r"DetaBackboneWithPositionalEncodings", r"DetaEncoderLayer", r"DetaDecoderLayer"]
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
@ -1043,6 +1044,7 @@ class DetaEncoder(DetaPreTrainedModel):
|
||||
|
||||
self.dropout = config.dropout
|
||||
self.layers = nn.ModuleList([DetaEncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -1265,9 +1267,13 @@ class DetaDecoder(DetaPreTrainedModel):
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
position_embeddings,
|
||||
reference_points_input,
|
||||
spatial_shapes,
|
||||
level_start_index,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
None,
|
||||
output_attentions,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
@ -1712,6 +1718,11 @@ class DetaModel(DetaPreTrainedModel):
|
||||
init_reference_points = reference_points
|
||||
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_logits)))
|
||||
query_embed, target = torch.split(pos_trans_out, num_channels, dim=2)
|
||||
|
||||
topk_feats = torch.stack(
|
||||
[object_query_embedding[b][topk_proposals[b]] for b in range(batch_size)]
|
||||
).detach()
|
||||
target = target + self.pix_trans_norm(self.pix_trans(topk_feats))
|
||||
else:
|
||||
query_embed, target = torch.split(query_embeds, num_channels, dim=1)
|
||||
query_embed = query_embed.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
@ -750,7 +750,12 @@ class DonutSwinEncoder(nn.Module):
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
|
@ -826,7 +826,12 @@ class SwinEncoder(nn.Module):
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
layer_head_mask,
|
||||
output_attentions,
|
||||
always_partition,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer_module(
|
||||
|
Loading…
Reference in New Issue
Block a user