diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index aea5b60bdee..001d379e9a1 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1050,6 +1050,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): main_input_name = "pixel_values" supports_gradient_checkpointing = True _no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"] + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index eb0336f85bc..b98b2318508 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -947,6 +947,7 @@ class DetaPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "pixel_values" _no_split_modules = [r"DetaBackboneWithPositionalEncodings", r"DetaEncoderLayer", r"DetaDecoderLayer"] + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -1043,6 +1044,7 @@ class DetaEncoder(DetaPreTrainedModel): self.dropout = config.dropout self.layers = nn.ModuleList([DetaEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -1265,9 +1267,13 @@ class DetaDecoder(DetaPreTrainedModel): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, + reference_points_input, + spatial_shapes, + level_start_index, encoder_hidden_states, encoder_attention_mask, - None, + output_attentions, ) else: layer_outputs = decoder_layer( @@ -1712,6 +1718,11 @@ class DetaModel(DetaPreTrainedModel): init_reference_points = reference_points pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_logits))) query_embed, target = torch.split(pos_trans_out, num_channels, dim=2) + + topk_feats = torch.stack( + [object_query_embedding[b][topk_proposals[b]] for b in range(batch_size)] + ).detach() + target = target + self.pix_trans_norm(self.pix_trans(topk_feats)) else: query_embed, target = torch.split(query_embeds, num_channels, dim=1) query_embed = query_embed.unsqueeze(0).expand(batch_size, -1, -1) diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 65af7f5b1c2..ed79b8ef8ec 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -750,7 +750,12 @@ class DonutSwinEncoder(nn.Module): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, ) else: layer_outputs = layer_module( diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 967f9440090..a3f0643512a 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -826,7 +826,12 @@ class SwinEncoder(nn.Module): if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions + layer_module.__call__, + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, ) else: layer_outputs = layer_module(