From 84d19be41e0131e6f2a660fe6af8b77094906af7 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Jun 2025 13:24:48 +0100 Subject: [PATCH] Apply GradientCheckpointingLayer to the whole repo (#38913) * first batch (4) * align * altclip * beit * bert * yolos * dino, pvt_v2 * bark, bart, bert_generation * big_bird, biogpt * blnderbot, bloom * bridgetower * camambert, canine, chameleon * chinese clip, clap, clip * codegen, conditional detr, convbert * dab_detr, data2vec * dbrx, deberta * deberta, decicion_tranformer, deformable_detr * deit, deta, mctct * detr, dinov2, distilbert * donut, dpt, electra * ernie, esm, falcon * flava, fnet, falcon_mamba * focalnet, git, gpt2 * gpt - bigcode, neo, neox * gptj, groupvit * idefics2, idefics3 * ijepa, imagegpt, internvl * jetmoe, kosmos2, layoutlm * layoutlm2-3, led * lilt, longformer, longt5, luke * m2m, mamba1-2 * marian, markuplm, mask2former * maskformer * mbart, megatron_bert, mimi * mixtral, mlcd * mobilevit1-2, modernbert * moshi, mpt, mra * mt5, musicgen * mvp, nemotron * nllb_moe * nystromformer, omdet_turbo * opt, owlvit, owlv2 * pegasus, pegasus_x, presimmon * phimoe, pix2struct, pixtral * plbart, pop2piano, prophetnet * qwen2* * qwen2, qwen3 moe, rec gemma * rembert * roberta * roberta prelayernorm * roc_bert, roformer, rwkv * sam, sam_hq * seggpt, smolvlm, speech_to_text * splinter, stablelm, swin * swin2sr, switch_transformer, t5, table_transformer * tapas, time_series_tranformer, timesformer * trocr, tvp, umt5 * videomae, vilt, visual_bert * vit, vit_mae, vit_msn * vitpose_backbone, vits, vivit * whisper. x_clip, xglm * xlm_roberta, xmod * yoso * zamba * vitdet, wav2vec2, wav2vec2_bert * unispeech, wav2vec_conformer * wavlm * speecht5 * swinv2 * sew / _d * seamless_mt4 / _v2 * deprecated models update * bros * gemma2, gemma3 * got, hiera, hubert, llama4, mllama, oneformer, phi, olmoe, informer * fixup * Add use_cache=False and past_key_value=None to GradientCheckpointingLayer * fixup * fix prophetnet * fix bigbird_pegasus * fix blenderbot * fix mbart * fix mvp * fix zamba2 * fix bart * fix blenderbot_small * fix codegen * Update gradient checkpointing layer to support more past_key_values arg names * fix data2vec vision * fix deformable_detr * fix gptj * fix led * fix m2m_100 * add comment * fix nnlb_moe * Fix pegasus_x * fix plbart * udop * fix-copies: beit, wav2vec2 * fix gpt_bigcode * fixup * fix t5 * fix switch_transformers * fix longt5 * fix mt5 * update tapas * fix blip2 * update blip * fix musicgen * fix gpt2, trocr * fix copies * !!! Revert zamba, mllama * update autoformer * update bros * update args / kwargs for BERT and copies * 2nd round of updates * update conditional detr * Pass encoder_hidden_states as positional arg * Update to pass encoder_decoder_position_bias as positional arg * fixup * biogpt modular * modular gemma2 * modular gemma3 * modular gpt_neox * modular informer * modular internvl * modular mixtral * modular mlcd * modular modernbert * modular phi * modular qwen2_5_omni * modular qwen2_5_vl * modular sam_hq * modular sew * wav2vec2_bert * modular wav2vec2_conformer * modular wavlm * fixup * Update by modular instructblipvideo * modular data2vec_audio * nit modular mistral * apply modular minimax * fix modular moonshine * revert zamba2 * fix mask2former * refactor idefics --- src/transformers/modeling_layers.py | 35 ++++++ .../models/align/modeling_align.py | 33 ++---- .../models/altclip/modeling_altclip.py | 56 +++------- .../modeling_audio_spectrogram_transformer.py | 14 +-- .../models/autoformer/modeling_autoformer.py | 75 +++++-------- src/transformers/models/bark/modeling_bark.py | 30 ++--- src/transformers/models/bart/modeling_bart.py | 67 ++++------- src/transformers/models/beit/modeling_beit.py | 32 ++---- src/transformers/models/bert/modeling_bert.py | 33 ++---- .../modeling_bert_generation.py | 33 ++---- .../models/big_bird/modeling_big_bird.py | 45 +++----- .../modeling_bigbird_pegasus.py | 82 +++++--------- .../models/biogpt/modeling_biogpt.py | 39 +++---- .../models/biogpt/modular_biogpt.py | 36 ++---- .../models/blenderbot/modeling_blenderbot.py | 67 ++++------- .../modeling_blenderbot_small.py | 67 ++++------- src/transformers/models/blip/modeling_blip.py | 2 +- .../models/blip_2/modeling_blip_2.py | 12 +- .../models/bloom/modeling_bloom.py | 36 ++---- .../bridgetower/modeling_bridgetower.py | 33 ++---- src/transformers/models/bros/modeling_bros.py | 41 ++----- .../models/camembert/modeling_camembert.py | 33 ++---- .../models/canine/modeling_canine.py | 14 +-- .../models/chameleon/modeling_chameleon.py | 37 +++---- .../chinese_clip/modeling_chinese_clip.py | 50 +++------ src/transformers/models/clap/modeling_clap.py | 46 +++----- src/transformers/models/clip/modeling_clip.py | 24 ++-- .../models/clipseg/modeling_clipseg.py | 25 ++--- .../models/codegen/modeling_codegen.py | 36 ++---- .../modeling_conditional_detr.py | 40 +++---- .../models/convbert/modeling_convbert.py | 30 ++--- .../models/dab_detr/modeling_dab_detr.py | 59 ++++------ .../data2vec/modeling_data2vec_audio.py | 27 ++--- .../models/data2vec/modeling_data2vec_text.py | 33 ++---- .../data2vec/modeling_data2vec_vision.py | 32 ++---- .../models/data2vec/modular_data2vec_audio.py | 3 +- src/transformers/models/dbrx/modeling_dbrx.py | 36 ++---- .../models/deberta/modeling_deberta.py | 30 ++--- .../models/deberta_v2/modeling_deberta_v2.py | 30 ++--- .../modeling_decision_transformer.py | 39 +++---- .../modeling_deformable_detr.py | 74 ++++--------- src/transformers/models/deit/modeling_deit.py | 13 +-- .../models/deprecated/deta/modeling_deta.py | 36 ++---- .../models/deprecated/mctct/modeling_mctct.py | 22 ++-- .../models/deprecated/nezha/modeling_nezha.py | 33 ++---- .../open_llama/modeling_open_llama.py | 30 ++--- .../deprecated/qdqbert/modeling_qdqbert.py | 38 ++----- .../models/deprecated/realm/modeling_realm.py | 33 ++---- .../modeling_speech_to_text_2.py | 39 +++---- .../modeling_trajectory_transformer.py | 14 +-- .../models/deprecated/tvlt/modeling_tvlt.py | 24 +--- .../vit_hybrid/modeling_vit_hybrid.py | 13 +-- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 76 ++++--------- src/transformers/models/detr/modeling_detr.py | 31 ++---- .../models/dinov2/modeling_dinov2.py | 13 +-- .../modeling_dinov2_with_registers.py | 13 +-- .../models/distilbert/modeling_distilbert.py | 24 ++-- .../models/donut/modeling_donut_swin.py | 19 +--- src/transformers/models/dpt/modeling_dpt.py | 13 +-- .../models/electra/modeling_electra.py | 33 ++---- .../models/ernie/modeling_ernie.py | 33 ++---- src/transformers/models/esm/modeling_esm.py | 33 ++---- .../models/falcon/modeling_falcon.py | 42 +++---- .../falcon_mamba/modeling_falcon_mamba.py | 20 ++-- .../models/flava/modeling_flava.py | 14 +-- src/transformers/models/fnet/modeling_fnet.py | 8 +- .../models/focalnet/modeling_focalnet.py | 12 +- .../models/gemma2/modeling_gemma2.py | 39 +++---- .../models/gemma2/modular_gemma2.py | 39 +++---- .../models/gemma3/modeling_gemma3.py | 42 +++---- .../models/gemma3/modular_gemma3.py | 42 +++---- src/transformers/models/git/modeling_git.py | 52 +++------ .../models/got_ocr2/modeling_got_ocr2.py | 11 +- src/transformers/models/gpt2/modeling_gpt2.py | 41 +++---- .../gpt_bigcode/modeling_gpt_bigcode.py | 42 +++---- .../models/gpt_neo/modeling_gpt_neo.py | 33 ++---- .../models/gpt_neox/modeling_gpt_neox.py | 41 +++---- .../models/gpt_neox/modular_gpt_neox.py | 41 +++---- src/transformers/models/gptj/modeling_gptj.py | 36 ++---- .../models/groupvit/modeling_groupvit.py | 24 ++-- .../models/hiera/modeling_hiera.py | 10 +- .../models/hubert/modeling_hubert.py | 47 +++----- .../models/idefics/modeling_idefics.py | 100 ++++------------- src/transformers/models/idefics/vision.py | 24 ++-- .../models/idefics2/modeling_idefics2.py | 21 ++-- .../models/idefics3/modeling_idefics3.py | 21 ++-- .../models/ijepa/modeling_ijepa.py | 13 +-- .../models/imagegpt/modeling_imagegpt.py | 36 ++---- .../models/informer/modeling_informer.py | 78 +++++-------- .../models/informer/modular_informer.py | 33 ++---- .../instructblip/modeling_instructblip.py | 12 +- .../modeling_instructblipvideo.py | 12 +- .../models/internvl/modeling_internvl.py | 10 +- .../models/internvl/modular_internvl.py | 10 +- .../models/jetmoe/modeling_jetmoe.py | 34 ++---- .../models/kosmos2/modeling_kosmos2.py | 66 ++++------- .../models/layoutlm/modeling_layoutlm.py | 33 ++---- .../models/layoutlmv2/modeling_layoutlmv2.py | 30 ++--- .../models/layoutlmv3/modeling_layoutlmv3.py | 30 ++--- src/transformers/models/led/modeling_led.py | 73 ++++-------- src/transformers/models/lilt/modeling_lilt.py | 27 ++--- .../models/llama4/modeling_llama4.py | 63 ++++------- .../models/longformer/modeling_longformer.py | 33 ++---- .../models/longt5/modeling_longt5.py | 51 +++------ src/transformers/models/luke/modeling_luke.py | 27 ++--- .../models/m2m_100/modeling_m2m_100.py | 69 ++++-------- .../models/mamba/modeling_mamba.py | 20 ++-- .../models/mamba2/modeling_mamba2.py | 20 ++-- .../models/marian/modeling_marian.py | 67 ++++------- .../models/markuplm/modeling_markuplm.py | 33 ++---- .../mask2former/modeling_mask2former.py | 60 +++++----- .../models/maskformer/modeling_maskformer.py | 32 ++---- .../maskformer/modeling_maskformer_swin.py | 25 ++--- .../models/mbart/modeling_mbart.py | 67 ++++------- .../megatron_bert/modeling_megatron_bert.py | 33 ++---- src/transformers/models/mimi/modeling_mimi.py | 33 ++---- .../models/minimax/modeling_minimax.py | 3 +- .../models/mixtral/modeling_mixtral.py | 42 +++---- .../models/mixtral/modular_mixtral.py | 42 +++---- src/transformers/models/mlcd/modeling_mlcd.py | 24 ++-- src/transformers/models/mlcd/modular_mlcd.py | 21 +--- .../models/mobilevit/modeling_mobilevit.py | 11 +- .../mobilevitv2/modeling_mobilevitv2.py | 11 +- .../models/modernbert/modeling_modernbert.py | 33 ++---- .../models/modernbert/modular_modernbert.py | 33 ++---- .../models/moonshine/modeling_moonshine.py | 4 +- .../models/moonshine/modular_moonshine.py | 4 +- .../models/moshi/modeling_moshi.py | 63 ++++------- src/transformers/models/mpt/modeling_mpt.py | 30 ++--- src/transformers/models/mra/modeling_mra.py | 12 +- src/transformers/models/mt5/modeling_mt5.py | 51 +++------ .../models/musicgen/modeling_musicgen.py | 41 +++---- .../modeling_musicgen_melody.py | 30 ++--- src/transformers/models/mvp/modeling_mvp.py | 73 ++++-------- .../models/nemotron/modeling_nemotron.py | 36 ++---- .../models/nllb_moe/modeling_nllb_moe.py | 71 ++++-------- .../nystromformer/modeling_nystromformer.py | 13 +-- .../models/olmoe/modeling_olmoe.py | 39 +++---- .../omdet_turbo/modeling_omdet_turbo.py | 47 +++----- .../models/oneformer/modeling_oneformer.py | 8 +- src/transformers/models/opt/modeling_opt.py | 38 +++---- .../models/owlv2/modeling_owlv2.py | 24 ++-- .../models/owlvit/modeling_owlvit.py | 24 ++-- .../models/pegasus/modeling_pegasus.py | 67 ++++------- .../models/pegasus_x/modeling_pegasus_x.py | 59 ++++------ .../models/persimmon/modeling_persimmon.py | 38 +++---- src/transformers/models/phi/modeling_phi.py | 39 +++---- src/transformers/models/phi/modular_phi.py | 39 +++---- .../models/phimoe/modeling_phimoe.py | 39 +++---- .../models/pix2struct/modeling_pix2struct.py | 72 ++++-------- .../models/pixtral/modeling_pixtral.py | 26 ++--- .../models/plbart/modeling_plbart.py | 67 ++++------- .../models/pop2piano/modeling_pop2piano.py | 48 +++----- .../models/prophetnet/modeling_prophetnet.py | 76 ++++--------- .../models/pvt_v2/modeling_pvt_v2.py | 8 +- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 104 ++++++------------ .../qwen2_5_omni/modular_qwen2_5_omni.py | 30 ++--- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 44 +++----- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 10 +- .../qwen2_audio/modeling_qwen2_audio.py | 24 ++-- .../models/qwen2_moe/modeling_qwen2_moe.py | 39 +++---- .../models/qwen2_vl/modeling_qwen2_vl.py | 44 +++----- .../models/qwen3_moe/modeling_qwen3_moe.py | 42 +++---- .../modeling_recurrent_gemma.py | 10 +- .../models/rembert/modeling_rembert.py | 33 ++---- .../models/roberta/modeling_roberta.py | 33 ++---- .../modeling_roberta_prelayernorm.py | 33 ++---- .../models/roc_bert/modeling_roc_bert.py | 33 ++---- .../models/roformer/modeling_roformer.py | 36 ++---- src/transformers/models/rwkv/modeling_rwkv.py | 14 +-- src/transformers/models/sam/modeling_sam.py | 11 +- .../models/sam_hq/modeling_sam_hq.py | 15 +-- .../models/sam_hq/modular_sam_hq.py | 12 +- .../seamless_m4t/modeling_seamless_m4t.py | 79 +++++-------- .../modeling_seamless_m4t_v2.py | 99 ++++++----------- .../models/seggpt/modeling_seggpt.py | 14 +-- src/transformers/models/sew/modeling_sew.py | 31 ++---- src/transformers/models/sew/modular_sew.py | 14 +-- .../models/sew_d/modeling_sew_d.py | 44 +++----- .../models/smolvlm/modeling_smolvlm.py | 21 ++-- .../speech_to_text/modeling_speech_to_text.py | 64 ++++------- .../models/speecht5/modeling_speecht5.py | 81 +++++--------- .../models/splinter/modeling_splinter.py | 33 ++---- .../models/stablelm/modeling_stablelm.py | 36 ++---- src/transformers/models/swin/modeling_swin.py | 19 +--- .../models/swin2sr/modeling_swin2sr.py | 10 +- .../models/swinv2/modeling_swinv2.py | 20 ++-- .../modeling_switch_transformers.py | 54 +++------ src/transformers/models/t5/modeling_t5.py | 51 +++------ .../modeling_table_transformer.py | 31 ++---- .../models/tapas/modeling_tapas.py | 33 ++---- .../modeling_time_series_transformer.py | 67 ++++------- .../timesformer/modeling_timesformer.py | 12 +- .../models/trocr/modeling_trocr.py | 41 +++---- src/transformers/models/tvp/modeling_tvp.py | 14 +-- src/transformers/models/udop/modeling_udop.py | 13 ++- src/transformers/models/umt5/modeling_umt5.py | 44 +++----- .../models/unispeech/modeling_unispeech.py | 47 +++----- .../unispeech_sat/modeling_unispeech_sat.py | 47 +++----- .../models/videomae/modeling_videomae.py | 23 +--- src/transformers/models/vilt/modeling_vilt.py | 14 +-- .../visual_bert/modeling_visual_bert.py | 14 +-- src/transformers/models/vit/modeling_vit.py | 13 +-- .../models/vit_mae/modeling_vit_mae.py | 23 +--- .../models/vit_msn/modeling_vit_msn.py | 13 +-- .../models/vitdet/modeling_vitdet.py | 13 +-- .../modeling_vitpose_backbone.py | 14 +-- src/transformers/models/vits/modeling_vits.py | 24 ++-- .../models/vivit/modeling_vivit.py | 13 +-- .../models/wav2vec2/modeling_wav2vec2.py | 47 +++----- .../wav2vec2_bert/modeling_wav2vec2_bert.py | 27 ++--- .../wav2vec2_bert/modular_wav2vec2_bert.py | 27 ++--- .../modeling_wav2vec2_conformer.py | 38 ++----- .../modular_wav2vec2_conformer.py | 24 ++-- .../models/wavlm/modeling_wavlm.py | 63 ++++------- .../models/wavlm/modular_wavlm.py | 49 +++------ .../models/whisper/modeling_whisper.py | 65 ++++------- .../models/x_clip/modeling_x_clip.py | 47 +++----- src/transformers/models/xglm/modeling_xglm.py | 41 +++---- .../xlm_roberta/modeling_xlm_roberta.py | 33 ++---- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 33 ++---- src/transformers/models/xmod/modeling_xmod.py | 36 ++---- .../models/yolos/modeling_yolos.py | 13 +-- src/transformers/models/yoso/modeling_yoso.py | 13 +-- 224 files changed, 2513 insertions(+), 5280 deletions(-) diff --git a/src/transformers/modeling_layers.py b/src/transformers/modeling_layers.py index 57be2d8e0d7..5179cfa6571 100644 --- a/src/transformers/modeling_layers.py +++ b/src/transformers/modeling_layers.py @@ -16,6 +16,11 @@ from functools import partial import torch.nn as nn +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + class GradientCheckpointingLayer(nn.Module): """Base class for layers with gradient checkpointing. @@ -44,5 +49,35 @@ class GradientCheckpointingLayer(nn.Module): def __call__(self, *args, **kwargs): if self.gradient_checkpointing and self.training: + do_warn = False + layer_name = self.__class__.__name__ + message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting" + + if "use_cache" in kwargs and kwargs["use_cache"]: + kwargs["use_cache"] = False + message += " `use_cache=False`," + do_warn = True + + # different names for the same thing in different layers + if "past_key_value" in kwargs and kwargs["past_key_value"] is not None: + kwargs["past_key_value"] = None + message += " `past_key_value=None`," + do_warn = True + + if "past_key_values" in kwargs and kwargs["past_key_values"] is not None: + kwargs["past_key_values"] = None + message += " `past_key_values=None`," + do_warn = True + + if "layer_past" in kwargs and kwargs["layer_past"] is not None: + kwargs["layer_past"] = None + message += " `layer_past=None`," + do_warn = True + + # warn if anything was changed + if do_warn: + message = message.rstrip(",") + "." + logger.warning(message) + return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args) return super().__call__(*args, **kwargs) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 952fe0bdc9e..6ff99d6a491 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithNoAttention, BaseModelOutputWithPastAndCrossAttentions, @@ -827,7 +828,7 @@ class AlignTextOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText -class AlignTextLayer(nn.Module): +class AlignTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -953,27 +954,15 @@ class AlignTextEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 3e917940809..41d4c595c8d 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -23,6 +23,7 @@ import torch.nn as nn import torch.utils.checkpoint from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -418,7 +419,7 @@ class AltRobertaOutput(nn.Module): # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->AltRoberta -class AltRobertaLayer(nn.Module): +class AltRobertaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -544,27 +545,15 @@ class AltRobertaEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: @@ -732,7 +721,7 @@ class AltCLIPMLP(nn.Module): return hidden_states -class AltCLIPEncoderLayer(nn.Module): +class AltCLIPEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: AltCLIPConfig): super().__init__() self.embed_dim = config.hidden_size @@ -848,21 +837,12 @@ class AltCLIPEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index d3ccf24153b..602de3ff72b 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -282,7 +283,7 @@ class ASTOutput(nn.Module): # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->AST,VIT->AST -class ASTLayer(nn.Module): +class ASTLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ASTConfig) -> None: @@ -349,16 +350,7 @@ class ASTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) - + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 6db63b4945f..be8c8c621a1 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -30,6 +30,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput from ...modeling_utils import PreTrainedModel from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput @@ -670,7 +671,7 @@ class AutoformerAttention(nn.Module): return attn_output, attn_weights_reshaped, past_key_value -class AutoformerEncoderLayer(nn.Module): +class AutoformerEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: AutoformerConfig): super().__init__() self.embed_dim = config.d_model @@ -744,7 +745,7 @@ class AutoformerEncoderLayer(nn.Module): return outputs -class AutoformerDecoderLayer(nn.Module): +class AutoformerDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: AutoformerConfig): super().__init__() self.embed_dim = config.d_model @@ -1042,21 +1043,12 @@ class AutoformerEncoder(AutoformerPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1186,6 +1178,12 @@ class AutoformerDecoder(AutoformerPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + input_shape = inputs_embeds.size()[:-1] # expand encoder attention mask @@ -1228,38 +1226,17 @@ class AutoformerDecoder(AutoformerPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) (hidden_states, residual_trend) = layer_outputs[0] trend = trend + residual_trend diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 4ee608d9aec..8ace5221c08 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -31,6 +31,7 @@ from ...generation.logits_process import ( ) from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_utils import PreTrainedModel, get_parameter_device from ...utils import ( @@ -309,7 +310,7 @@ class BarkMLP(nn.Module): return hidden_states -class BarkBlock(nn.Module): +class BarkBlock(GradientCheckpointingLayer): def __init__(self, config, is_causal=False): super().__init__() @@ -606,25 +607,14 @@ class BarkCausalModel(BarkPreTrainedModel, GenerationMixin): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - attention_mask, - head_mask[i], - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - past_key_values=past_layer_key_values, - attention_mask=attention_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - ) + outputs = block( + hidden_states, + past_key_values=past_layer_key_values, + attention_mask=attention_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = outputs[0] diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index f0adc76924f..994bf9d85dc 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -270,7 +271,7 @@ class BartAttention(nn.Module): return attn_output, attn_weights, past_key_value -class BartEncoderLayer(nn.Module): +class BartEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -341,7 +342,7 @@ class BartEncoderLayer(nn.Module): return outputs -class BartDecoderLayer(nn.Module): +class BartDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -875,21 +876,12 @@ class BartEncoder(BartPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1137,35 +1129,18 @@ class BartDecoder(BartPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 086e62561fd..347471fc7f7 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -26,6 +26,7 @@ from torch import Tensor, nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BackboneOutput, BaseModelOutput, @@ -497,7 +498,7 @@ class BeitOutput(nn.Module): return hidden_states -class BeitLayer(nn.Module): +class BeitLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0) -> None: @@ -525,7 +526,7 @@ class BeitLayer(nn.Module): output_attentions: bool = False, relative_position_bias: Optional[torch.Tensor] = None, interpolate_pos_encoding: bool = False, - resolution: Optional[tuple[int]] = None, + resolution: Optional[tuple[int, int]] = None, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention @@ -695,25 +696,14 @@ class BeitEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - relative_position_bias, - interpolate_pos_encoding, - resolution, - ) - else: - layer_outputs = layer_module( - hidden_states, - layer_head_mask, - output_attentions, - relative_position_bias, - interpolate_pos_encoding, - resolution, - ) + layer_outputs = layer_module( + hidden_states, + head_mask=layer_head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 12080dfff6f..e508a98614a 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -522,7 +523,7 @@ class BertOutput(nn.Module): return hidden_states -class BertLayer(nn.Module): +class BertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -647,27 +648,15 @@ class BertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 959a3cce077..bd65e88ae85 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -23,6 +23,7 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer @@ -275,7 +276,7 @@ class BertGenerationOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->BertGeneration -class BertGenerationLayer(nn.Module): +class BertGenerationLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -401,27 +402,15 @@ class BertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index e06c8f87d5f..a7b3c064703 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1419,7 +1420,7 @@ class BigBirdOutput(nn.Module): return hidden_states -class BigBirdLayer(nn.Module): +class BigBirdLayer(GradientCheckpointingLayer): def __init__(self, config, seed=None): super().__init__() self.config = config @@ -1593,35 +1594,19 @@ class BigBirdEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - band_mask, - from_mask, - to_mask, - blocked_encoder_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - band_mask, - from_mask, - to_mask, - blocked_encoder_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + band_mask, + from_mask, + to_mask, + blocked_encoder_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index bc72d16bf54..465b94e13be 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1333,7 +1334,7 @@ class BigBirdPegasusDecoderAttention(nn.Module): return attn_output, attn_weights, past_key_value -class BigBirdPegasusEncoderLayer(nn.Module): +class BigBirdPegasusEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BigBirdPegasusConfig, seed=None): super().__init__() self.attention_type = config.attention_type @@ -1420,7 +1421,7 @@ class BigBirdPegasusEncoderLayer(nn.Module): self.self_attn.set_attention_type(value) -class BigBirdPegasusDecoderLayer(nn.Module): +class BigBirdPegasusDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -1947,31 +1948,17 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - band_mask, - from_mask, - to_mask, - blocked_encoder_mask, - blocked_encoder_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - band_mask=band_mask, - from_mask=from_mask, - to_mask=to_mask, - from_blocked_mask=blocked_encoder_mask, - to_blocked_mask=blocked_encoder_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + band_mask=band_mask, + from_mask=from_mask, + to_mask=to_mask, + from_blocked_mask=blocked_encoder_mask, + to_blocked_mask=blocked_encoder_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -2297,35 +2284,18 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index a1fba008841..8a0c43eafd3 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -20,7 +20,6 @@ # limitations under the License. import math -from functools import partial from typing import Callable, Optional, Union import torch @@ -32,6 +31,7 @@ from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -248,7 +248,7 @@ class BioGptAttention(nn.Module): return attn_output, attn_weights, past_key_value -class BioGptDecoderLayer(nn.Module): +class BioGptDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.hidden_size @@ -646,30 +646,17 @@ class BioGptModel(BioGptPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - head_mask[idx] if head_mask is not None else None, - None, - output_attentions, - use_cache, - position_ids, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - position_ids=position_ids, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_ids=position_ids, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index d639f44ffec..938b1c9d8be 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -15,7 +15,6 @@ """PyTorch BioGPT model.""" import math -from functools import partial from typing import Optional, Union import torch @@ -473,30 +472,17 @@ class BioGptModel(BioGptPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - head_mask[idx] if head_mask is not None else None, - None, - output_attentions, - use_cache, - position_ids, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - position_ids=position_ids, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_ids=position_ids, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 6be0c98869a..7821e1c7b4f 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -270,7 +271,7 @@ class BlenderbotAttention(nn.Module): # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT -class BlenderbotEncoderLayer(nn.Module): +class BlenderbotEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model @@ -339,7 +340,7 @@ class BlenderbotEncoderLayer(nn.Module): # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT -class BlenderbotDecoderLayer(nn.Module): +class BlenderbotDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -825,21 +826,12 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1090,35 +1082,18 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 99666356d09..550e5122192 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -254,7 +255,7 @@ class BlenderbotSmallAttention(nn.Module): # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL -class BlenderbotSmallEncoderLayer(nn.Module): +class BlenderbotSmallEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -326,7 +327,7 @@ class BlenderbotSmallEncoderLayer(nn.Module): # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL -class BlenderbotSmallDecoderLayer(nn.Module): +class BlenderbotSmallDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -812,21 +813,12 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1073,35 +1065,18 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + causal_mask, + encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 3967f8bce47..e43b79595ca 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -552,7 +552,7 @@ class BlipEncoder(nn.Module): layer_outputs = encoder_layer( hidden_states, - attention_mask, + attention_mask=attention_mask, output_attentions=output_attentions, ) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 4382296969e..04b20b7513a 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -531,7 +531,7 @@ class Blip2Encoder(nn.Module): layer_outputs = encoder_layer( hidden_states, - attention_mask, + attention_mask=attention_mask, output_attentions=output_attentions, ) @@ -992,11 +992,11 @@ class Blip2QFormerEncoder(nn.Module): hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + query_length=query_length, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 66dfc0c1fa8..41442684451 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -27,6 +27,7 @@ from torch.nn import functional as F from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -366,7 +367,7 @@ class BloomMLP(nn.Module): return output -class BloomBlock(nn.Module): +class BloomBlock(GradientCheckpointingLayer): def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None): super().__init__() hidden_size = config.hidden_size @@ -605,29 +606,16 @@ class BloomModel(BloomPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - alibi, - causal_mask, - past_key_values, - head_mask[i], - use_cache, - output_attentions, - cache_position, - ) - else: - outputs = block( - hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - cache_position=cache_position, - ) + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + cache_position=cache_position, + ) hidden_states = outputs[0] if use_cache: diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 10db0bbb62f..bb9ac6c1bd3 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN, QuickGELUActivation +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -662,7 +663,7 @@ class BridgeTowerBertCrossLayer(nn.Module): return layer_output -class BridgeTowerTextLayer(nn.Module): +class BridgeTowerTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -788,27 +789,15 @@ class BridgeTowerTextEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index 33d1a44a2fe..aab9d62d646 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -428,7 +429,7 @@ class BrosOutput(nn.Module): return hidden_states -class BrosLayer(nn.Module): +class BrosLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -550,34 +551,16 @@ class BrosEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - bbox_pos_emb, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states=hidden_states, - bbox_pos_emb=bbox_pos_emb, - attention_mask=attention_mask, - head_mask=layer_head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + bbox_pos_emb, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 1b4a52295f2..3dff8f3b2cf 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -478,7 +479,7 @@ class CamembertOutput(nn.Module): # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->Camembert -class CamembertLayer(nn.Module): +class CamembertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -604,27 +605,15 @@ class CamembertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index d55c600d05d..a5f5552b78e 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, ModelOutput, @@ -672,7 +673,7 @@ class CanineOutput(nn.Module): return hidden_states -class CanineLayer(nn.Module): +class CanineLayer(GradientCheckpointingLayer): def __init__( self, config, @@ -779,16 +780,7 @@ class CanineEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 25b9923620a..0aaf197dbe8 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -383,7 +384,7 @@ class ChameleonAttention(nn.Module): # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON -class ChameleonDecoderLayer(nn.Module): +class ChameleonDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ChameleonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -458,7 +459,7 @@ class ChameleonDecoderLayer(nn.Module): return outputs -class ChameleonSwinDecoderLayer(nn.Module): +class ChameleonSwinDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ChameleonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1011,28 +1012,16 @@ class ChameleonModel(ChameleonPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 5de98397cad..8d676fc5013 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -577,7 +578,7 @@ class ChineseCLIPVisionMLP(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ChineseCLIPText -class ChineseCLIPTextLayer(nn.Module): +class ChineseCLIPTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -663,7 +664,7 @@ class ChineseCLIPTextLayer(nn.Module): return layer_output -class ChineseCLIPVisionLayer(nn.Module): +class ChineseCLIPVisionLayer(GradientCheckpointingLayer): def __init__(self, config: ChineseCLIPConfig): super().__init__() self.embed_dim = config.hidden_size @@ -816,27 +817,15 @@ class ChineseCLIPTextEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: @@ -920,17 +909,10 @@ class ChineseCLIPVisionEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 6a44e36ade4..973ae11fd37 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPooling, @@ -691,7 +692,7 @@ class ClapAudioLayer(nn.Module): # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->ClapAudio -class ClapAudioStage(nn.Module): +class ClapAudioStage(GradientCheckpointingLayer): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): super().__init__() self.config = config @@ -928,14 +929,9 @@ class ClapAudioEncoder(nn.Module): input_dimensions = self.input_resolutions[i] - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions - ) - else: - layer_outputs = layer_module( - hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition - ) + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) hidden_states = layer_outputs[0] @@ -1355,7 +1351,7 @@ class ClapTextOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->ClapText -class ClapTextLayer(nn.Module): +class ClapTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -1481,27 +1477,15 @@ class ClapTextEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index b93f63bcea9..3e8a898b35d 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int @@ -393,7 +394,7 @@ class CLIPMLP(nn.Module): return hidden_states -class CLIPEncoderLayer(nn.Module): +class CLIPEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Union[CLIPVisionConfig, CLIPTextConfig]): super().__init__() self.embed_dim = config.hidden_size @@ -575,21 +576,12 @@ class CLIPEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index c68404cb66c..cff0471c813 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -374,7 +375,7 @@ class CLIPSegMLP(nn.Module): # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg -class CLIPSegEncoderLayer(nn.Module): +class CLIPSegEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: CLIPSegConfig): super().__init__() self.embed_dim = config.hidden_size @@ -539,22 +540,12 @@ class CLIPSegEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) - + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 6a99d0fa390..d00528b1406 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -245,7 +246,7 @@ class CodeGenMLP(nn.Module): # Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen -class CodeGenBlock(nn.Module): +class CodeGenBlock(GradientCheckpointingLayer): # Ignore copy def __init__(self, config, layer_idx=None): super().__init__() @@ -437,29 +438,16 @@ class CodeGenModel(CodeGenPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - causal_mask, - position_ids, - head_mask[i], - use_cache, - output_attentions, - cache_position, - ) - else: - outputs = block( - hidden_states=hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 19b1439302e..2042817a210 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -23,6 +23,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_timm_available, logging, requires_backends @@ -827,7 +828,7 @@ class ConditionalDetrEncoderLayer(nn.Module): return outputs -class ConditionalDetrDecoderLayer(nn.Module): +class ConditionalDetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: ConditionalDetrConfig): super().__init__() self.embed_dim = config.d_model @@ -1297,31 +1298,18 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel): pos_transformation = self.query_scale(hidden_states) # apply transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - None, - object_queries, - query_position_embeddings, - query_sine_embed, - encoder_hidden_states, - encoder_attention_mask, - None, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - query_sine_embed=query_sine_embed, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - is_first=(idx == 0), - ) + + layer_outputs = decoder_layer( + hidden_states, + None, # attention_mask + object_queries, + query_position_embeddings, + query_sine_embed, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + is_first=(idx == 0), + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index 1a443c575ab..bdac1fecc1c 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, MaskedLMOutput, @@ -532,7 +533,7 @@ class ConvBertOutput(nn.Module): return hidden_states -class ConvBertLayer(nn.Module): +class ConvBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -620,25 +621,14 @@ class ConvBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) diff --git a/src/transformers/models/dab_detr/modeling_dab_detr.py b/src/transformers/models/dab_detr/modeling_dab_detr.py index c977f4b923b..47b67f4f7c9 100644 --- a/src/transformers/models/dab_detr/modeling_dab_detr.py +++ b/src/transformers/models/dab_detr/modeling_dab_detr.py @@ -23,6 +23,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -702,7 +703,7 @@ class DabDetrDecoderLayerFFN(nn.Module): # Modified from transformers.models.detr.modeling_detr.DetrEncoderLayer with DetrEncoderLayer->DabDetrEncoderLayer,DetrConfig->DabDetrConfig -class DabDetrEncoderLayer(nn.Module): +class DabDetrEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: DabDetrConfig): super().__init__() self.hidden_size = config.hidden_size @@ -764,7 +765,7 @@ class DabDetrEncoderLayer(nn.Module): # Modified from transformers.models.conditional_detr.modeling_conditional_detr.ConditionalDetrDecoderLayer with ConditionalDetr->DabDetr -class DabDetrDecoderLayer(nn.Module): +class DabDetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DabDetrConfig, is_first: bool = False): super().__init__() self.self_attn = DabDetrDecoderLayerSelfAttention(config) @@ -976,21 +977,12 @@ class DabDetrEncoder(DabDetrPreTrainedModel): # we add object_queries * pos_scaler as extra input to the encoder_layer scaled_object_queries = object_queries * pos_scales - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - scaled_object_queries, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - object_queries=scaled_object_queries, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + object_queries=scaled_object_queries, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1138,29 +1130,16 @@ class DabDetrDecoder(DabDetrPreTrainedModel): reference_anchor_size[..., 1] / obj_center[..., 3] ).unsqueeze(-1) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - None, - object_queries, - query_pos, - query_sine_embed, - encoder_hidden_states, - memory_key_padding_mask, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_pos, - query_sine_embed=query_sine_embed, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=memory_key_padding_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + None, # attention_mask + object_queries, + query_pos, + query_sine_embed, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=memory_key_padding_mask, + output_attentions=output_attentions, + ) # iter update hidden_states = layer_outputs[0] diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 9c62b5fc8c1..e4ddac37541 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -33,6 +33,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -51,7 +52,7 @@ if is_torch_flex_attn_available(): from ...integrations.flex_attention import make_flex_block_causal_mask -class Data2VecAudioConvLayer(nn.Module): +class Data2VecAudioConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -155,13 +156,7 @@ class Data2VecAudioFeatureEncoder(nn.Module): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -357,7 +352,7 @@ class Data2VecAudioFeedForward(nn.Module): return hidden_states -class Data2VecAudioEncoderLayer(nn.Module): +class Data2VecAudioEncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = Data2VecAudioAttention( @@ -441,17 +436,9 @@ class Data2VecAudioEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 0d7e8513490..97bca6d0d69 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -375,7 +376,7 @@ class Data2VecTextOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Data2VecText -class Data2VecTextLayer(nn.Module): +class Data2VecTextLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -501,27 +502,15 @@ class Data2VecTextEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index c48782d2477..381d354e3e9 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -497,7 +498,7 @@ class Data2VecVisionOutput(nn.Module): # Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision -class Data2VecVisionLayer(nn.Module): +class Data2VecVisionLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__( @@ -527,7 +528,7 @@ class Data2VecVisionLayer(nn.Module): output_attentions: bool = False, relative_position_bias: Optional[torch.Tensor] = None, interpolate_pos_encoding: bool = False, - resolution: Optional[tuple[int]] = None, + resolution: Optional[tuple[int, int]] = None, ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: self_attention_outputs = self.attention( self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention @@ -699,25 +700,14 @@ class Data2VecVisionEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - relative_position_bias, - interpolate_pos_encoding, - resolution, - ) - else: - layer_outputs = layer_module( - hidden_states, - layer_head_mask, - output_attentions, - relative_position_bias, - interpolate_pos_encoding, - resolution, - ) + layer_outputs = layer_module( + hidden_states, + head_mask=layer_head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 0b4695c1e28..94a4d3e080a 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -20,6 +20,7 @@ import torch from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel from ..wav2vec2.modeling_wav2vec2 import ( @@ -38,7 +39,7 @@ from ..wav2vec2.modeling_wav2vec2 import ( from .configuration_data2vec_audio import Data2VecAudioConfig -class Data2VecAudioConvLayer(nn.Module): +class Data2VecAudioConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index f1fee43a239..e53ea75f26c 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_torch_flex_attn_available, logging @@ -724,7 +725,7 @@ class DbrxFFN(nn.Module): return out, weights -class DbrxBlock(nn.Module): +class DbrxBlock(GradientCheckpointingLayer): def __init__(self, config: DbrxConfig, block_idx: int): super().__init__() self.hidden_size = config.d_model @@ -947,29 +948,16 @@ class DbrxModel(DbrxPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - block_outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - ) - else: - block_outputs = block( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - ) + block_outputs = block( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = block_outputs[0] diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index eef11f7ec34..c6dd97736c5 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -492,7 +493,7 @@ class DebertaOutput(nn.Module): return hidden_states -class DebertaLayer(nn.Module): +class DebertaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = DebertaAttention(config) @@ -580,25 +581,14 @@ class DebertaEncoder(nn.Module): rel_embeddings = self.get_rel_embedding() for i, layer_module in enumerate(self.layer): - if self.gradient_checkpointing and self.training: - hidden_states, att_m = self._gradient_checkpointing_func( - layer_module.__call__, - next_kv, - attention_mask, - query_states, - relative_pos, - rel_embeddings, - output_attentions, - ) - else: - hidden_states, att_m = layer_module( - next_kv, - attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - ) + hidden_states, att_m = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 5073d1de526..9089fe1f650 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -418,7 +419,7 @@ class DebertaV2Output(nn.Module): # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 -class DebertaV2Layer(nn.Module): +class DebertaV2Layer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = DebertaV2Attention(config) @@ -655,25 +656,14 @@ class DebertaV2Encoder(nn.Module): next_kv = hidden_states rel_embeddings = self.get_rel_embedding() for i, layer_module in enumerate(self.layer): - if self.gradient_checkpointing and self.training: - output_states, attn_weights = self._gradient_checkpointing_func( - layer_module.__call__, - next_kv, - attention_mask, - query_states, - relative_pos, - rel_embeddings, - output_attentions, - ) - else: - output_states, attn_weights = layer_module( - next_kv, - attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - ) + output_states, attn_weights = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) if output_attentions: all_attentions = all_attentions + (attn_weights,) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index a436ce6d2c0..9184b11bfe6 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer @@ -360,7 +361,7 @@ class DecisionTransformerGPT2MLP(nn.Module): # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2 -class DecisionTransformerGPT2Block(nn.Module): +class DecisionTransformerGPT2Block(GradientCheckpointingLayer): # Ignore copy def __init__(self, config, layer_idx=None): super().__init__() @@ -654,31 +655,17 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - past_key_value=past_key_values, - cache_position=cache_position, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + outputs = block( + hidden_states, + past_key_values if not (self.gradient_checkpointing and self.training) else None, + cache_position, + attention_mask, + head_mask[i], + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = outputs[0] diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index e36da6da89b..26298a0f6b2 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -27,6 +27,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid @@ -759,7 +760,7 @@ class DeformableDetrMultiheadAttention(nn.Module): return attn_output, attn_weights_reshaped -class DeformableDetrEncoderLayer(nn.Module): +class DeformableDetrEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: DeformableDetrConfig): super().__init__() self.embed_dim = config.d_model @@ -848,7 +849,7 @@ class DeformableDetrEncoderLayer(nn.Module): return outputs -class DeformableDetrDecoderLayer(nn.Module): +class DeformableDetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DeformableDetrConfig): super().__init__() self.embed_dim = config.d_model @@ -1126,29 +1127,16 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel): for i, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - position_embeddings, - reference_points, - spatial_shapes, - spatial_shapes_list, - level_start_index, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - position_embeddings=position_embeddings, - reference_points=reference_points, - spatial_shapes=spatial_shapes, - spatial_shapes_list=spatial_shapes_list, - level_start_index=level_start_index, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1273,31 +1261,17 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - position_embeddings, - reference_points_input, - spatial_shapes, - spatial_shapes_list, - level_start_index, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - encoder_hidden_states=encoder_hidden_states, - reference_points=reference_points_input, - spatial_shapes=spatial_shapes, - spatial_shapes_list=spatial_shapes_list, - level_start_index=level_start_index, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings, + reference_points_input, + spatial_shapes, + spatial_shapes_list, + level_start_index, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask, + output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 6b6284995fc..4250c1180bc 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -347,7 +348,7 @@ class DeiTOutput(nn.Module): # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->DeiT,VIT->DEIT -class DeiTLayer(nn.Module): +class DeiTLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: DeiTConfig) -> None: @@ -414,15 +415,7 @@ class DeiTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/deta/modeling_deta.py b/src/transformers/models/deprecated/deta/modeling_deta.py index ae44b5e1b12..64912b512ff 100644 --- a/src/transformers/models/deprecated/deta/modeling_deta.py +++ b/src/transformers/models/deprecated/deta/modeling_deta.py @@ -39,6 +39,7 @@ from ....file_utils import ( replace_return_docstrings, ) from ....modeling_attn_mask_utils import _prepare_4d_attention_mask +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput from ....modeling_utils import PreTrainedModel from ....pytorch_utils import meshgrid @@ -909,7 +910,7 @@ class DetaEncoderLayer(nn.Module): return outputs -class DetaDecoderLayer(nn.Module): +class DetaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DetaConfig): super().__init__() self.embed_dim = config.d_model @@ -1341,29 +1342,16 @@ class DetaDecoder(DetaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - position_embeddings, - reference_points_input, - spatial_shapes, - level_start_index, - encoder_hidden_states, - encoder_attention_mask, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - encoder_hidden_states=encoder_hidden_states, - reference_points=reference_points_input, - spatial_shapes=spatial_shapes, - level_start_index=level_start_index, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + encoder_hidden_states=encoder_hidden_states, + reference_points=reference_points_input, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index e4852ff78f8..7bc835cf13d 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -26,6 +26,7 @@ from ....file_utils import add_code_sample_docstrings, add_start_docstrings, add from ....integrations.deepspeed import is_deepspeed_zero3_enabled from ....integrations.fsdp import is_fsdp_managed_module from ....modeling_attn_mask_utils import _prepare_4d_attention_mask +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput, CausalLMOutput from ....modeling_utils import ( PreTrainedModel, @@ -377,7 +378,7 @@ class MCTCTOutput(nn.Module): return hidden_states -class MCTCTLayer(nn.Module): +class MCTCTLayer(GradientCheckpointingLayer): def __init__(self, config: MCTCTConfig): super().__init__() @@ -591,20 +592,11 @@ class MCTCTEncoder(MCTCTPreTrainedModel): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/nezha/modeling_nezha.py b/src/transformers/models/deprecated/nezha/modeling_nezha.py index d1c3fd8dbaa..2ef4a560952 100644 --- a/src/transformers/models/deprecated/nezha/modeling_nezha.py +++ b/src/transformers/models/deprecated/nezha/modeling_nezha.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ....activations import ACT2FN +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -438,7 +439,7 @@ class NezhaOutput(nn.Module): return hidden_states -class NezhaLayer(nn.Module): +class NezhaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -563,27 +564,15 @@ class NezhaEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 100d02e2285..848f3f971e0 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ....activations import ACT2FN from ....modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ....modeling_utils import PreTrainedModel from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings @@ -339,7 +340,7 @@ class OpenLlamaAttention(nn.Module): return attn_output, attn_weights, past_key_value -class OpenLlamaDecoderLayer(nn.Module): +class OpenLlamaDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: OpenLlamaConfig): super().__init__() self.hidden_size = config.hidden_size @@ -631,25 +632,14 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - None, - output_attentions, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py index f30a0757009..df3fce3b520 100755 --- a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ....activations import ACT2FN +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -452,7 +453,7 @@ class QDQBertOutput(nn.Module): # Based on transformers.models.bert.modeling_bert.BertLayer with Bert -> QDQBert -class QDQBertLayer(nn.Module): +class QDQBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.seq_len_dim = 1 @@ -568,32 +569,15 @@ class QDQBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/deprecated/realm/modeling_realm.py b/src/transformers/models/deprecated/realm/modeling_realm.py index 5714bf52a0e..e88a75bd1bf 100644 --- a/src/transformers/models/deprecated/realm/modeling_realm.py +++ b/src/transformers/models/deprecated/realm/modeling_realm.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ....activations import ACT2FN +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -447,7 +448,7 @@ class RealmOutput(nn.Module): return hidden_states -class RealmLayer(nn.Module): +class RealmLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -572,27 +573,15 @@ class RealmEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py index ce4fdd1bb2a..0599c3b592f 100755 --- a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ....activations import ACT2FN from ....modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ....modeling_utils import PreTrainedModel from ....utils import add_start_docstrings, logging, replace_return_docstrings @@ -263,7 +264,7 @@ class Speech2Text2Attention(nn.Module): return attn_output, attn_weights_reshaped, past_key_value -class Speech2Text2DecoderLayer(nn.Module): +class Speech2Text2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Speech2Text2Config): super().__init__() self.embed_dim = config.d_model @@ -612,31 +613,17 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index a06a52d9f3c..fdfbecf7fe2 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -25,6 +25,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import functional as F +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_utils import PreTrainedModel from ....utils import ( ModelOutput, @@ -346,7 +347,7 @@ class CausalSelfAttention(nn.Module): return outputs -class Block(nn.Module): +class Block(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.ln1 = nn.LayerNorm(config.n_embd) @@ -540,16 +541,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - layer_past, - use_cache, - output_attentions, - ) - else: - outputs = block(hidden_states, layer_past, use_cache, output_attentions) + outputs = block(hidden_states, layer_past, use_cache, output_attentions) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py index d0083211bd3..5280248c59e 100644 --- a/src/transformers/models/deprecated/tvlt/modeling_tvlt.py +++ b/src/transformers/models/deprecated/tvlt/modeling_tvlt.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ....activations import ACT2FN +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ....modeling_utils import PreTrainedModel from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -483,7 +484,7 @@ class TvltOutput(nn.Module): return hidden_states -class TvltLayer(nn.Module): +class TvltLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config): @@ -546,16 +547,7 @@ class TvltEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] @@ -853,15 +845,7 @@ class TvltDecoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - None, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py index 553b0a7bb3b..03bcc24beb6 100644 --- a/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/deprecated/vit_hybrid/modeling_vit_hybrid.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ....activations import ACT2FN +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ....modeling_utils import PreTrainedModel from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -390,7 +391,7 @@ VIT_HYBRID_ATTENTION_CLASSES = { } -class ViTHybridLayer(nn.Module): +class ViTHybridLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ViTHybridConfig) -> None: @@ -457,15 +458,7 @@ class ViTHybridEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py index e8b23b961e5..57f2c2610e8 100644 --- a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -26,6 +26,7 @@ from torch import Tensor, nn from torch.nn import LayerNorm from ....activations import ACT2FN +from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput from ....modeling_utils import PreTrainedModel from ....utils import ( @@ -1090,7 +1091,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module): return predict_relative_pos_embeddings -class XLMProphetNetEncoderLayer(nn.Module): +class XLMProphetNetEncoderLayer(GradientCheckpointingLayer): """ Encoder block for XLMProphetnet """ @@ -1133,7 +1134,7 @@ class XLMProphetNetEncoderLayer(nn.Module): return outputs -class XLMProphetNetDecoderLayer(nn.Module): +class XLMProphetNetDecoderLayer(GradientCheckpointingLayer): """ Decoder block for XLMProphetnet """ @@ -1320,21 +1321,12 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel): if output_hidden_states: encoder_hidden_states = encoder_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - extended_attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=extended_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1554,41 +1546,21 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - extended_attention_mask, - encoder_hidden_states, - extended_encoder_attention_mask, - (head_mask[idx] if head_mask is not None else None), - (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - extended_predict_attention_mask, - main_relative_position_buckets, - predict_relative_position_buckets, - position_ids, - None, - use_cache, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=extended_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attn_mask=extended_encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - extended_predict_attention_mask=extended_predict_attention_mask, - main_relative_position_buckets=main_relative_position_buckets, - predict_relative_position_buckets=predict_relative_position_buckets, - position_ids=position_ids, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attn_mask=extended_encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 9f8ea167ab7..e52ab48cdf2 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -23,6 +23,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -677,7 +678,7 @@ class DetrEncoderLayer(nn.Module): return outputs -class DetrDecoderLayer(nn.Module): +class DetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DetrConfig): super().__init__() self.embed_dim = config.d_model @@ -1045,25 +1046,15 @@ class DetrDecoder(DetrPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - combined_attention_mask, - encoder_hidden_states, - encoder_attention_mask, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + combined_attention_mask, + object_queries, + query_position_embeddings, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index bd35abb941a..7b023242a1c 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -382,7 +383,7 @@ class Dinov2SwiGLUFFN(nn.Module): return self.weights_out(hidden) -class Dinov2Layer(nn.Module): +class Dinov2Layer(GradientCheckpointingLayer): """This corresponds to the Block class in the original implementation.""" def __init__(self, config: Dinov2Config) -> None: @@ -458,15 +459,7 @@ class Dinov2Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py index c2eeb197021..adbb34c2fd4 100644 --- a/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +++ b/src/transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py @@ -28,6 +28,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -399,7 +400,7 @@ class Dinov2WithRegistersSwiGLUFFN(nn.Module): return self.weights_out(hidden) -class Dinov2WithRegistersLayer(nn.Module): +class Dinov2WithRegistersLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the original implementation.""" def __init__(self, config: Dinov2WithRegistersConfig) -> None: @@ -476,15 +477,7 @@ class Dinov2WithRegistersEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 28cec74fb3d..30ccc04c1f5 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -31,6 +31,7 @@ from ...configuration_utils import PretrainedConfig from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -441,7 +442,7 @@ DISTILBERT_ATTENTION_CLASSES = { } -class TransformerBlock(nn.Module): +class TransformerBlock(GradientCheckpointingLayer): def __init__(self, config: PretrainedConfig): super().__init__() @@ -537,21 +538,12 @@ class Transformer(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_state,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_state, - attn_mask, - head_mask[i], - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_state, - attn_mask, - head_mask[i], - output_attentions, - ) + layer_outputs = layer_module( + hidden_state, + attn_mask, + head_mask[i], + output_attentions, + ) hidden_state = layer_outputs[-1] diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index a63b0d3f0f5..603acec7782 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -27,6 +27,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -706,7 +707,7 @@ class DonutSwinLayer(nn.Module): # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin -class DonutSwinStage(nn.Module): +class DonutSwinStage(GradientCheckpointingLayer): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): super().__init__() self.config = config @@ -816,19 +817,9 @@ class DonutSwinEncoder(nn.Module): for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - input_dimensions, - layer_head_mask, - output_attentions, - always_partition, - ) - else: - layer_outputs = layer_module( - hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition - ) + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = layer_outputs[1] diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index d3c1703b9ee..82ff615afc2 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -29,6 +29,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -469,7 +470,7 @@ class DPTViTOutput(nn.Module): # copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput -class DPTViTLayer(nn.Module): +class DPTViTLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: DPTConfig) -> None: @@ -536,15 +537,7 @@ class DPTViTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index dfe5849a5a7..81eb2d894d6 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions, @@ -436,7 +437,7 @@ class ElectraOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Electra -class ElectraLayer(nn.Module): +class ElectraLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -562,27 +563,15 @@ class ElectraEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 21a79a3fabc..dda93fb81c9 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -361,7 +362,7 @@ class ErnieOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Ernie -class ErnieLayer(nn.Module): +class ErnieLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -487,27 +488,15 @@ class ErnieEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 10db78a67cb..953a024a823 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -599,7 +600,7 @@ class EsmOutput(nn.Module): return hidden_states -class EsmLayer(nn.Module): +class EsmLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -725,27 +726,15 @@ class EsmEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index d6634662f30..d5924567922 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -30,6 +30,7 @@ from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -556,7 +557,7 @@ FALCON_ATTENTION_CLASSES = { } -class FalconDecoderLayer(nn.Module): +class FalconDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: FalconConfig, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -836,33 +837,18 @@ class FalconModel(FalconPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - alibi, - causal_mask, - position_ids, - head_mask[i], - past_key_values, - use_cache, - output_attentions, - cache_position, - position_embeddings, - ) - else: - outputs = block( - hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - alibi=alibi, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 257f2c50cd9..2df27e390ea 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import MambaCache from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available @@ -405,7 +406,7 @@ class FalconMambaRMSNorm(nn.Module): # Copied from transformers.models.mamba.modeling_mamba.MambaBlock with Mamba->FalconMamba,FalconMambaCache->MambaCache -class FalconMambaBlock(nn.Module): +class FalconMambaBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.config = config @@ -620,17 +621,12 @@ class FalconMambaModel(FalconMambaPreTrainedModel): hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask - ) - else: - hidden_states = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - ) + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 3bd7b45d0dc..1d1ce04c356 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -25,6 +25,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -577,7 +578,7 @@ class FlavaOutput(nn.Module): return hidden_states -class FlavaLayer(nn.Module): +class FlavaLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: FlavaPossibleConfigs) -> None: @@ -648,16 +649,7 @@ class FlavaEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 619d6c9c5ad..fed31339da7 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -31,6 +31,7 @@ if is_scipy_available(): from scipy import linalg from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -235,7 +236,7 @@ class FNetOutput(nn.Module): return hidden_states -class FNetLayer(nn.Module): +class FNetLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -276,10 +277,7 @@ class FNetEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states) - else: - layer_outputs = layer_module(hidden_states) + layer_outputs = layer_module(hidden_states) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 232f1e6ed1f..47fa9d4f2eb 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging @@ -455,7 +456,7 @@ class FocalNetLayer(nn.Module): return hidden_state -class FocalNetStage(nn.Module): +class FocalNetStage(GradientCheckpointingLayer): def __init__(self, config, index, input_resolution): super().__init__() @@ -560,14 +561,7 @@ class FocalNetEncoder(nn.Module): all_reshaped_hidden_states += (reshaped_hidden_state,) for i, stage_module in enumerate(self.stages): - if self.gradient_checkpointing and self.training: - stage_outputs = self._gradient_checkpointing_func( - stage_module.__call__, - hidden_states, - input_dimensions, - ) - else: - stage_outputs = stage_module(hidden_states, input_dimensions) + stage_outputs = stage_module(hidden_states, input_dimensions) hidden_states = stage_outputs[0] hidden_states_before_downsampling = stage_outputs[1] diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 28a57ab9090..7008538c7ab 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Union import torch @@ -30,6 +29,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -238,7 +238,7 @@ class Gemma2Attention(nn.Module): return attn_output, attn_weights -class Gemma2DecoderLayer(nn.Module): +class Gemma2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -466,30 +466,17 @@ class Gemma2Model(Gemma2PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - position_embeddings, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 9890711d3c1..b317936c776 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Union import torch @@ -25,6 +24,7 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig, layer_type_validation from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack @@ -303,7 +303,7 @@ class Gemma2Attention(GemmaAttention): return attn_output, attn_weights -class Gemma2DecoderLayer(nn.Module): +class Gemma2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -449,30 +449,17 @@ class Gemma2Model(GemmaModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - position_embeddings, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings=position_embeddings, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 16d4673053e..bed1d5310af 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -22,7 +22,6 @@ import copy from collections.abc import Callable from dataclasses import dataclass -from functools import partial from typing import Optional, Union import torch @@ -34,6 +33,7 @@ from ...configuration_utils import PretrainedConfig from ...generation import GenerationMixin from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -364,7 +364,7 @@ class Gemma3Attention(nn.Module): return attn_output, attn_weights -class Gemma3DecoderLayer(nn.Module): +class Gemma3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma3TextConfig, layer_idx: int): super().__init__() self.config = config @@ -581,32 +581,18 @@ class Gemma3TextModel(Gemma3PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - position_embeddings_global, - position_embeddings_local, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index d93d53c8e93..bc1db4b50a4 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -16,7 +16,6 @@ import copy from collections.abc import Callable from dataclasses import dataclass -from functools import partial from typing import Any, Optional, Union import torch @@ -27,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig, layer_type_validation from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -443,7 +443,7 @@ class Gemma3Attention(Gemma2Attention): return attn_output, attn_weights -class Gemma3DecoderLayer(nn.Module): +class Gemma3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Gemma3TextConfig, layer_idx: int): super().__init__() self.config = config @@ -632,32 +632,18 @@ class Gemma3TextModel(Gemma2Model): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - position_embeddings_global, - position_embeddings_local, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + position_embeddings_global=position_embeddings_global, + position_embeddings_local=position_embeddings_local, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 6068ce169da..8058c542e9d 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, @@ -343,7 +344,7 @@ class GitOutput(nn.Module): return hidden_states -class GitLayer(nn.Module): +class GitLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -441,24 +442,14 @@ class GitEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - past_key_values, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - past_key_values, - output_attentions, - pixel_values_present, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + past_key_values, + output_attentions, + pixel_values_present, + ) hidden_states = layer_outputs[0] if use_cache: @@ -723,7 +714,7 @@ class GitVisionAttention(nn.Module): # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision -class GitVisionEncoderLayer(nn.Module): +class GitVisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: GitVisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -840,21 +831,12 @@ class GitVisionEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 6527d3efbbc..12a1d58151f 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -31,6 +31,7 @@ import torch.nn.functional as F from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack @@ -192,7 +193,7 @@ class GotOcr2VisionAttention(nn.Module): return outputs -class GotOcr2VisionLayer(nn.Module): +class GotOcr2VisionLayer(GradientCheckpointingLayer): def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -463,13 +464,7 @@ class GotOcr2VisionEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index fa98bc3614e..b88290343b6 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN, get_activation from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -368,7 +369,7 @@ class GPT2MLP(nn.Module): return hidden_states -class GPT2Block(nn.Module): +class GPT2Block(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -922,32 +923,18 @@ class GPT2Model(GPT2PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - past_key_values, - cache_position, - causal_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - past_key_value=past_key_values, - cache_position=cache_position, - attention_mask=causal_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs, - ) + outputs = block( + hidden_states, + past_key_values if not (self.gradient_checkpointing and self.training) else None, + cache_position, + causal_mask, + head_mask[i], + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + **kwargs, + ) hidden_states = outputs[0] diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index b90fdfe8acc..725ddbabf7a 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -558,7 +559,7 @@ GPTBIGCODE_ATTENTION_CLASSES = { } -class GPTBigCodeBlock(nn.Module): +class GPTBigCodeBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -759,6 +760,12 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: @@ -891,29 +898,16 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + outputs = block( + hidden_states, + layer_past, + attention_mask, + head_mask[i], + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = outputs[0] if use_cache: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 8ac65c7d1ae..25542ac3129 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, @@ -431,7 +432,7 @@ class GPTNeoMLP(nn.Module): return hidden_states -class GPTNeoBlock(nn.Module): +class GPTNeoBlock(GradientCheckpointingLayer): def __init__(self, config, layer_id=None): super().__init__() hidden_size = config.hidden_size @@ -635,27 +636,15 @@ class GPTNeoModel(GPTNeoPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - causal_mask, - head_mask[i], - use_cache, - output_attentions, - cache_position, - ) - else: - outputs = block( - hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) hidden_states = outputs[0] if use_cache: diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 6b08c27a306..d3c5141371b 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -14,6 +14,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -190,7 +191,7 @@ class GPTNeoXAttention(nn.Module): return attn_output, attn_weights -class GPTNeoXLayer(nn.Module): +class GPTNeoXLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -415,32 +416,18 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - causal_mask, - position_ids, - head_mask[i], - use_cache, - past_key_values, - output_attentions, - cache_position, - position_embeddings, - ) - else: - outputs = layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - layer_past=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + outputs = layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + layer_past=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = outputs[0] if output_attentions: diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 4922a4e3b4c..fde2677b4e2 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -9,6 +9,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -177,7 +178,7 @@ class GPTNeoXAttention(nn.Module): return attn_output, attn_weights -class GPTNeoXLayer(nn.Module): +class GPTNeoXLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -362,32 +363,18 @@ class GPTNeoXModel(LlamaModel, nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - causal_mask, - position_ids, - head_mask[i], - use_cache, - past_key_values, - output_attentions, - cache_position, - position_embeddings, - ) - else: - outputs = layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - layer_past=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + outputs = layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + layer_past=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = outputs[0] if output_attentions: diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 4388fad01f6..a8504db42c8 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -434,7 +435,7 @@ class GPTJMLP(nn.Module): return hidden_states -class GPTJBlock(nn.Module): +class GPTJBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd @@ -738,29 +739,16 @@ class GPTJModel(GPTJPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - causal_mask, - position_ids, - head_mask[i], - use_cache, - output_attentions, - cache_position, - ) - else: - outputs = block( - hidden_states=hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, - position_ids=position_ids, - head_mask=head_mask[i], - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + outputs = block( + hidden_states, + layer_past=past_key_values, + attention_mask=causal_mask, + position_ids=position_ids, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index 100fd2dd85f..362d170ffa8 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging, torch_int @@ -692,7 +693,7 @@ class GroupViTAttention(nn.Module): # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GroupViT -class GroupViTEncoderLayer(nn.Module): +class GroupViTEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: GroupViTConfig): super().__init__() self.embed_dim = config.hidden_size @@ -906,21 +907,12 @@ class GroupViTTextEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/hiera/modeling_hiera.py b/src/transformers/models/hiera/modeling_hiera.py index 2fadde33211..e086b432ba5 100644 --- a/src/transformers/models/hiera/modeling_hiera.py +++ b/src/transformers/models/hiera/modeling_hiera.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BackboneOutput, BaseModelOutput, @@ -540,7 +541,7 @@ class HieraLayer(nn.Module): return (hidden_states, attn_weights) -class HieraStage(nn.Module): +class HieraStage(GradientCheckpointingLayer): def __init__( self, config, @@ -734,12 +735,7 @@ class HieraEncoder(nn.Module): for i, stage_module in enumerate(self.stages): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - stage_module.__call__, hidden_states, layer_head_mask, output_attentions - ) - else: - layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index faa0ff48c68..0fab4184bfe 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -32,6 +32,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -107,7 +108,7 @@ class HubertSamePadLayer(nn.Module): return hidden_states -class HubertNoLayerNormConvLayer(nn.Module): +class HubertNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -128,7 +129,7 @@ class HubertNoLayerNormConvLayer(nn.Module): return hidden_states -class HubertLayerNormConvLayer(nn.Module): +class HubertLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -155,7 +156,7 @@ class HubertLayerNormConvLayer(nn.Module): return hidden_states -class HubertGroupNormConvLayer(nn.Module): +class HubertGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -212,13 +213,7 @@ class HubertFeatureEncoder(nn.Module): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -417,7 +412,7 @@ class HubertFeedForward(nn.Module): return hidden_states -class HubertEncoderLayer(nn.Module): +class HubertEncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = HubertAttention( @@ -501,17 +496,9 @@ class HubertEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -579,7 +566,7 @@ class HubertAttnAdapterLayer(nn.Module): return hidden_states -class HubertEncoderLayerStableLayerNorm(nn.Module): +class HubertEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = HubertAttention( @@ -675,17 +662,9 @@ class HubertEncoderStableLayerNorm(nn.Module): if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index d3bba25a564..a5a868072a6 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -32,6 +32,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel from ...processing_utils import Unpack @@ -668,7 +669,7 @@ class IdeficsAttention(nn.Module): # this was adapted from LlamaDecoderLayer -class IdeficsDecoderLayer(nn.Module): +class IdeficsDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size @@ -749,7 +750,7 @@ class IdeficsDecoderLayer(nn.Module): return outputs -class IdeficsGatedCrossAttentionLayer(nn.Module): +class IdeficsGatedCrossAttentionLayer(GradientCheckpointingLayer): def __init__(self, config: IdeficsConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size @@ -1185,95 +1186,32 @@ class IdeficsModel(IdeficsPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - def vblock( - main_block, - hidden_states, - attention_mask, - position_ids, - past_key_value, - image_hidden_states, - image_attention_mask, - cross_attention_gate, - output_attentions, - use_cache, - layer_idx, - cross_layer_interval, - gated_cross_attn_layers, - cache_position, - ): - # TODO(ls): Add cross attention values to respective lists - if layer_idx % cross_layer_interval == 0: - xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval] - outputs = xblock( - hidden_states, - attention_mask=attention_mask, - image_hidden_states=image_hidden_states, - image_attention_mask=image_attention_mask, - cross_attention_gate=cross_attention_gate, - output_attentions=output_attentions, - use_cache=use_cache, - past_key_value=None, # not implemented - **kwargs, - ) - hidden_states = outputs[0] - - layer_outputs = main_block( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - return layer_outputs - - if self.gradient_checkpointing and self.training: - past_key_values = None - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - layer_outputs = self._gradient_checkpointing_func( - vblock, - decoder_layer, + # TODO(ls): Add cross attention values to respective lists + if idx % self.cross_layer_interval == 0: + cross_attn_block = self.gated_cross_attn_layers[idx // self.cross_layer_interval] + outputs = cross_attn_block( hidden_states, attention_mask, - position_ids, - past_key_values, image_hidden_states, - image_attention_mask, - cross_attention_gate, - output_attentions, - use_cache, - idx, - self.cross_layer_interval, - self.gated_cross_attn_layers, - cache_position, - ) - else: - layer_outputs = vblock( - decoder_layer, - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - image_hidden_states=image_hidden_states, image_attention_mask=image_attention_mask, cross_attention_gate=cross_attention_gate, output_attentions=output_attentions, use_cache=use_cache, - layer_idx=idx, - cross_layer_interval=self.cross_layer_interval, - gated_cross_attn_layers=self.gated_cross_attn_layers, - cache_position=cache_position, + past_key_value=None, # not implemented **kwargs, ) + hidden_states = outputs[0] + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index 815b902d3fb..d75d61545ec 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import ( @@ -283,7 +284,7 @@ class IdeficsVisionMLP(nn.Module): # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->IdeficsVision -class IdeficsVisionEncoderLayer(nn.Module): +class IdeficsVisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: IdeficsVisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -400,21 +401,12 @@ class IdeficsVisionEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index d9b5d5e6833..1f3f96de630 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -339,7 +340,7 @@ class Idefics2MultiheadAttentionPoolingHead(nn.Module): return hidden_state[:, 0] -class Idefics2EncoderLayer(nn.Module): +class Idefics2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Idefics2VisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -448,19 +449,11 @@ class Idefics2Encoder(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 53b3cc2e304..56750bc5298 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -26,6 +26,7 @@ from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -300,7 +301,7 @@ class Idefics3SimpleMLP(nn.Module): # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2EncoderLayer with Idefics2->Idefics3 -class Idefics3EncoderLayer(nn.Module): +class Idefics3EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Idefics3VisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -409,19 +410,11 @@ class Idefics3Encoder(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/ijepa/modeling_ijepa.py b/src/transformers/models/ijepa/modeling_ijepa.py index c7ce7c29b4c..5568b4ebcc8 100644 --- a/src/transformers/models/ijepa/modeling_ijepa.py +++ b/src/transformers/models/ijepa/modeling_ijepa.py @@ -12,6 +12,7 @@ import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -357,7 +358,7 @@ class IJepaOutput(nn.Module): return hidden_states -class IJepaLayer(nn.Module): +class IJepaLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: IJepaConfig) -> None: @@ -423,15 +424,7 @@ class IJepaEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 65d5cbc3df2..d076a193162 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -401,7 +402,7 @@ class ImageGPTMLP(nn.Module): return hidden_states -class ImageGPTBlock(nn.Module): +class ImageGPTBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size @@ -719,29 +720,16 @@ class ImageGPTModel(ImageGPTPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + outputs = block( + hidden_states, + layer_past, + attention_mask, + head_mask[i], + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 98859e7534c..9718e8fb736 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -744,7 +745,7 @@ class InformerProbSparseAttention(nn.Module): # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py -class InformerConvLayer(nn.Module): +class InformerConvLayer(GradientCheckpointingLayer): def __init__(self, c_in): super().__init__() self.downConv = nn.Conv1d( @@ -767,7 +768,7 @@ class InformerConvLayer(nn.Module): return x -class InformerEncoderLayer(nn.Module): +class InformerEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: InformerConfig): super().__init__() self.embed_dim = config.d_model @@ -845,7 +846,7 @@ class InformerEncoderLayer(nn.Module): return outputs -class InformerDecoderLayer(nn.Module): +class InformerDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -1086,27 +1087,15 @@ class InformerEncoder(InformerPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - if conv_layer is not None: - output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0]) - layer_outputs = (output,) + layer_outputs[1:] - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) - if conv_layer is not None: - output = conv_layer(layer_outputs[0]) - layer_outputs = (output,) + layer_outputs[1:] + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] hidden_states = layer_outputs[0] @@ -1299,35 +1288,18 @@ class InformerDecoder(InformerPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 755fcd68853..3d46275bdc8 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -27,6 +27,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, ) @@ -433,7 +434,7 @@ class InformerProbSparseAttention(nn.Module): # source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py -class InformerConvLayer(nn.Module): +class InformerConvLayer(GradientCheckpointingLayer): def __init__(self, c_in): super().__init__() self.downConv = nn.Conv1d( @@ -610,27 +611,15 @@ class InformerEncoder(TimeSeriesTransformerEncoder): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - if conv_layer is not None: - output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0]) - layer_outputs = (output,) + layer_outputs[1:] - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) - if conv_layer is not None: - output = conv_layer(layer_outputs[0]) - layer_outputs = (output,) + layer_outputs[1:] + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] hidden_states = layer_outputs[0] diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 1708a86082b..bf2c76cf9e5 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -427,7 +427,7 @@ class InstructBlipEncoder(nn.Module): layer_outputs = encoder_layer( hidden_states, - attention_mask, + attention_mask=attention_mask, output_attentions=output_attentions, ) @@ -889,11 +889,11 @@ class InstructBlipQFormerEncoder(nn.Module): hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + query_length=query_length, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 8c0d5c05a3b..ee9cffd4f2e 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -356,7 +356,7 @@ class InstructBlipVideoEncoder(nn.Module): layer_outputs = encoder_layer( hidden_states, - attention_mask, + attention_mask=attention_mask, output_attentions=output_attentions, ) @@ -750,11 +750,11 @@ class InstructBlipVideoQFormerEncoder(nn.Module): hidden_states, attention_mask, layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + query_length=query_length, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 571104934ef..65b2952b39d 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -31,6 +31,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -383,7 +384,7 @@ class InternVLVisionMLP(nn.Module): NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm} -class InternVLVisionLayer(nn.Module): +class InternVLVisionLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: InternVLVisionConfig) -> None: @@ -452,12 +453,7 @@ class InternVLVisionEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, output_attentions - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 90576676b3c..a71b9fbdad8 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -24,6 +24,7 @@ import torch.utils.checkpoint from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -334,7 +335,7 @@ class InternVLVisionMLP(CLIPMLP): NORM2FN = {"layer_norm": nn.LayerNorm, "rms_norm": InternVLVisionRMSNorm} -class InternVLVisionLayer(nn.Module): +class InternVLVisionLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: InternVLVisionConfig) -> None: @@ -403,12 +404,7 @@ class InternVLVisionEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, output_attentions - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 4b6fefbf9e9..ce17baf3281 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel @@ -763,7 +764,7 @@ JETMOE_ATTENTION_CLASSES = { } -class JetMoeBlock(nn.Module): +class JetMoeBlock(GradientCheckpointingLayer): def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): """ Initialize the JetMoeBlock module. @@ -967,28 +968,15 @@ class JetMoeModel(JetMoePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - position_ids, - past_key_values, - causal_mask, - output_attentions, - output_router_logits, - use_cache, - use_reentrant=False, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 34dc0848b79..a78ad47f2dd 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -404,7 +405,7 @@ class Kosmos2VisionMLP(nn.Module): # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision -class Kosmos2VisionEncoderLayer(nn.Module): +class Kosmos2VisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -521,21 +522,12 @@ class Kosmos2VisionEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -840,7 +832,7 @@ class Kosmos2TextFFN(nn.Module): return hidden_states -class Kosmos2TextBlock(nn.Module): +class Kosmos2TextBlock(GradientCheckpointingLayer): def __init__(self, config: Kosmos2TextConfig): super().__init__() self.embed_dim = config.embed_dim @@ -1138,34 +1130,18 @@ class Kosmos2TextTransformer(nn.Module): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index e4fb25523a7..372e4b89e07 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -358,7 +359,7 @@ class LayoutLMOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM -class LayoutLMLayer(nn.Module): +class LayoutLMLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -484,27 +485,15 @@ class LayoutLMEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 7a82375d1ff..fdaa37b9e50 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -261,7 +262,7 @@ class LayoutLMv2Output(nn.Module): return hidden_states -class LayoutLMv2Layer(nn.Module): +class LayoutLMv2Layer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -436,25 +437,14 @@ class LayoutLMv2Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - rel_pos=rel_pos, - rel_2d_pos=rel_2d_pos, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - rel_pos=rel_pos, - rel_2d_pos=rel_2d_pos, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + rel_pos=rel_pos, + rel_2d_pos=rel_2d_pos, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 83f87ec5281..1b6398a382d 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -25,6 +25,7 @@ import torch.utils.checkpoint from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, QuestionAnsweringModelOutput, @@ -358,7 +359,7 @@ class LayoutLMv3Attention(nn.Module): # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3 -class LayoutLMv3Layer(nn.Module): +class LayoutLMv3Layer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -514,25 +515,14 @@ class LayoutLMv3Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - rel_pos, - rel_2d_pos, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - rel_pos=rel_pos, - rel_2d_pos=rel_2d_pos, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + rel_pos=rel_pos, + rel_2d_pos=rel_2d_pos, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index bc573861321..ad095cbcd47 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging @@ -900,7 +901,7 @@ class LEDDecoderAttention(nn.Module): return attn_output, attn_weights_reshaped, past_key_value -class LEDEncoderLayer(nn.Module): +class LEDEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: LEDConfig, layer_id: int): super().__init__() self.embed_dim = config.d_model @@ -962,7 +963,7 @@ class LEDEncoderLayer(nn.Module): return (hidden_states,) + attn_outputs[1:] -class LEDDecoderLayer(nn.Module): +class LEDDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: LEDConfig): super().__init__() self.embed_dim = config.d_model @@ -1680,27 +1681,15 @@ class LEDEncoder(LEDPreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - is_index_masked, - is_index_global_attn, - is_global_attn, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - is_index_masked=is_index_masked, - is_index_global_attn=is_index_global_attn, - is_global_attn=is_global_attn, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: @@ -1943,33 +1932,17 @@ class LEDDecoder(LEDPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - combined_attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + combined_attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 0c76a25a6e3..91664c32fac 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -419,7 +420,7 @@ class LiltOutput(nn.Module): return hidden_states -class LiltLayer(nn.Module): +class LiltLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -506,23 +507,13 @@ class LiltEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layout_inputs, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - layout_inputs, - attention_mask, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + layout_inputs, + attention_mask, + layer_head_mask, + output_attentions, + ) hidden_states = layer_outputs[0][0] layout_inputs = layer_outputs[0][1] diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 38b6fde1037..fb546ad6816 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -29,6 +29,7 @@ from ...generation import GenerationMixin from ...integrations.hub_kernels import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_chunked_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -360,7 +361,7 @@ class Llama4TextAttention(nn.Module): return attn_output, attn_weights -class Llama4TextDecoderLayer(nn.Module): +class Llama4TextDecoderLayer(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.hidden_size = config.hidden_size @@ -571,31 +572,17 @@ class Llama4TextModel(Llama4PreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - False, # output_router_logits is False - use_cache, - cache_position, - freq_cis, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=freq_cis, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=freq_cis, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] @@ -930,7 +917,7 @@ class Llama4VisionMLP(nn.Module): return hidden_states -class Llama4VisionEncoderLayer(nn.Module): +class Llama4VisionEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Llama4VisionConfig): super().__init__() self.hidden_size = config.hidden_size @@ -1033,21 +1020,13 @@ class Llama4VisionEncoder(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - freqs_ci, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_state=hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - freqs_ci=freqs_ci, - ) + + layer_outputs = encoder_layer( + hidden_state=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + freqs_ci=freqs_ci, + ) if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index a40d5bb0e2c..c6b16492c8d 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging @@ -1205,7 +1206,7 @@ class LongformerOutput(nn.Module): return hidden_states -class LongformerLayer(nn.Module): +class LongformerLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.attention = LongformerAttention(config, layer_id) @@ -1284,27 +1285,15 @@ class LongformerEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - is_index_masked, - is_index_global_attn, - is_global_attn, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=attention_mask, - layer_head_mask=head_mask[idx] if head_mask is not None else None, - is_index_masked=is_index_masked, - is_index_global_attn=is_index_global_attn, - is_global_attn=is_global_attn, - output_attentions=output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=head_mask[idx] if head_mask is not None else None, + is_index_masked=is_index_masked, + is_index_global_attn=is_index_global_attn, + is_global_attn=is_global_attn, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 8efc2f7416c..081869ec8fc 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1143,7 +1144,7 @@ class LongT5LayerCrossAttention(nn.Module): return outputs -class LongT5Block(nn.Module): +class LongT5Block(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -1501,39 +1502,21 @@ class LongT5Stack(LongT5PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - return_dict, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index c01bb7453f9..af01cf77be4 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward @@ -695,7 +696,7 @@ class LukeOutput(nn.Module): return hidden_states -class LukeLayer(nn.Module): +class LukeLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -774,23 +775,13 @@ class LukeEncoder(nn.Module): all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - word_hidden_states, - entity_hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module( - word_hidden_states, - entity_hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module( + word_hidden_states, + entity_hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + ) word_hidden_states = layer_outputs[0] diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 25ed975e4d5..7d5a73667ee 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import ( from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -335,7 +336,7 @@ class M2M100Attention(nn.Module): # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 -class M2M100EncoderLayer(nn.Module): +class M2M100EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model @@ -404,7 +405,7 @@ class M2M100EncoderLayer(nn.Module): # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 -class M2M100DecoderLayer(nn.Module): +class M2M100DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: M2M100Config, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -883,21 +884,12 @@ class M2M100Encoder(M2M100PreTrainedModel): if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1142,35 +1134,20 @@ class M2M100Decoder(M2M100PreTrainedModel): if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 0f6dfab8112..d771494486f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import MambaCache from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -343,7 +344,7 @@ class MambaRMSNorm(nn.Module): return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" -class MambaBlock(nn.Module): +class MambaBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.config = config @@ -561,17 +562,12 @@ class MambaModel(MambaPreTrainedModel): hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask - ) - else: - hidden_states = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - ) + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 17925c5acc0..7dd6ecc92d4 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -24,6 +24,7 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -682,7 +683,7 @@ class Mamba2RMSNorm(nn.Module): return self.weight * hidden_states.to(input_dtype) -class Mamba2Block(nn.Module): +class Mamba2Block(GradientCheckpointingLayer): def __init__(self, config, layer_idx): super().__init__() self.config = config @@ -901,17 +902,12 @@ class Mamba2Model(Mamba2PreTrainedModel): hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask - ) - else: - hidden_states = mixer_block( - hidden_states, - cache_params=cache_params, - cache_position=cache_position, - attention_mask=attention_mask, - ) + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 5630f916ee3..7319671b485 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -270,7 +271,7 @@ class MarianAttention(nn.Module): # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN -class MarianEncoderLayer(nn.Module): +class MarianEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -342,7 +343,7 @@ class MarianEncoderLayer(nn.Module): # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN -class MarianDecoderLayer(nn.Module): +class MarianDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -831,21 +832,12 @@ class MarianEncoder(MarianPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1087,35 +1079,18 @@ class MarianDecoder(MarianPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 47e57b00172..4a34c85b3db 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -518,7 +519,7 @@ class MarkupLMAttention(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->MarkupLM -class MarkupLMLayer(nn.Module): +class MarkupLMLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -644,27 +645,15 @@ class MarkupLMEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 3eb559dfcdb..5ab37ea5358 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -25,6 +25,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...file_utils import ModelOutput, is_scipy_available, requires_backends +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_accelerate_available, logging @@ -1535,7 +1536,7 @@ class Mask2FormerAttention(nn.Module): return attn_output, attn_weights_reshaped -class Mask2FormerMaskedAttentionDecoderLayer(nn.Module): +class Mask2FormerMaskedAttentionDecoderLayer(GradientCheckpointingLayer): """ The Mask2FormerMaskedAttentionDecoderLayer is made up of self-attention, cross (masked) attention as well as FFN blocks. The cross attention block used as part of `Mask2FormerMaskedAttentionDecoderLayer` is actually a `masked @@ -1858,46 +1859,35 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module): if self.training and (dropout_probability < self.layerdrop): continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - None, - None, - output_attentions, - ) + level_index = idx % self.num_feature_levels - else: - level_index = idx % self.num_feature_levels + where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) + # Multiply the attention mask instead of indexing to avoid issue in torch.export. + attention_mask = attention_mask * where.unsqueeze(-1) - where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) - # Multiply the attention mask instead of indexing to avoid issue in torch.export. - attention_mask = attention_mask * where.unsqueeze(-1) + layer_outputs = decoder_layer( + hidden_states, + level_index, + None, # attention_mask + multi_stage_positional_embeddings, + query_position_embeddings, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=attention_mask, + output_attentions=output_attentions, + ) - layer_outputs = decoder_layer( - hidden_states, - level_index=level_index, - position_embeddings=multi_stage_positional_embeddings, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=attention_mask, - output_attentions=output_attentions, - ) + intermediate_hidden_states = self.layernorm(layer_outputs[0]) - intermediate_hidden_states = self.layernorm(layer_outputs[0]) + predicted_mask, attention_mask = self.mask_predictor( + intermediate_hidden_states, + pixel_embeddings, + feature_size_list[(idx + 1) % self.num_feature_levels], + ) - predicted_mask, attention_mask = self.mask_predictor( - intermediate_hidden_states, - pixel_embeddings, - feature_size_list[(idx + 1) % self.num_feature_levels], - ) + intermediate_mask_predictions += (predicted_mask,) - intermediate_mask_predictions += (predicted_mask,) - - # add intermediate hidden states with layer norm applied which will be used for predicting class logits - intermediate += (intermediate_hidden_states,) + # add intermediate hidden states with layer norm applied which will be used for predicting class logits + intermediate += (intermediate_hidden_states,) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 02f9848a8fb..18d36427d92 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -25,6 +25,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -529,7 +530,7 @@ class DetrAttention(nn.Module): # Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer -class DetrDecoderLayer(nn.Module): +class DetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DetrConfig): super().__init__() self.embed_dim = config.d_model @@ -742,26 +743,15 @@ class DetrDecoder(nn.Module): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - None, - encoder_hidden_states, - encoder_attention_mask, - None, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + None, # attention_mask + object_queries, + query_position_embeddings, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index b7505aa6748..68c291f1b10 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -26,6 +26,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...file_utils import ModelOutput +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer @@ -629,7 +630,7 @@ class MaskFormerSwinLayer(nn.Module): return outputs -class MaskFormerSwinStage(nn.Module): +class MaskFormerSwinStage(GradientCheckpointingLayer): # Copied from transformers.models.swin.modeling_swin.SwinStage.__init__ with Swin->MaskFormerSwin def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): super().__init__() @@ -729,21 +730,13 @@ class MaskFormerSwinEncoder(nn.Module): for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_hidden_states, output_dimensions, layer_all_hidden_states = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( - hidden_states, - input_dimensions, - layer_head_mask, - output_attentions, - output_hidden_states, - ) + layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + output_hidden_states, + ) input_dimensions = (output_dimensions[-2], output_dimensions[-1]) all_input_dimensions += (input_dimensions,) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 18ad34026f4..2585d91a3e3 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -34,6 +34,7 @@ from ...modeling_attn_mask_utils import ( from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -279,7 +280,7 @@ class MBartAttention(nn.Module): return attn_output, attn_weights, past_key_value -class MBartEncoderLayer(nn.Module): +class MBartEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model @@ -347,7 +348,7 @@ class MBartEncoderLayer(nn.Module): return outputs -class MBartDecoderLayer(nn.Module): +class MBartDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -866,21 +867,12 @@ class MBartEncoder(MBartPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1130,35 +1122,18 @@ class MBartDecoder(MBartPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index d22b1536081..941d62b5869 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -405,7 +406,7 @@ class MegatronBertOutput(nn.Module): # Based on transformers.models.bert.modeling_bert.BertLayer. Added LayerNorm. -class MegatronBertLayer(nn.Module): +class MegatronBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -535,27 +536,15 @@ class MegatronBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) # Because we moved the layer-norm at the end of the hidden layer, we have non-normali- # zed data here. If that's really needed, we must apply LN to match Transformer's BERT. diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 45c307b6136..3bea1055362 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel @@ -804,7 +805,7 @@ MIMI_ATTENTION_CLASSES = { } -class MimiTransformerLayer(nn.Module): +class MimiTransformerLayer(GradientCheckpointingLayer): def __init__(self, config: MimiConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1019,27 +1020,15 @@ class MimiTransformerModel(nn.Module): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 1bf968e0361..34e0b507f44 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -32,6 +32,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, MoeCausalLMOutputWithPast, @@ -485,7 +486,7 @@ class MiniMaxSparseMoeBlock(nn.Module): return final_hidden_states, router_logits -class MiniMaxDecoderLayer(nn.Module): +class MiniMaxDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MiniMaxConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 26007b7b18a..8f82b59e5e4 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -24,7 +24,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Union import torch @@ -37,6 +36,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, MoeCausalLMOutputWithPast, @@ -295,7 +295,7 @@ class MixtralAttention(nn.Module): return attn_output, attn_weights -class MixtralDecoderLayer(nn.Module): +class MixtralDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -535,32 +535,18 @@ class MixtralModel(MixtralPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index bfc78597cfe..c4e4a429666 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -19,7 +19,6 @@ # limitations under the License. """PyTorch Mixtral model.""" -from functools import partial from typing import Optional, Union import torch @@ -31,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import DynamicCache from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...processing_utils import Unpack from ...utils import LossKwargs, logging @@ -226,7 +226,7 @@ class MixtralAttention(MistralAttention): pass -class MixtralDecoderLayer(nn.Module): +class MixtralDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -386,32 +386,18 @@ class MixtralModel(MistralModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mlcd/modeling_mlcd.py b/src/transformers/models/mlcd/modeling_mlcd.py index b20e9f107e0..26a12cab8ba 100644 --- a/src/transformers/models/mlcd/modeling_mlcd.py +++ b/src/transformers/models/mlcd/modeling_mlcd.py @@ -25,6 +25,7 @@ import torch.nn as nn from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -299,7 +300,7 @@ class MLCDAttention(nn.Module): return attn_output, attn_weights -class MLCDEncoderLayer(nn.Module): +class MLCDEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MLCDVisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -416,21 +417,12 @@ class MLCDEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - position_embeddings, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mlcd/modular_mlcd.py b/src/transformers/models/mlcd/modular_mlcd.py index d18b2346224..412d34daa5f 100644 --- a/src/transformers/models/mlcd/modular_mlcd.py +++ b/src/transformers/models/mlcd/modular_mlcd.py @@ -356,21 +356,12 @@ class MLCDEncoder(CLIPEncoder): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - position_embeddings, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states=hidden_states, - position_embeddings=position_embeddings, - attention_mask=attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index 149eb9261ef..1b483fe958c 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithNoAttention, BaseModelOutputWithPoolingAndNoAttention, @@ -350,7 +351,7 @@ class MobileViTTransformer(nn.Module): return hidden_states -class MobileViTLayer(nn.Module): +class MobileViTLayer(GradientCheckpointingLayer): """ MobileViT block: https://huggingface.co/papers/2110.02178 """ @@ -603,13 +604,7 @@ class MobileViTEncoder(nn.Module): all_hidden_states = () if output_hidden_states else None for i, layer_module in enumerate(self.layer): - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - hidden_states = layer_module(hidden_states) + hidden_states = layer_module(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index 868c595dbac..a52aedca7cf 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithNoAttention, BaseModelOutputWithPoolingAndNoAttention, @@ -351,7 +352,7 @@ class MobileViTV2Transformer(nn.Module): return hidden_states -class MobileViTV2Layer(nn.Module): +class MobileViTV2Layer(GradientCheckpointingLayer): """ MobileViTV2 layer: https://huggingface.co/papers/2206.02680 """ @@ -556,13 +557,7 @@ class MobileViTV2Encoder(nn.Module): all_hidden_states = () if output_hidden_states else None for i, layer_module in enumerate(self.layer): - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - hidden_states = layer_module(hidden_states) + hidden_states = layer_module(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 05fb1af62b2..9089a8d3425 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -508,7 +509,7 @@ class ModernBertAttention(nn.Module): return (hidden_states,) + attn_outputs[1:] # add attentions if outputted -class ModernBertEncoderLayer(nn.Module): +class ModernBertEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config @@ -864,27 +865,15 @@ class ModernBertModel(ModernBertPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - sliding_window_mask, - position_ids, - cu_seqlens, - max_seqlen, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - sliding_window_mask=sliding_window_mask, - position_ids=position_ids, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions and len(layer_outputs) > 1: all_self_attentions = all_self_attentions + (layer_outputs[1],) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index a707c659fbf..bafbb3bf7d7 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...configuration_utils import PretrainedConfig from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -710,7 +711,7 @@ class ModernBertAttention(nn.Module): return (hidden_states,) + attn_outputs[1:] # add attentions if outputted -class ModernBertEncoderLayer(nn.Module): +class ModernBertEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): super().__init__() self.config = config @@ -994,27 +995,15 @@ class ModernBertModel(ModernBertPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - sliding_window_mask, - position_ids, - cu_seqlens, - max_seqlen, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - sliding_window_mask=sliding_window_mask, - position_ids=position_ids, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + sliding_window_mask=sliding_window_mask, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions and len(layer_outputs) > 1: all_self_attentions = all_self_attentions + (layer_outputs[1],) diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index d8f61fbf502..2909fb386fb 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -781,9 +781,9 @@ class MoonshineDecoder(MoonshinePreTrainedModel): layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 4ee7cd81f77..500231f3b48 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -787,9 +787,9 @@ class MoonshineDecoder(LlamaModel): layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - encoder_hidden_states=encoder_hidden_states, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index c06fd27e368..d397dc0d923 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationConfig, GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput, Seq2SeqLMOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel @@ -738,7 +739,7 @@ MOSHI_ATTENTION_CLASSES = { } -class MoshiDecoderLayer(nn.Module): +class MoshiDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool, use_rope=True): super().__init__() self.hidden_size = config.hidden_size @@ -1026,27 +1027,15 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] @@ -1343,27 +1332,15 @@ class MoshiModel(MoshiPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 79ec42e2b8d..b3005728e7b 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -25,6 +25,7 @@ from torch.nn import functional as F from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -160,7 +161,7 @@ class MptMLP(nn.Module): return output -class MptBlock(nn.Module): +class MptBlock(GradientCheckpointingLayer): def __init__(self, config: MptConfig): super().__init__() hidden_size = config.hidden_size @@ -388,25 +389,14 @@ class MptModel(MptPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - outputs = self._gradient_checkpointing_func( - block.__call__, - hidden_states, - alibi, - causal_mask, - layer_past, - use_cache, - output_attentions, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=causal_mask, - use_cache=use_cache, - output_attentions=output_attentions, - position_bias=alibi, - ) + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + use_cache=use_cache, + output_attentions=output_attentions, + position_bias=alibi, + ) hidden_states = outputs[0] if use_cache is True: diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 7501c4d8306..a7fd783d848 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.utils.cpp_extension import load from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, MaskedLMOutput, @@ -688,7 +689,7 @@ class MraOutput(nn.Module): return hidden_states -class MraLayer(nn.Module): +class MraLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -738,14 +739,7 @@ class MraEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask) + layer_outputs = layer_module(hidden_states, attention_mask) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 8596fbeb4f9..5584b2ee825 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -523,7 +524,7 @@ class MT5LayerCrossAttention(nn.Module): # Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5 -class MT5Block(nn.Module): +class MT5Block(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -1088,39 +1089,21 @@ class MT5Stack(MT5PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - return_dict, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index a7ead0a5146..54fd0b31bb3 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -43,6 +43,7 @@ from ...modeling_attn_mask_utils import ( from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -304,7 +305,7 @@ class MusicgenAttention(nn.Module): return attn_output, attn_weights, past_key_value -class MusicgenDecoderLayer(nn.Module): +class MusicgenDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MusicgenDecoderConfig): super().__init__() self.embed_dim = config.hidden_size @@ -619,33 +620,17 @@ class MusicgenDecoder(MusicgenPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.forward, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index a57955a7a70..441c0e862b2 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -41,6 +41,7 @@ from ...modeling_attn_mask_utils import ( from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -320,7 +321,7 @@ class MusicgenMelodyAttention(nn.Module): return attn_output, attn_weights, past_key_value -class MusicgenMelodyDecoderLayer(nn.Module): +class MusicgenMelodyDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MusicgenMelodyDecoderConfig): super().__init__() self.embed_dim = config.hidden_size @@ -596,25 +597,14 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.forward, - hidden_states, - attention_mask, - head_mask[idx] if head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 739ecd8f015..9e5136d27ad 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -29,6 +29,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -244,7 +245,7 @@ class MvpAttention(nn.Module): return attn_output, attn_weights_reshaped, past_key_value -class MvpEncoderLayer(nn.Module): +class MvpEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: MvpConfig): super().__init__() self.embed_dim = config.d_model @@ -316,7 +317,7 @@ class MvpEncoderLayer(nn.Module): return outputs -class MvpDecoderLayer(nn.Module): +class MvpDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: MvpConfig): super().__init__() self.embed_dim = config.d_model @@ -682,23 +683,13 @@ class MvpEncoder(MvpPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - (self_attn_prompt[idx] if self.use_prompt else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -935,37 +926,19 @@ class MvpDecoder(MvpPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - self_attn_prompt[idx] if self.use_prompt else None, - cross_attn_prompt[idx] if self.use_prompt else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), - cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), + cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 2cb0cdc5f0e..6fba76b55d3 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -28,6 +28,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -498,7 +499,7 @@ NEMOTRON_ATTENTION_CLASSES = { # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron # no longer copied after attention refactors -class NemotronDecoderLayer(nn.Module): +class NemotronDecoderLayer(GradientCheckpointingLayer): # Ignore copy def __init__(self, config: NemotronConfig, layer_idx: int): super().__init__() @@ -703,29 +704,16 @@ class NemotronModel(NemotronPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 15c87871649..f498cf743fc 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, @@ -625,7 +626,7 @@ class NllbMoeAttention(nn.Module): return attn_output, attn_weights, past_key_value -class NllbMoeEncoderLayer(nn.Module): +class NllbMoeEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): super().__init__() self.embed_dim = config.d_model @@ -707,7 +708,7 @@ class NllbMoeEncoderLayer(nn.Module): return outputs -class NllbMoeDecoderLayer(nn.Module): +class NllbMoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): super().__init__() self.embed_dim = config.d_model @@ -1018,22 +1019,13 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - output_router_logits=output_router_logits, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + output_router_logits=output_router_logits, + ) hidden_states = layer_outputs[0] @@ -1296,37 +1288,18 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.forward, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 17a3319de3a..f5b940157de 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, MaskedLMOutput, @@ -311,7 +312,7 @@ class NystromformerOutput(nn.Module): return hidden_states -class NystromformerLayer(nn.Module): +class NystromformerLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -363,15 +364,7 @@ class NystromformerEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index d904deb1493..61732ab1c25 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -24,6 +24,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel @@ -611,7 +612,7 @@ class OlmoeSparseMoeBlock(nn.Module): return final_hidden_states, router_logits -class OlmoeDecoderLayer(nn.Module): +class OlmoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: OlmoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -828,31 +829,17 @@ class OlmoeModel(OlmoePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 3380118dd9e..1007be135ed 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -31,6 +31,7 @@ from ...file_utils import ( ) from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from ...utils.backbone_utils import load_backbone @@ -879,7 +880,7 @@ class OmDetTurboTaskEncoder(nn.Module): return x -class OmDetTurboDeformableTransformerDecoderLayer(nn.Module): +class OmDetTurboDeformableTransformerDecoderLayer(GradientCheckpointingLayer): """ A single layer of the Deformable Transformer Decoder. """ @@ -1376,37 +1377,19 @@ class OmDetTurboDecoder(OmDetTurboPreTrainedModel): last_refined_bbox = None reference_points = reference_points.sigmoid() for i, layer in enumerate(self.layers): - if self.gradient_checkpointing and self.training: - predicted_class_features, task_features, self_attention, cross_attention = ( - self._gradient_checkpointing_func( - layer.__call__, - predicted_class_features, - task_features, - reference_points, - vision_features, - vision_shapes, - vision_shapes_list, - level_start_index=level_start_index, - attention_mask=attention_mask, - query_position=self.query_position_head(reference_points), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - ) - else: - predicted_class_features, task_features, self_attention, cross_attention = layer( - predicted_class_features, - task_features, - reference_points, - vision_features, - vision_shapes, - vision_shapes_list, - level_start_index=level_start_index, - attention_mask=attention_mask, - query_position=self.query_position_head(reference_points), - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + predicted_class_features, task_features, self_attention, cross_attention = layer( + predicted_class_features, + task_features, + reference_points, + vision_features, + vision_shapes, + vision_shapes_list, + level_start_index=level_start_index, + attention_mask=attention_mask, + query_position=self.query_position_head(reference_points), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) if output_attentions: all_self_attns = all_self_attns + (self_attention,) all_cross_attns = all_cross_attns + (cross_attention,) diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index d400e08cd18..99dde2fabba 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -26,6 +26,7 @@ from torch import Tensor, nn from torch.cuda.amp import autocast from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -2563,7 +2564,7 @@ class OneFormerTextMLP(nn.Module): return hidden_states -class OneFormerTextTransformerLayer(nn.Module): +class OneFormerTextTransformerLayer(GradientCheckpointingLayer): def __init__(self, width: int, heads: int, attn_mask: torch.Tensor, layer_norm_eps=1e-05): super().__init__() self.self_attn = nn.MultiheadAttention(width, heads) @@ -2617,10 +2618,7 @@ class OneFormerTextTransformer(nn.Module): def forward(self, hidden_states: torch.Tensor): for layer in self.layers: - if self.use_checkpoint: - hidden_states = self._gradient_checkpointing_func(layer, hidden_states) - else: - hidden_states = layer(hidden_states) + hidden_states = layer(hidden_states) return hidden_states diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 0629fe2ad19..d18378ba8e9 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -200,7 +201,7 @@ class OPTAttention(nn.Module): return attn_output, attn_weights, past_key_value -class OPTDecoderLayer(nn.Module): +class OPTDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: OPTConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.hidden_size @@ -668,30 +669,17 @@ class OPTDecoder(OPTPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - head_mask[idx] if head_mask is not None else None, - None, - output_attentions, - use_cache, - position_ids, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index 4e838c90a9f..ee4fb714f54 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -24,6 +24,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_vision_available, logging, torch_int @@ -497,7 +498,7 @@ class Owlv2MLP(nn.Module): # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Owlv2 -class Owlv2EncoderLayer(nn.Module): +class Owlv2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Owlv2Config): super().__init__() self.embed_dim = config.hidden_size @@ -655,21 +656,12 @@ class Owlv2Encoder(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index d487706611d..4e269cd4659 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -24,6 +24,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, is_vision_available, logging, torch_int @@ -485,7 +486,7 @@ class OwlViTMLP(nn.Module): # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->OwlViT -class OwlViTEncoderLayer(nn.Module): +class OwlViTEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: OwlViTConfig): super().__init__() self.embed_dim = config.hidden_size @@ -641,21 +642,12 @@ class OwlViTEncoder(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 6922d3c815b..2ffb53ee9e0 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -269,7 +270,7 @@ class PegasusAttention(nn.Module): # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS -class PegasusEncoderLayer(nn.Module): +class PegasusEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model @@ -338,7 +339,7 @@ class PegasusEncoderLayer(nn.Module): # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS -class PegasusDecoderLayer(nn.Module): +class PegasusDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PegasusConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -845,21 +846,12 @@ class PegasusEncoder(PegasusPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1135,35 +1127,18 @@ class PegasusDecoder(PegasusPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 0c48aa61412..842e365fb82 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -528,7 +529,7 @@ class PegasusXGlobalLocalAttention(nn.Module): return attn_output, attn_probs -class PegasusXEncoderLayer(nn.Module): +class PegasusXEncoderLayer(GradientCheckpointingLayer): def __init__(self, stagger_blocks_this_layer: bool, config: PegasusXConfig): super().__init__() self.embed_dim = config.d_model @@ -643,7 +644,7 @@ class PegasusXEncoderLayer(nn.Module): return padded_hidden_states[:, pad_size:-pad_size, :] -class PegasusXDecoderLayer(nn.Module): +class PegasusXDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PegasusXConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -1148,21 +1149,12 @@ class PegasusXEncoder(PegasusXPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - global_hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - global_hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + global_hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] global_hidden_states = layer_outputs[1] @@ -1388,29 +1380,16 @@ class PegasusXDecoder(PegasusXPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_attention_mask, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index ce142c8d6d2..f2bbef331a0 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -30,6 +30,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -299,7 +300,7 @@ class PersimmonAttention(nn.Module): return attn_output, attn_weights, past_key_value -class PersimmonDecoderLayer(nn.Module): +class PersimmonDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PersimmonConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -517,30 +518,17 @@ class PersimmonModel(PersimmonPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 5edceb27c0d..95164a5f5db 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -4,7 +4,6 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_phi.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -from functools import partial from typing import Callable, Optional, Union import torch @@ -15,6 +14,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -206,7 +206,7 @@ class PhiMLP(nn.Module): return hidden_states -class PhiDecoderLayer(nn.Module): +class PhiDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() self.self_attn = PhiAttention(config, layer_idx=layer_idx) @@ -410,30 +410,17 @@ class PhiModel(PhiPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index c515e13e723..46a367bbdb1 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -1,4 +1,3 @@ -from functools import partial from typing import Callable, Optional import torch @@ -7,6 +6,7 @@ import torch.nn as nn from ...cache_utils import Cache, DynamicCache from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, ) @@ -118,7 +118,7 @@ class PhiMLP(CLIPMLP): pass -class PhiDecoderLayer(nn.Module): +class PhiDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() self.self_attn = PhiAttention(config, layer_idx=layer_idx) @@ -261,30 +261,17 @@ class PhiModel(LlamaModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index b1651f467b4..af4e823a98d 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_flash_attention_utils import is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel @@ -795,7 +796,7 @@ class PhimoeSparseMoeBlock(nn.Module): return final_hidden_states, router_logits -class PhimoeDecoderLayer(nn.Module): +class PhimoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PhimoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -1018,31 +1019,17 @@ class PhimoeModel(PhimoePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 4f9e42cc94c..6b90ae80d7c 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -259,7 +260,7 @@ class Pix2StructVisionMlp(nn.Module): return hidden_states -class Pix2StructVisionLayer(nn.Module): +class Pix2StructVisionLayer(GradientCheckpointingLayer): def __init__(self, config: Pix2StructConfig) -> None: super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -327,16 +328,7 @@ class Pix2StructVisionEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] @@ -925,7 +917,7 @@ class Pix2StructTextLayerCrossAttention(nn.Module): return outputs -class Pix2StructTextBlock(nn.Module): +class Pix2StructTextBlock(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() @@ -1148,6 +1140,12 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training and use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: @@ -1241,42 +1239,20 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index 998124a8da6..f1d5ab06d54 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -23,6 +23,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel @@ -272,7 +273,7 @@ class PixtralRMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class PixtralAttentionLayer(nn.Module): +class PixtralAttentionLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention_norm = PixtralRMSNorm(config.hidden_size, eps=1e-5) @@ -374,22 +375,13 @@ class PixtralTransformer(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - position_embeddings, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - position_embeddings=position_embeddings, - output_attentions=output_attentions, - **kwargs, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index a192fe70e23..327b70b5ec7 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -36,6 +36,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -465,7 +466,7 @@ class PLBartAttention(nn.Module): return attn_output, attn_weights, past_key_value -class PLBartEncoderLayer(nn.Module): +class PLBartEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -683,21 +684,12 @@ class PLBartEncoder(PLBartPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -714,7 +706,7 @@ class PLBartEncoder(PLBartPreTrainedModel): ) -class PLBartDecoderLayer(nn.Module): +class PLBartDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -1064,35 +1056,18 @@ class PLBartDecoder(PLBartPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 48f930dacb6..5c4285afe72 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -467,7 +468,7 @@ class Pop2PianoLayerCrossAttention(nn.Module): # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano -class Pop2PianoBlock(nn.Module): +class Pop2PianoBlock(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -814,37 +815,20 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index d7783c48e0a..eb0b2e59471 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -27,6 +27,7 @@ from torch.nn import LayerNorm from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging @@ -956,7 +957,7 @@ class ProphetNetNgramSelfAttention(nn.Module): return predict_relative_pos_embeddings -class ProphetNetEncoderLayer(nn.Module): +class ProphetNetEncoderLayer(GradientCheckpointingLayer): """ Encoder block for Prophetnet """ @@ -999,7 +1000,7 @@ class ProphetNetEncoderLayer(nn.Module): return outputs -class ProphetNetDecoderLayer(nn.Module): +class ProphetNetDecoderLayer(GradientCheckpointingLayer): """ Decoder block for Prophetnet """ @@ -1183,21 +1184,12 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): if output_hidden_states: encoder_hidden_states = encoder_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - extended_attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=extended_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=extended_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1395,41 +1387,21 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - extended_attention_mask, - encoder_hidden_states, - extended_encoder_attention_mask, - (head_mask[idx] if head_mask is not None else None), - (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - extended_predict_attention_mask, - main_relative_position_buckets, - predict_relative_position_buckets, - position_ids, - None, - use_cache, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=extended_attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attn_mask=extended_encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - extended_predict_attention_mask=extended_predict_attention_mask, - main_relative_position_buckets=main_relative_position_buckets, - predict_relative_position_buckets=predict_relative_position_buckets, - position_ids=position_ids, - past_key_value=past_key_value, - use_cache=use_cache, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + extended_attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attn_mask=extended_encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + extended_predict_attention_mask=extended_predict_attention_mask, + main_relative_position_buckets=main_relative_position_buckets, + predict_relative_position_buckets=predict_relative_position_buckets, + position_ids=position_ids, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/pvt_v2/modeling_pvt_v2.py b/src/transformers/models/pvt_v2/modeling_pvt_v2.py index 7c2f48bd580..b357cb5970a 100644 --- a/src/transformers/models/pvt_v2/modeling_pvt_v2.py +++ b/src/transformers/models/pvt_v2/modeling_pvt_v2.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -300,7 +301,7 @@ class PvtV2BlockLayer(nn.Module): return outputs -class PvtV2EncoderLayer(nn.Module): +class PvtV2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: PvtV2Config, layer_idx: int): super().__init__() self.patch_embedding = PvtV2OverlapPatchEmbeddings( @@ -367,10 +368,7 @@ class PvtV2Encoder(nn.Module): batch_size = pixel_values.shape[0] hidden_states = pixel_values for idx, layer in enumerate(self.layers): - if self.gradient_checkpointing and self.training: - layer_output = self._gradient_checkpointing_func(layer.__call__, hidden_states, output_attentions) - else: - layer_output = layer(hidden_states, output_attentions) + layer_output = layer(hidden_states, output_attentions) outputs, height, width = layer_output hidden_states = outputs[0] if output_attentions: diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index dd6b6030677..e7f362c1f4a 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -727,7 +727,7 @@ QWEN2_5_OMNI_AUDIO_ATTENTION_CLASSES = { } -class Qwen2_5OmniAudioEncoderLayer(nn.Module): +class Qwen2_5OmniAudioEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): super().__init__() self.embed_dim = config.d_model @@ -889,19 +889,8 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): ) ).to(torch.int32) - for idx, encoder_layer in enumerate(self.layers): - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - cu_seqlens, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - cu_seqlens, - ) - + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states, cu_seqlens) hidden_states = layer_outputs[0] hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) @@ -1107,7 +1096,7 @@ QWEN2_5_OMNI_VISION_ATTENTION_CLASSES = { } -class Qwen2_5OmniVisionBlock(nn.Module): +class Qwen2_5OmniVisionBlock(GradientCheckpointingLayer): def __init__(self, config: Qwen2_5OmniVisionEncoderConfig) -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) @@ -1324,16 +1313,11 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5OmniPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb - ) - else: - hidden_states = blk( - hidden_states, - cu_seqlens=cu_seqlens_now, - rotary_pos_emb=rotary_pos_emb, - ) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] @@ -1760,30 +1744,17 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] @@ -2331,30 +2302,17 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 4412184cd3f..0d55daf14c1 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1900,19 +1900,8 @@ class Qwen2_5OmniAudioEncoder(Qwen2_5OmniPreTrainedModel): ) ).to(torch.int32) - for idx, encoder_layer in enumerate(self.layers): - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - cu_seqlens, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - cu_seqlens, - ) - + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states, cu_seqlens) hidden_states = layer_outputs[0] hidden_states_list = hidden_states.split(aftercnn_lens.tolist(), dim=0) @@ -2166,16 +2155,11 @@ class Qwen2_5OmniVisionEncoder(Qwen2_5_VisionTransformerPretrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens_now, rotary_pos_emb - ) - else: - hidden_states = blk( - hidden_states, - cu_seqlens=cu_seqlens_now, - rotary_pos_emb=rotary_pos_emb, - ) + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb, + ) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) hidden_states = hidden_states[reverse_indices, :] diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index c851244e6b7..75a5b79f8b5 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -314,7 +314,7 @@ QWEN2_5_VL_VISION_ATTENTION_CLASSES = { } -class Qwen2_5_VLVisionBlock(nn.Module): +class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) @@ -516,12 +516,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings - ) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) @@ -991,30 +986,17 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 71764b24563..a6dedefa019 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -50,6 +50,7 @@ from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput from ...modeling_flash_attention_utils import is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...processing_utils import MultiModalData, ProcessingKwargs, Unpack, VideosKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import is_torchdynamo_compiling, logging @@ -205,7 +206,7 @@ QWEN2_5_VL_VISION_ATTENTION_CLASSES = { } -class Qwen2_5_VLVisionBlock(nn.Module): +class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6) @@ -395,12 +396,7 @@ class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel): cu_seqlens_now = cu_seqlens else: cu_seqlens_now = cu_window_seqlens - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings - ) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings) hidden_states = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 6569d78674e..8e331e1bd0b 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import auto_docstring, logging @@ -200,7 +201,7 @@ class Qwen2AudioAttention(nn.Module): # Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO -class Qwen2AudioEncoderLayer(nn.Module): +class Qwen2AudioEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2AudioConfig): super().__init__() self.embed_dim = config.d_model @@ -436,21 +437,12 @@ class Qwen2AudioEncoder(Qwen2AudioPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index cc617533582..243bd12baca 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -32,6 +32,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, @@ -639,7 +640,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): return final_hidden_states, router_logits -class Qwen2MoeDecoderLayer(nn.Module): +class Qwen2MoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen2MoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -865,31 +866,17 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index acf1e4025b9..32b9a416600 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -439,7 +439,7 @@ QWEN2_VL_VISION_ATTENTION_CLASSES = { } -class Qwen2VLVisionBlock(nn.Module): +class Qwen2VLVisionBlock(GradientCheckpointingLayer): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__() self.norm1 = LayerNorm(config.embed_dim, eps=1e-6) @@ -837,12 +837,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel): cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for blk in self.blocks: - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - blk.__call__, hidden_states, cu_seqlens, None, position_embeddings - ) - else: - hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings) return self.merger(hidden_states) @@ -959,30 +954,17 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask_mapping[decoder_layer.attention_type], - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask_mapping[decoder_layer.attention_type], - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 67f21d1b836..1f74f5e5589 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial from typing import Callable, Optional, Union import torch @@ -32,6 +31,7 @@ from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, MoeCausalLMOutputWithPast, @@ -286,7 +286,7 @@ class Qwen3MoeRMSNorm(nn.Module): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Qwen3MoeDecoderLayer(nn.Module): +class Qwen3MoeDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3MoeConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -541,32 +541,18 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - partial(decoder_layer.__call__, **flash_attn_kwargs), - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 50d19a33b73..4b82c490cc9 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -25,6 +25,7 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging @@ -467,7 +468,7 @@ class RecurrentGemmaMlp(nn.Module): return self.down_proj(gate * self.up_proj(hidden_states)) -class RecurrentGemmaDecoderLayer(nn.Module): +class RecurrentGemmaDecoderLayer(GradientCheckpointingLayer): """Griffin and Hawk's residual block.""" def __init__(self, config, layer_idx): @@ -644,12 +645,7 @@ class RecurrentGemmaModel(RecurrentGemmaPreTrainedModel): for i, residual_block in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - residual_block.__call__, hidden_states, position_ids, causal_mask, cache_position, use_cache - ) - else: - hidden_states = residual_block(hidden_states, position_ids, causal_mask, cache_position, use_cache) + hidden_states = residual_block(hidden_states, position_ids, causal_mask, cache_position, use_cache) hidden_states = self.final_norm(hidden_states) diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 22a3ba7aeb9..774fc46ef7f 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -399,7 +400,7 @@ class RemBertOutput(nn.Module): return hidden_states -class RemBertLayer(nn.Module): +class RemBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -528,27 +529,15 @@ class RemBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 9381c8f9ab0..ecf3a6cc531 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -477,7 +478,7 @@ class RobertaOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta -class RobertaLayer(nn.Module): +class RobertaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -603,27 +604,15 @@ class RobertaEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 31d459e7d8b..e8636281e39 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -365,7 +366,7 @@ class RobertaPreLayerNormOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RobertaPreLayerNorm -class RobertaPreLayerNormLayer(nn.Module): +class RobertaPreLayerNormLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -491,27 +492,15 @@ class RobertaPreLayerNormEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 06721ae7719..3985e86e0b3 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -488,7 +489,7 @@ class RoCBertOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RoCBert -class RoCBertLayer(nn.Module): +class RoCBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -614,27 +615,15 @@ class RoCBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 3f9b2875c20..8439fed19cf 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -423,7 +424,7 @@ class RoFormerOutput(nn.Module): return hidden_states -class RoFormerLayer(nn.Module): +class RoFormerLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -558,29 +559,16 @@ class RoFormerEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - sinusoidal_pos, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - sinusoidal_pos, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + sinusoidal_pos, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index b7362000b57..8f045398102 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -25,6 +25,7 @@ import torch.utils.checkpoint from torch import nn from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -344,7 +345,7 @@ class RwkvFeedForward(nn.Module): return receptance * value, state -class RwkvBlock(nn.Module): +class RwkvBlock(GradientCheckpointingLayer): def __init__(self, config, layer_id): super().__init__() self.config = config @@ -604,14 +605,9 @@ class RwkvModel(RwkvPreTrainedModel): all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for idx, block in enumerate(self.blocks): - if self.gradient_checkpointing and self.training: - hidden_states, state, attentions = self._gradient_checkpointing_func( - block.__call__, hidden_states, state, use_cache, output_attentions - ) - else: - hidden_states, state, attentions = block( - hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions - ) + hidden_states, state, attentions = block( + hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions + ) if ( self.layers_are_rescaled diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index a9088958a8f..31cdec6d5f7 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -25,6 +25,7 @@ import torch.utils.checkpoint from torch import Tensor, nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -969,7 +970,7 @@ SAM_VISION_ATTENTION_CLASSES = { } -class SamVisionLayer(nn.Module): +class SamVisionLayer(GradientCheckpointingLayer): def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -1145,13 +1146,7 @@ class SamVisionEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/sam_hq/modeling_sam_hq.py b/src/transformers/models/sam_hq/modeling_sam_hq.py index 20391169855..982bbbb47e0 100644 --- a/src/transformers/models/sam_hq/modeling_sam_hq.py +++ b/src/transformers/models/sam_hq/modeling_sam_hq.py @@ -29,6 +29,7 @@ import torch.nn.functional as F from torch import Tensor, nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging @@ -364,7 +365,7 @@ SAM_HQ_VISION_ATTENTION_CLASSES = { } -class SamHQVisionLayer(nn.Module): +class SamHQVisionLayer(GradientCheckpointingLayer): def __init__(self, config, window_size): super().__init__() self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -539,18 +540,10 @@ class SamHQVisionEncoder(nn.Module): all_self_attentions = () if output_attentions else None intermediate_embeddings = [] - for i, layer_module in enumerate(self.layers): + for layer_module in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) - + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) hidden_states = layer_outputs[0] # Collect embeddings from non-windowed blocks diff --git a/src/transformers/models/sam_hq/modular_sam_hq.py b/src/transformers/models/sam_hq/modular_sam_hq.py index a78ce712cc0..55f475880ca 100644 --- a/src/transformers/models/sam_hq/modular_sam_hq.py +++ b/src/transformers/models/sam_hq/modular_sam_hq.py @@ -151,18 +151,10 @@ class SamHQVisionEncoder(SamVisionEncoder): all_self_attentions = () if output_attentions else None intermediate_embeddings = [] - for i, layer_module in enumerate(self.layers): + for layer_module in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) - + layer_outputs = layer_module(hidden_states, output_attentions=output_attentions) hidden_states = layer_outputs[0] # Collect embeddings from non-windowed blocks diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 24f02c3e6b4..65feeb2d222 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -32,6 +32,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -610,7 +611,7 @@ class SeamlessM4TConformerSelfAttention(nn.Module): return scores -class SeamlessM4TConformerEncoderLayer(nn.Module): +class SeamlessM4TConformerEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerEncoderLayer.__init__ with Wav2Vec2->SeamlessM4T, attention_dropout->speech_encoder_dropout, torch.nn->nn @@ -743,23 +744,13 @@ class SeamlessM4TConformerEncoder(nn.Module): ) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - conv_attention_mask, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - conv_attention_mask=conv_attention_mask, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -1173,7 +1164,7 @@ class SeamlessM4TFeedForwardNetwork(nn.Module): return hidden_states -class SeamlessM4TEncoderLayer(nn.Module): +class SeamlessM4TEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: SeamlessM4TConfig, encoder_ffn_dim=None, encoder_attention_heads=None): super().__init__() encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim @@ -1236,7 +1227,7 @@ class SeamlessM4TEncoderLayer(nn.Module): return outputs -class SeamlessM4TDecoderLayer(nn.Module): +class SeamlessM4TDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: SeamlessM4TConfig, decoder_ffn_dim=None, decoder_attention_heads=None): super().__init__() decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim @@ -1691,19 +1682,11 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.forward, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1866,27 +1849,15 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 95c586bfd76..7427f1dfab2 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -29,6 +29,7 @@ from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -489,7 +490,7 @@ class SeamlessM4Tv2ConformerSelfAttention(nn.Module): return attn_output, attn_weights -class SeamlessM4Tv2ConformerEncoderLayer(nn.Module): +class SeamlessM4Tv2ConformerEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" # Copied from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer.Wav2Vec2ConformerEncoderLayer.__init__ with Wav2Vec2->SeamlessM4Tv2, attention_dropout->speech_encoder_dropout, torch.nn->nn @@ -645,21 +646,12 @@ class SeamlessM4Tv2ConformerEncoder(nn.Module): ) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - conv_attention_mask, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - conv_attention_mask=conv_attention_mask, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -1031,7 +1023,7 @@ class SeamlessM4Tv2FeedForwardNetwork(nn.Module): # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TEncoderLayer with SeamlessM4T->SeamlessM4Tv2 -class SeamlessM4Tv2EncoderLayer(nn.Module): +class SeamlessM4Tv2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: SeamlessM4Tv2Config, encoder_ffn_dim=None, encoder_attention_heads=None): super().__init__() encoder_ffn_dim = config.encoder_ffn_dim if encoder_ffn_dim is None else encoder_ffn_dim @@ -1095,7 +1087,7 @@ class SeamlessM4Tv2EncoderLayer(nn.Module): # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TDecoderLayer with SeamlessM4T->SeamlessM4Tv2 -class SeamlessM4Tv2DecoderLayer(nn.Module): +class SeamlessM4Tv2DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_attention_heads=None): super().__init__() decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim @@ -1210,7 +1202,7 @@ class SeamlessM4Tv2DecoderLayer(nn.Module): return outputs -class SeamlessM4Tv2TextToUnitDecoderLayer(nn.Module): +class SeamlessM4Tv2TextToUnitDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: SeamlessM4Tv2Config, decoder_ffn_dim=None, decoder_attention_heads=None): super().__init__() decoder_ffn_dim = config.decoder_ffn_dim if decoder_ffn_dim is None else decoder_ffn_dim @@ -1760,19 +1752,11 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.forward, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1936,27 +1920,15 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: @@ -2137,21 +2109,12 @@ class SeamlessM4Tv2TextToUnitDecoder(SeamlessM4Tv2PreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - padding_mask, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - padding_mask=padding_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + padding_mask=padding_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/seggpt/modeling_seggpt.py b/src/transformers/models/seggpt/modeling_seggpt.py index 69c5ce88f7a..cf6b6db3f2a 100644 --- a/src/transformers/models/seggpt/modeling_seggpt.py +++ b/src/transformers/models/seggpt/modeling_seggpt.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import functional as F from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging, torch_int from .configuration_seggpt import SegGptConfig @@ -395,7 +396,7 @@ class SegGptDropPath(nn.Module): return f"p={self.drop_prob}" -class SegGptLayer(nn.Module): +class SegGptLayer(GradientCheckpointingLayer): def __init__(self, config: SegGptConfig, drop_path_rate: float) -> None: super().__init__() self.attention = SegGptAttention(config) @@ -470,16 +471,7 @@ class SegGptEncoder(nn.Module): # Condition to check if we have the appropriate number of prompts to ensemble ensemble_cond = 2 if self.config.merge_index > i else 1 - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - ensemble_cond, - feature_ensemble, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, ensemble_cond, feature_ensemble, output_attentions) + layer_outputs = layer_module(hidden_states, ensemble_cond, feature_ensemble, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 30949092bd0..da4a54b39fc 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -32,6 +32,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -42,7 +43,7 @@ from .configuration_sew import SEWConfig logger = logging.get_logger(__name__) -class SEWNoLayerNormConvLayer(nn.Module): +class SEWNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -63,7 +64,7 @@ class SEWNoLayerNormConvLayer(nn.Module): return hidden_states -class SEWLayerNormConvLayer(nn.Module): +class SEWLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -90,7 +91,7 @@ class SEWLayerNormConvLayer(nn.Module): return hidden_states -class SEWGroupNormConvLayer(nn.Module): +class SEWGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -223,13 +224,7 @@ class SEWFeatureEncoder(nn.Module): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -410,7 +405,7 @@ class SEWFeedForward(nn.Module): return hidden_states -class SEWEncoderLayer(nn.Module): +class SEWEncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = SEWAttention( @@ -521,17 +516,9 @@ class SEWEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index 0b151b05a2c..2d56fea3bc6 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -230,17 +230,9 @@ class SEWEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index f2d682884c4..5e00ddcd1f3 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -27,6 +27,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import softmax_backward_data @@ -242,7 +243,7 @@ def get_mask(input, local_context): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEWD -class SEWDNoLayerNormConvLayer(nn.Module): +class SEWDNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -264,7 +265,7 @@ class SEWDNoLayerNormConvLayer(nn.Module): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEWD -class SEWDLayerNormConvLayer(nn.Module): +class SEWDLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -292,7 +293,7 @@ class SEWDLayerNormConvLayer(nn.Module): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEWD -class SEWDGroupNormConvLayer(nn.Module): +class SEWDGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -429,13 +430,7 @@ class SEWDFeatureEncoder(nn.Module): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -930,7 +925,7 @@ class SEWDOutput(nn.Module): return hidden_states -class SEWDLayer(nn.Module): +class SEWDLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = SEWDAttention(config) @@ -1087,25 +1082,14 @@ class SEWDTransformerEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (output_states,) - if self.gradient_checkpointing and self.training: - output_states = self._gradient_checkpointing_func( - layer_module.__call__, - next_kv, - attention_mask, - query_states, - relative_pos, - rel_embeddings, - output_attentions, - ) - else: - output_states = layer_module( - next_kv, - attention_mask, - query_states=query_states, - relative_pos=relative_pos, - rel_embeddings=rel_embeddings, - output_attentions=output_attentions, - ) + output_states = layer_module( + next_kv, + attention_mask, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + output_attentions=output_attentions, + ) if output_attentions: output_states, att_m = output_states diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 383450aae1f..3cf16a992c9 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -30,6 +30,7 @@ from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -239,7 +240,7 @@ class SmolVLMVisionMLP(nn.Module): return hidden_states -class SmolVLMEncoderLayer(nn.Module): +class SmolVLMEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: SmolVLMVisionConfig): super().__init__() self.embed_dim = config.hidden_size @@ -346,19 +347,11 @@ class SmolVLMEncoder(nn.Module): for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 375392077c6..aaff8d90fec 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -30,6 +30,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -328,7 +329,7 @@ class Speech2TextAttention(nn.Module): # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT -class Speech2TextEncoderLayer(nn.Module): +class Speech2TextEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model @@ -398,7 +399,7 @@ class Speech2TextEncoderLayer(nn.Module): # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT # TODO: change copy when applying cache class -class Speech2TextDecoderLayer(nn.Module): +class Speech2TextDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model @@ -693,21 +694,12 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -941,33 +933,17 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index c63426468d2..9dfb2653828 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -28,6 +28,7 @@ from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -207,7 +208,7 @@ def _compute_mask_indices( # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SpeechT5 -class SpeechT5NoLayerNormConvLayer(nn.Module): +class SpeechT5NoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -229,7 +230,7 @@ class SpeechT5NoLayerNormConvLayer(nn.Module): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SpeechT5 -class SpeechT5LayerNormConvLayer(nn.Module): +class SpeechT5LayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -257,7 +258,7 @@ class SpeechT5LayerNormConvLayer(nn.Module): # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SpeechT5 -class SpeechT5GroupNormConvLayer(nn.Module): +class SpeechT5GroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -487,13 +488,7 @@ class SpeechT5FeatureEncoder(nn.Module): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -1032,7 +1027,7 @@ class SpeechT5FeedForward(nn.Module): return hidden_states -class SpeechT5EncoderLayer(nn.Module): +class SpeechT5EncoderLayer(GradientCheckpointingLayer): def __init__(self, config: SpeechT5Config): super().__init__() self.attention = SpeechT5Attention( @@ -1093,7 +1088,7 @@ class SpeechT5EncoderLayer(nn.Module): return outputs -class SpeechT5DecoderLayer(nn.Module): +class SpeechT5DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: SpeechT5Config): super().__init__() self.self_attn = SpeechT5Attention( @@ -1338,23 +1333,13 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel): if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - position_bias, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -1636,33 +1621,17 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 06d2917b6a2..1d65ec5b954 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer @@ -330,7 +331,7 @@ class SplinterOutput(nn.Module): # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Splinter -class SplinterLayer(nn.Module): +class SplinterLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -456,27 +457,15 @@ class SplinterEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index df4e41bcd21..0dc1d00890e 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -31,6 +31,7 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -519,7 +520,7 @@ ATTENTION_CLASSES = { } -class StableLmDecoderLayer(nn.Module): +class StableLmDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: StableLmConfig, layer_idx: int): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -744,29 +745,16 @@ class StableLmModel(StableLmPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index a8c29e84785..c62c2e4fc95 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -25,6 +25,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer @@ -739,7 +740,7 @@ class SwinLayer(nn.Module): return layer_outputs -class SwinStage(nn.Module): +class SwinStage(GradientCheckpointingLayer): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): super().__init__() self.config = config @@ -848,19 +849,9 @@ class SwinEncoder(nn.Module): for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - input_dimensions, - layer_head_mask, - output_attentions, - always_partition, - ) - else: - layer_outputs = layer_module( - hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition - ) + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = layer_outputs[1] diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index c63579a014f..ae6e0a6e795 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -24,6 +24,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer @@ -592,7 +593,7 @@ class Swin2SRLayer(nn.Module): return layer_outputs -class Swin2SRStage(nn.Module): +class Swin2SRStage(GradientCheckpointingLayer): """ This corresponds to the Residual Swin Transformer Block (RSTB) in the original implementation. """ @@ -705,12 +706,7 @@ class Swin2SREncoder(nn.Module): for i, stage_module in enumerate(self.stages): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions - ) - else: - layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] output_dimensions = layer_outputs[1] diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 050e8d3fd27..67657f0111a 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -25,6 +25,7 @@ import torch.utils.checkpoint from torch import Tensor, nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer @@ -787,7 +788,7 @@ class Swinv2Layer(nn.Module): return layer_outputs -class Swinv2Stage(nn.Module): +class Swinv2Stage(GradientCheckpointingLayer): def __init__( self, config, dim, input_resolution, depth, num_heads, drop_path, downsample, pretrained_window_size=0 ): @@ -902,17 +903,12 @@ class Swinv2Encoder(nn.Module): for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, hidden_states, input_dimensions, layer_head_mask - ) - else: - layer_outputs = layer_module( - hidden_states, - input_dimensions, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + ) hidden_states = layer_outputs[0] hidden_states_before_downsampling = layer_outputs[1] diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index cf613ee5b82..b0273c8a4a3 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, @@ -661,7 +662,7 @@ class SwitchTransformersLayerCrossAttention(nn.Module): return outputs -class SwitchTransformersBlock(nn.Module): +class SwitchTransformersBlock(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, is_sparse=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -1024,41 +1025,22 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - output_router_logits, - return_dict, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - return_dict=return_dict, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) router_probs = layer_outputs[-1] layer_outputs = layer_outputs[:-1] diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index b099d6cea02..2a1a84b8152 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -645,7 +646,7 @@ class T5LayerCrossAttention(nn.Module): return outputs -class T5Block(nn.Module): +class T5Block(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -1101,39 +1102,21 @@ class T5Stack(T5PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - position_bias, - encoder_hidden_states, - encoder_extended_attention_mask, - encoder_decoder_position_bias, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - return_dict, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 4938ea378df..d55a1be0c55 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -23,6 +23,7 @@ from torch import Tensor, nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -618,7 +619,7 @@ class TableTransformerEncoderLayer(nn.Module): return outputs -class TableTransformerDecoderLayer(nn.Module): +class TableTransformerDecoderLayer(GradientCheckpointingLayer): # Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer.__init__ with Detr->TableTransformer def __init__(self, config: TableTransformerConfig): super().__init__() @@ -989,25 +990,15 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - combined_attention_mask, - encoder_hidden_states, - encoder_attention_mask, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=combined_attention_mask, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = decoder_layer( + hidden_states, + combined_attention_mask, + object_queries, + query_position_embeddings, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index f6463d95a24..c2660b6895a 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer @@ -475,7 +476,7 @@ class TapasOutput(nn.Module): return hidden_states -class TapasLayer(nn.Module): +class TapasLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -591,27 +592,15 @@ class TapasEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_values, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_values, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_values, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 8672900e76a..778a0485b4e 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -30,6 +30,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -435,7 +436,7 @@ class TimeSeriesTransformerAttention(nn.Module): # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER -class TimeSeriesTransformerEncoderLayer(nn.Module): +class TimeSeriesTransformerEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -507,7 +508,7 @@ class TimeSeriesTransformerEncoderLayer(nn.Module): # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER -class TimeSeriesTransformerDecoderLayer(nn.Module): +class TimeSeriesTransformerDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -857,21 +858,12 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1066,35 +1058,18 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 00592039a92..191a65f9b13 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -288,7 +289,7 @@ class TimesformerOutput(nn.Module): # Adapted from https://github.com/facebookresearch/TimeSformer/blob/a5ef29a7b7264baff199a30b3306ac27de901133/timesformer/models/vit.py#L89 -class TimesformerLayer(nn.Module): +class TimesformerLayer(GradientCheckpointingLayer): def __init__(self, config: TimesformerConfig, layer_index: int) -> None: super().__init__() @@ -432,14 +433,7 @@ class TimesformerEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, output_attentions) + layer_outputs = layer_module(hidden_states, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 11b5b4a415f..fdc0cae068a 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -28,6 +28,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging @@ -288,7 +289,7 @@ class TrOCRAttention(nn.Module): return attn_output, attn_weights_reshaped, past_key_value -class TrOCRDecoderLayer(nn.Module): +class TrOCRDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: TrOCRConfig): super().__init__() self.embed_dim = config.hidden_size @@ -643,33 +644,17 @@ class TrOCRDecoder(TrOCRPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/tvp/modeling_tvp.py b/src/transformers/models/tvp/modeling_tvp.py index 01932573f01..cd6e88df846 100644 --- a/src/transformers/models/tvp/modeling_tvp.py +++ b/src/transformers/models/tvp/modeling_tvp.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import prune_linear_layer @@ -455,7 +456,7 @@ class TvpOutputLayer(nn.Module): return hidden_states -class TvpEncodeLayer(nn.Module): +class TvpEncodeLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = TvpAttention(config) @@ -511,16 +512,7 @@ class TvpEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - (head_mask[i] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], output_attentions) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index ba3c0080a27..c78ee0dc5b8 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -38,6 +38,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -743,7 +744,7 @@ class UdopLayerCrossAttention(nn.Module): # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop -class UdopBlock(nn.Module): +class UdopBlock(GradientCheckpointingLayer): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -1295,11 +1296,11 @@ class UdopStack(UdopPreTrainedModel): layer_outputs = layer_module( hidden_states, - attention_mask=causal_mask, - position_bias=position_bias, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - encoder_decoder_position_bias=encoder_decoder_position_bias, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=head_mask[i], past_key_value=past_key_values, use_cache=use_cache, diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index d5bb6718baf..2b1f650c678 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -406,7 +407,7 @@ class UMT5LayerCrossAttention(nn.Module): return outputs -class UMT5Block(nn.Module): +class UMT5Block(GradientCheckpointingLayer): def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder @@ -765,35 +766,20 @@ class UMT5Stack(UMT5PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, - hidden_states, - causal_mask, - encoder_hidden_states, - encoder_extended_attention_mask, - layer_head_mask, - cross_attn_layer_head_mask, - None, # past_key_value is always None with gradient checkpointing - use_cache, - output_attentions, - cache_position, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - layer_head_mask=layer_head_mask, - cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - cache_position=cache_position, - ) + layer_outputs = layer_module( + hidden_states, + causal_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_extended_attention_mask, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) - hidden_states = layer_outputs[0] + hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[1] diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 02fa44891f5..af8b151ac6b 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -34,6 +34,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -146,7 +147,7 @@ class UniSpeechPositionalConvEmbedding(nn.Module): return hidden_states -class UniSpeechNoLayerNormConvLayer(nn.Module): +class UniSpeechNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -167,7 +168,7 @@ class UniSpeechNoLayerNormConvLayer(nn.Module): return hidden_states -class UniSpeechLayerNormConvLayer(nn.Module): +class UniSpeechLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -194,7 +195,7 @@ class UniSpeechLayerNormConvLayer(nn.Module): return hidden_states -class UniSpeechGroupNormConvLayer(nn.Module): +class UniSpeechGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -254,13 +255,7 @@ class UniSpeechFeatureEncoder(nn.Module): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -456,7 +451,7 @@ class UniSpeechFeedForward(nn.Module): return hidden_states -class UniSpeechEncoderLayer(nn.Module): +class UniSpeechEncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = UniSpeechAttention( @@ -540,17 +535,9 @@ class UniSpeechEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -618,7 +605,7 @@ class UniSpeechAttnAdapterLayer(nn.Module): return hidden_states -class UniSpeechEncoderLayerStableLayerNorm(nn.Module): +class UniSpeechEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = UniSpeechAttention( @@ -714,17 +701,9 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 1ecb418c140..c6cc7561aa3 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -34,6 +34,7 @@ from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -149,7 +150,7 @@ class UniSpeechSatPositionalConvEmbedding(nn.Module): return hidden_states -class UniSpeechSatNoLayerNormConvLayer(nn.Module): +class UniSpeechSatNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -170,7 +171,7 @@ class UniSpeechSatNoLayerNormConvLayer(nn.Module): return hidden_states -class UniSpeechSatLayerNormConvLayer(nn.Module): +class UniSpeechSatLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -197,7 +198,7 @@ class UniSpeechSatLayerNormConvLayer(nn.Module): return hidden_states -class UniSpeechSatGroupNormConvLayer(nn.Module): +class UniSpeechSatGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -257,13 +258,7 @@ class UniSpeechSatFeatureEncoder(nn.Module): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -459,7 +454,7 @@ class UniSpeechSatFeedForward(nn.Module): return hidden_states -class UniSpeechSatEncoderLayer(nn.Module): +class UniSpeechSatEncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = UniSpeechSatAttention( @@ -543,17 +538,9 @@ class UniSpeechSatEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -621,7 +608,7 @@ class UniSpeechSatAttnAdapterLayer(nn.Module): return hidden_states -class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): +class UniSpeechSatEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = UniSpeechSatAttention( @@ -717,17 +704,9 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index c418a3b49c5..a8278f0b892 100755 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -389,7 +390,7 @@ class VideoMAEOutput(nn.Module): # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE -class VideoMAELayer(nn.Module): +class VideoMAELayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: VideoMAEConfig) -> None: @@ -456,15 +457,7 @@ class VideoMAEEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] @@ -698,15 +691,7 @@ class VideoMAEDecoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - None, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 6ce00d9f397..d42ac12605c 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -456,7 +457,7 @@ class ViltOutput(nn.Module): return hidden_states -class ViltLayer(nn.Module): +class ViltLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config): @@ -519,16 +520,7 @@ class ViltEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index ee228e5f428..34adc5dd74d 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -330,7 +331,7 @@ class VisualBertOutput(nn.Module): return hidden_states -class VisualBertLayer(nn.Module): +class VisualBertLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -394,16 +395,7 @@ class VisualBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 6e85320a0fa..dbad9ef41f3 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -346,7 +347,7 @@ class ViTOutput(nn.Module): return hidden_states -class ViTLayer(nn.Module): +class ViTLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ViTConfig) -> None: @@ -412,15 +413,7 @@ class ViTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 61cf29f8569..32b7151169a 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -25,6 +25,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -531,7 +532,7 @@ class ViTMAEOutput(nn.Module): # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE -class ViTMAELayer(nn.Module): +class ViTMAELayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ViTMAEConfig) -> None: @@ -598,15 +599,7 @@ class ViTMAEEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] @@ -864,15 +857,7 @@ class ViTMAEDecoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - None, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) + layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 25efc1ac4de..11155d2d081 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -23,6 +23,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -349,7 +350,7 @@ class ViTMSNOutput(nn.Module): # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN -class ViTMSNLayer(nn.Module): +class ViTMSNLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: ViTMSNConfig) -> None: @@ -416,15 +417,7 @@ class ViTMSNEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index e13e36d08e2..b74bc1008f7 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging @@ -439,7 +440,7 @@ def window_unpartition(windows, window_size, pad_height_width, height_width): return hidden_state -class VitDetLayer(nn.Module): +class VitDetLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the original implementation.""" def __init__( @@ -560,15 +561,7 @@ class VitDetEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py index 7579ecb9fbb..fb22d215996 100644 --- a/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +++ b/src/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -27,6 +27,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BackboneOutput, BaseModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -302,7 +303,7 @@ class VitPoseBackboneMLP(nn.Module): return hidden_state -class VitPoseBackboneLayer(nn.Module): +class VitPoseBackboneLayer(GradientCheckpointingLayer): def __init__(self, config: VitPoseBackboneConfig) -> None: super().__init__() self.num_experts = config.num_experts @@ -377,16 +378,7 @@ class VitPoseBackboneEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - dataset_index, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index 65b77a4ccf2..e202f98070c 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging @@ -1067,7 +1068,7 @@ class VitsFeedForward(nn.Module): return hidden_states -class VitsEncoderLayer(nn.Module): +class VitsEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: VitsConfig): super().__init__() self.attention = VitsAttention(config) @@ -1145,21 +1146,12 @@ class VitsEncoder(nn.Module): skip_the_layer = self.training and (dropout_probability < self.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - padding_mask, - attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask=attention_mask, - padding_mask=padding_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + padding_mask=padding_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 0617e20de3b..7011552db82 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -342,7 +343,7 @@ class VivitOutput(nn.Module): return hidden_states -class VivitLayer(nn.Module): +class VivitLayer(GradientCheckpointingLayer): """This corresponds to the EncoderBlock class in the scenic/vivit implementation.""" def __init__(self, config): @@ -405,15 +406,7 @@ class VivitEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 6057e0c9fb2..153f067782c 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -271,7 +272,7 @@ def _sample_negative_indices( return sampled_negative_indices -class Wav2Vec2NoLayerNormConvLayer(nn.Module): +class Wav2Vec2NoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -292,7 +293,7 @@ class Wav2Vec2NoLayerNormConvLayer(nn.Module): return hidden_states -class Wav2Vec2LayerNormConvLayer(nn.Module): +class Wav2Vec2LayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -319,7 +320,7 @@ class Wav2Vec2LayerNormConvLayer(nn.Module): return hidden_states -class Wav2Vec2GroupNormConvLayer(nn.Module): +class Wav2Vec2GroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -434,13 +435,7 @@ class Wav2Vec2FeatureEncoder(nn.Module): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -648,7 +643,7 @@ class Wav2Vec2FeedForward(nn.Module): return hidden_states -class Wav2Vec2EncoderLayer(nn.Module): +class Wav2Vec2EncoderLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = Wav2Vec2Attention( @@ -684,7 +679,7 @@ class Wav2Vec2EncoderLayer(nn.Module): return outputs -class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): +class Wav2Vec2EncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.attention = Wav2Vec2Attention( @@ -778,17 +773,9 @@ class Wav2Vec2Encoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: @@ -882,17 +869,9 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py index f938fa20bfd..e41d7f32ffc 100644 --- a/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -17,6 +17,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -394,7 +395,7 @@ class Wav2Vec2BertSelfAttention(nn.Module): return scores -class Wav2Vec2BertEncoderLayer(nn.Module): +class Wav2Vec2BertEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" def __init__(self, config): @@ -520,23 +521,13 @@ class Wav2Vec2BertEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - conv_attention_mask, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - conv_attention_mask=conv_attention_mask, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py index 3427a01808a..d0f375332b3 100644 --- a/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +++ b/src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py @@ -9,6 +9,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -292,7 +293,7 @@ class Wav2Vec2BertSelfAttention(Wav2Vec2ConformerSelfAttention, nn.Module): return hidden_states, probs -class Wav2Vec2BertEncoderLayer(nn.Module): +class Wav2Vec2BertEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" def __init__(self, config): @@ -418,23 +419,13 @@ class Wav2Vec2BertEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - conv_attention_mask, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - conv_attention_mask=conv_attention_mask, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index eb28f7f9554..d5987ca172b 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -17,6 +17,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -216,7 +217,7 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module): return relative_position_embeddings -class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module): +class Wav2Vec2ConformerNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -237,7 +238,7 @@ class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module): return hidden_states -class Wav2Vec2ConformerLayerNormConvLayer(nn.Module): +class Wav2Vec2ConformerLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -264,7 +265,7 @@ class Wav2Vec2ConformerLayerNormConvLayer(nn.Module): return hidden_states -class Wav2Vec2ConformerGroupNormConvLayer(nn.Module): +class Wav2Vec2ConformerGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -324,13 +325,7 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states @@ -582,7 +577,7 @@ class Wav2Vec2ConformerSelfAttention(nn.Module): return scores -class Wav2Vec2ConformerEncoderLayer(nn.Module): +class Wav2Vec2ConformerEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" def __init__(self, config): @@ -709,21 +704,12 @@ class Wav2Vec2ConformerEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py index fc3444a545f..3436563c0db 100644 --- a/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py @@ -8,6 +8,7 @@ from torch import nn from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging @@ -384,7 +385,7 @@ class Wav2Vec2ConformerSelfAttention(nn.Module): return scores -class Wav2Vec2ConformerEncoderLayer(nn.Module): +class Wav2Vec2ConformerEncoderLayer(GradientCheckpointingLayer): """Conformer block based on https://huggingface.co/papers/2005.08100.""" def __init__(self, config): @@ -511,21 +512,12 @@ class Wav2Vec2ConformerEncoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if skip_the_layer: diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index bb9c15002b2..5904f05dcbf 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -17,6 +17,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -294,7 +295,7 @@ class WavLMFeedForward(nn.Module): return hidden_states -class WavLMEncoderLayer(nn.Module): +class WavLMEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( @@ -335,7 +336,7 @@ class WavLMEncoderLayer(nn.Module): return outputs -class WavLMEncoderLayerStableLayerNorm(nn.Module): +class WavLMEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( @@ -418,22 +419,13 @@ class WavLMEncoder(nn.Module): skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - position_bias, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - index=i, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + index=i, + ) hidden_states, position_bias = layer_outputs[:2] @@ -504,21 +496,12 @@ class WavLMEncoderStableLayerNorm(nn.Module): if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - position_bias, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - position_bias=position_bias, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + position_bias=position_bias, + ) hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: @@ -696,7 +679,7 @@ class WavLMPreTrainedModel(PreTrainedModel): return attention_mask -class WavLMNoLayerNormConvLayer(nn.Module): +class WavLMNoLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -717,7 +700,7 @@ class WavLMNoLayerNormConvLayer(nn.Module): return hidden_states -class WavLMLayerNormConvLayer(nn.Module): +class WavLMLayerNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -744,7 +727,7 @@ class WavLMLayerNormConvLayer(nn.Module): return hidden_states -class WavLMGroupNormConvLayer(nn.Module): +class WavLMGroupNormConvLayer(GradientCheckpointingLayer): def __init__(self, config, layer_id=0): super().__init__() self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 @@ -801,13 +784,7 @@ class WavLMFeatureEncoder(nn.Module): hidden_states.requires_grad = True for conv_layer in self.conv_layers: - if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - conv_layer.__call__, - hidden_states, - ) - else: - hidden_states = conv_layer(hidden_states) + hidden_states = conv_layer(hidden_states) return hidden_states diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py index 53d29edc0e7..aac25ff262b 100644 --- a/src/transformers/models/wavlm/modular_wavlm.py +++ b/src/transformers/models/wavlm/modular_wavlm.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import logging @@ -205,7 +206,7 @@ class WavLMFeedForward(Wav2Vec2FeedForward): pass -class WavLMEncoderLayer(nn.Module): +class WavLMEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( @@ -246,7 +247,7 @@ class WavLMEncoderLayer(nn.Module): return outputs -class WavLMEncoderLayerStableLayerNorm(nn.Module): +class WavLMEncoderLayerStableLayerNorm(GradientCheckpointingLayer): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True): super().__init__() self.attention = WavLMAttention( @@ -329,22 +330,13 @@ class WavLMEncoder(nn.Module): skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop) if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - position_bias, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - position_bias=position_bias, - output_attentions=output_attentions, - index=i, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + output_attentions=output_attentions, + index=i, + ) hidden_states, position_bias = layer_outputs[:2] @@ -415,21 +407,12 @@ class WavLMEncoderStableLayerNorm(nn.Module): if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - position_bias, - output_attentions, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - output_attentions=output_attentions, - position_bias=position_bias, - ) + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + position_bias=position_bias, + ) hidden_states, position_bias = layer_outputs[:2] if skip_the_layer: diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 613f5fb45a7..14cbaafe47d 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -30,6 +30,7 @@ from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, ) +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -366,7 +367,7 @@ class WhisperAttention(nn.Module): # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER -class WhisperEncoderLayer(nn.Module): +class WhisperEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: WhisperConfig): super().__init__() self.embed_dim = config.d_model @@ -434,7 +435,7 @@ class WhisperEncoderLayer(nn.Module): return outputs -class WhisperDecoderLayer(nn.Module): +class WhisperDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: WhisperConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -713,21 +714,12 @@ class WhisperEncoder(WhisperPreTrainedModel): if to_drop: layer_outputs = (None, None) else: - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - None, - (head_mask[idx] if head_mask is not None else None), - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - None, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -964,34 +956,17 @@ class WhisperDecoder(WhisperPreTrainedModel): if dropout_probability < self.layerdrop: continue - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - encoder_hidden_states, - None, # encoder attention mask - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, # past_key_value - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - encoder_hidden_states=encoder_hidden_states, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_values if use_cache else None, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + encoder_hidden_states=encoder_hidden_states, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_values if use_cache else None, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 41db6f5ce85..90f49571946 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -24,6 +24,7 @@ from torch import nn from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( @@ -338,7 +339,7 @@ class XCLIPMLP(nn.Module): # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->XCLIP -class XCLIPEncoderLayer(nn.Module): +class XCLIPEncoderLayer(GradientCheckpointingLayer): def __init__(self, config: XCLIPConfig): super().__init__() self.embed_dim = config.hidden_size @@ -424,7 +425,7 @@ class XCLIPDropPath(nn.Module): return f"p={self.drop_prob}" -class XCLIPVisionEncoderLayer(nn.Module): +class XCLIPVisionEncoderLayer(GradientCheckpointingLayer): """ This corresponds to the `CrossFramelAttentionBlock` class in the original implementation. """ @@ -625,21 +626,12 @@ class XCLIPEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -842,21 +834,12 @@ class XCLIPVisionEncoder(nn.Module): for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions, - ) - else: - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - causal_attention_mask, - output_attentions=output_attentions, - ) + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 08520e5d3ab..65e3e628404 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -24,6 +24,7 @@ from torch import nn from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging @@ -253,7 +254,7 @@ class XGLMAttention(nn.Module): return attn_output, attn_weights_reshaped, past_key_value -class XGLMDecoderLayer(nn.Module): +class XGLMDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: XGLMConfig): super().__init__() self.embed_dim = config.d_model @@ -547,33 +548,17 @@ class XGLMModel(XGLMPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - encoder_hidden_states, - encoder_attention_mask, - head_mask[idx] if head_mask is not None else None, - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, - None, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - layer_head_mask=(head_mask[idx] if head_mask is not None else None), - cross_attn_layer_head_mask=( - cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None - ), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 323e97a3e53..f66396b135c 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -478,7 +479,7 @@ class XLMRobertaOutput(nn.Module): # Copied from transformers.models.roberta.modeling_roberta.RobertaLayer with Roberta->XLMRoberta -class XLMRobertaLayer(nn.Module): +class XLMRobertaLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -604,27 +605,15 @@ class XLMRobertaEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, # as a positional argument for gradient checkpointing + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index a0162b8252b..7cbeaadb184 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -469,7 +470,7 @@ class XLMRobertaXLOutput(nn.Module): return hidden_states -class XLMRobertaXLLayer(nn.Module): +class XLMRobertaXLLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -596,27 +597,15 @@ class XLMRobertaXLEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 2d794aa6580..84cf9f6d534 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu from ...generation import GenerationMixin +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -423,7 +424,7 @@ class XmodOutput(nn.Module): return hidden_states -class XmodLayer(nn.Module): +class XmodLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -560,29 +561,16 @@ class XmodEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - lang_ids, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - else: - layer_outputs = layer_module( - hidden_states, - lang_ids, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + lang_ids, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) hidden_states = layer_outputs[0] if use_cache: diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index f1a7f4fab82..c61ff8cb85a 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -23,6 +23,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer @@ -403,7 +404,7 @@ class YolosOutput(nn.Module): # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->Yolos,VIT->YOLOS -class YolosLayer(nn.Module): +class YolosLayer(GradientCheckpointingLayer): """This corresponds to the Block class in the timm implementation.""" def __init__(self, config: YolosConfig) -> None: @@ -492,15 +493,7 @@ class YolosEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index be705885988..da35490c59e 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -24,6 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, MaskedLMOutput, @@ -507,7 +508,7 @@ class YosoOutput(nn.Module): return hidden_states -class YosoLayer(nn.Module): +class YosoLayer(GradientCheckpointingLayer): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -559,15 +560,7 @@ class YosoEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) hidden_states = layer_outputs[0] if output_attentions: