From ffff9e70abf90347760191711a1a2a7e04299a10 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 27 Oct 2023 16:15:22 +0200 Subject: [PATCH] [`core`/ `gradient_checkpointing`] Refactor GC - part 2 (#27073) * fix * more fixes * fix other models * fix long t5 * use `gradient_checkpointing_func` instead * fix copies * set `gradient_checkpointing_func` as a private attribute and retrieve previous behaviour * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * replace it with `is_gradient_checkpointing_set` * remove default * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/modeling_utils.py | 28 +++++++++++++++++-- .../models/align/modeling_align.py | 8 ++---- .../models/altclip/modeling_altclip.py | 12 ++------ .../modeling_audio_spectrogram_transformer.py | 8 +----- .../models/autoformer/modeling_autoformer.py | 9 ++---- src/transformers/models/bark/modeling_bark.py | 7 +---- src/transformers/models/bart/modeling_bart.py | 9 ++---- src/transformers/models/beit/modeling_beit.py | 7 +---- src/transformers/models/bert/modeling_bert.py | 7 +---- .../modeling_bert_generation.py | 7 +---- .../models/big_bird/modeling_big_bird.py | 7 +---- .../modeling_bigbird_pegasus.py | 9 ++---- .../models/biogpt/modeling_biogpt.py | 7 +---- src/transformers/models/bit/modeling_bit.py | 6 ---- .../models/blenderbot/modeling_blenderbot.py | 9 ++---- .../modeling_blenderbot_small.py | 9 ++---- src/transformers/models/blip/modeling_blip.py | 9 ++---- .../models/blip/modeling_blip_text.py | 2 +- .../models/blip_2/modeling_blip_2.py | 13 ++------- .../models/bloom/modeling_bloom.py | 7 +---- .../bridgetower/modeling_bridgetower.py | 2 +- src/transformers/models/bros/modeling_bros.py | 2 +- .../models/camembert/modeling_camembert.py | 7 +---- .../models/canine/modeling_canine.py | 7 +---- .../chinese_clip/modeling_chinese_clip.py | 9 ++---- src/transformers/models/clap/modeling_clap.py | 9 ++---- src/transformers/models/clip/modeling_clip.py | 7 +---- .../models/clipseg/modeling_clipseg.py | 7 +---- .../models/codegen/modeling_codegen.py | 7 +---- .../modeling_conditional_detr.py | 7 +---- .../models/convbert/modeling_convbert.py | 7 +---- .../models/convnext/modeling_convnext.py | 6 ---- .../models/convnextv2/modeling_convnextv2.py | 6 ---- .../models/cpmant/modeling_cpmant.py | 6 ---- .../data2vec/modeling_data2vec_audio.py | 9 ++---- .../models/data2vec/modeling_data2vec_text.py | 7 +---- .../data2vec/modeling_data2vec_vision.py | 7 +---- .../models/deberta/modeling_deberta.py | 7 +---- .../models/deberta_v2/modeling_deberta_v2.py | 7 +---- .../modeling_decision_transformer.py | 7 +---- .../modeling_deformable_detr.py | 7 +---- src/transformers/models/deit/modeling_deit.py | 7 +---- .../models/deprecated/mctct/modeling_mctct.py | 7 +---- .../open_llama/modeling_open_llama.py | 7 +---- .../modeling_trajectory_transformer.py | 7 +---- .../models/deprecated/van/modeling_van.py | 5 ---- src/transformers/models/deta/modeling_deta.py | 7 +---- src/transformers/models/detr/modeling_detr.py | 7 +---- .../models/dinat/modeling_dinat.py | 3 -- .../models/dinov2/modeling_dinov2.py | 7 +---- .../models/distilbert/modeling_distilbert.py | 7 +---- .../models/donut/modeling_donut_swin.py | 7 +---- src/transformers/models/dpr/modeling_dpr.py | 7 +---- src/transformers/models/dpt/modeling_dpt.py | 7 +---- .../efficientnet/modeling_efficientnet.py | 6 ---- .../models/electra/modeling_electra.py | 7 +---- .../models/encodec/modeling_encodec.py | 6 ---- .../modeling_encoder_decoder.py | 5 ---- .../models/ernie/modeling_ernie.py | 7 +---- .../models/ernie_m/modeling_ernie_m.py | 6 ---- src/transformers/models/esm/modeling_esm.py | 7 +---- .../models/falcon/modeling_falcon.py | 8 +----- .../models/flava/modeling_flava.py | 7 +---- src/transformers/models/fnet/modeling_fnet.py | 7 +---- .../models/focalnet/modeling_focalnet.py | 7 +---- src/transformers/models/fuyu/modeling_fuyu.py | 5 ---- src/transformers/models/git/modeling_git.py | 9 ++---- src/transformers/models/gpt2/modeling_gpt2.py | 7 +---- .../gpt_bigcode/modeling_gpt_bigcode.py | 8 +----- .../models/gpt_neo/modeling_gpt_neo.py | 7 +---- .../models/gpt_neox/modeling_gpt_neox.py | 7 +---- .../modeling_gpt_neox_japanese.py | 6 ---- src/transformers/models/gptj/modeling_gptj.py | 7 +---- .../modeling_gptsan_japanese.py | 5 ---- .../models/graphormer/modeling_graphormer.py | 6 ---- .../models/groupvit/modeling_groupvit.py | 7 +---- .../models/hubert/modeling_hubert.py | 11 ++------ .../models/idefics/modeling_idefics.py | 9 ++---- src/transformers/models/idefics/vision.py | 2 +- .../models/imagegpt/modeling_imagegpt.py | 7 +---- .../models/informer/modeling_informer.py | 11 ++------ .../instructblip/modeling_instructblip.py | 13 ++------- .../models/layoutlm/modeling_layoutlm.py | 7 +---- .../models/layoutlmv2/modeling_layoutlmv2.py | 7 +---- .../models/layoutlmv3/modeling_layoutlmv3.py | 2 +- src/transformers/models/led/modeling_led.py | 9 ++---- .../models/levit/modeling_levit.py | 6 ---- src/transformers/models/lilt/modeling_lilt.py | 7 +---- .../models/llama/modeling_llama.py | 7 +---- .../models/longformer/modeling_longformer.py | 7 +---- .../models/longt5/modeling_longt5.py | 12 ++------ src/transformers/models/luke/modeling_luke.py | 7 +---- .../models/m2m_100/modeling_m2m_100.py | 9 ++---- .../models/marian/modeling_marian.py | 9 ++---- .../models/markuplm/modeling_markuplm.py | 2 +- .../mask2former/modeling_mask2former.py | 2 +- .../models/maskformer/modeling_maskformer.py | 10 +------ .../maskformer/modeling_maskformer_swin.py | 7 +---- .../models/mbart/modeling_mbart.py | 9 ++---- .../megatron_bert/modeling_megatron_bert.py | 7 +---- .../models/mgp_str/modeling_mgp_str.py | 5 ---- .../models/mistral/modeling_mistral.py | 7 +---- .../models/mobilevit/modeling_mobilevit.py | 7 +---- .../mobilevitv2/modeling_mobilevitv2.py | 7 +---- src/transformers/models/mpt/modeling_mpt.py | 7 +---- src/transformers/models/mra/modeling_mra.py | 7 +---- src/transformers/models/mt5/modeling_mt5.py | 8 +----- .../models/musicgen/modeling_musicgen.py | 13 +-------- src/transformers/models/mvp/modeling_mvp.py | 9 ++---- src/transformers/models/nat/modeling_nat.py | 4 +-- .../models/nezha/modeling_nezha.py | 7 +---- .../models/nllb_moe/modeling_nllb_moe.py | 10 ++----- .../nystromformer/modeling_nystromformer.py | 7 +---- .../models/oneformer/modeling_oneformer.py | 2 +- src/transformers/models/opt/modeling_opt.py | 7 +---- .../models/owlv2/modeling_owlv2.py | 7 +---- .../models/owlvit/modeling_owlvit.py | 7 +---- .../models/pegasus/modeling_pegasus.py | 9 ++---- .../models/pegasus_x/modeling_pegasus_x.py | 9 ++---- .../models/persimmon/modeling_persimmon.py | 7 +---- .../models/pix2struct/modeling_pix2struct.py | 15 ++-------- .../models/plbart/modeling_plbart.py | 9 ++---- .../models/poolformer/modeling_poolformer.py | 6 ---- .../models/pop2piano/modeling_pop2piano.py | 8 +----- .../models/prophetnet/modeling_prophetnet.py | 9 ++---- src/transformers/models/pvt/modeling_pvt.py | 5 ---- .../models/qdqbert/modeling_qdqbert.py | 7 +---- .../models/realm/modeling_realm.py | 2 +- .../models/regnet/modeling_regnet.py | 6 ---- .../models/rembert/modeling_rembert.py | 7 +---- .../models/resnet/modeling_resnet.py | 6 ---- .../models/roberta/modeling_roberta.py | 7 +---- .../modeling_roberta_prelayernorm.py | 7 +---- .../models/roc_bert/modeling_roc_bert.py | 7 +---- .../models/roformer/modeling_roformer.py | 7 +---- src/transformers/models/rwkv/modeling_rwkv.py | 7 +---- src/transformers/models/sam/modeling_sam.py | 2 +- .../seamless_m4t/modeling_seamless_m4t.py | 9 ++---- src/transformers/models/sew/modeling_sew.py | 9 ++---- .../models/sew_d/modeling_sew_d.py | 9 ++---- .../modeling_speech_encoder_decoder.py | 5 ---- .../speech_to_text/modeling_speech_to_text.py | 9 ++---- .../modeling_speech_to_text_2.py | 7 +---- .../models/speecht5/modeling_speecht5.py | 11 ++------ .../models/splinter/modeling_splinter.py | 7 +---- .../swiftformer/modeling_swiftformer.py | 5 ---- src/transformers/models/swin/modeling_swin.py | 7 +---- .../models/swin2sr/modeling_swin2sr.py | 7 +---- .../models/swinv2/modeling_swinv2.py | 7 +---- .../modeling_switch_transformers.py | 8 +----- src/transformers/models/t5/modeling_t5.py | 8 +----- .../modeling_table_transformer.py | 7 +---- .../models/tapas/modeling_tapas.py | 7 +---- .../modeling_time_series_transformer.py | 9 ++---- .../timesformer/modeling_timesformer.py | 7 +---- .../models/trocr/modeling_trocr.py | 7 +---- src/transformers/models/tvlt/modeling_tvlt.py | 9 ++---- src/transformers/models/umt5/modeling_umt5.py | 8 +----- .../models/unispeech/modeling_unispeech.py | 11 ++------ .../unispeech_sat/modeling_unispeech_sat.py | 11 ++------ .../models/upernet/modeling_upernet.py | 7 ----- .../models/videomae/modeling_videomae.py | 9 ++---- src/transformers/models/vilt/modeling_vilt.py | 7 +---- .../modeling_vision_encoder_decoder.py | 5 ---- .../visual_bert/modeling_visual_bert.py | 7 +---- src/transformers/models/vit/modeling_vit.py | 7 +---- .../models/vit_hybrid/modeling_vit_hybrid.py | 7 +---- .../models/vit_mae/modeling_vit_mae.py | 9 ++---- .../models/vit_msn/modeling_vit_msn.py | 7 +---- .../models/vitdet/modeling_vitdet.py | 7 +---- .../models/vitmatte/modeling_vitmatte.py | 11 -------- src/transformers/models/vits/modeling_vits.py | 7 +---- .../models/vivit/modeling_vivit.py | 7 +---- .../models/wav2vec2/modeling_wav2vec2.py | 11 ++------ .../modeling_wav2vec2_conformer.py | 9 ++---- .../models/wavlm/modeling_wavlm.py | 11 ++------ .../models/whisper/modeling_whisper.py | 9 ++---- .../models/x_clip/modeling_x_clip.py | 9 ++---- src/transformers/models/xglm/modeling_xglm.py | 7 +---- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 9 ++---- .../xlm_roberta/modeling_xlm_roberta.py | 7 +---- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 2 +- src/transformers/models/xmod/modeling_xmod.py | 8 +----- .../models/yolos/modeling_yolos.py | 7 +---- src/transformers/models/yoso/modeling_yoso.py | 7 +---- ...ng_{{cookiecutter.lowercase_modelname}}.py | 16 ++--------- 186 files changed, 242 insertions(+), 1145 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6eb14b8270a..0df52e0f45e 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1873,7 +1873,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs ) - self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func)) + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) if getattr(self, "_hf_peft_config_loaded", False): # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True @@ -1882,6 +1882,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # the gradients to make sure the gradient flows. self.enable_input_require_grads() + def _set_gradient_checkpointing( + self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint + ): + is_gradient_checkpointing_set = False + + # Apply it on the top-level module in case the top-level modules supports it + # for example, LongT5Stack inherits from `PreTrainedModel`. + if hasattr(self, "gradient_checkpointing"): + self._gradient_checkpointing_func = gradient_checkpointing_func + self.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + for module in self.modules(): + if hasattr(module, "gradient_checkpointing"): + module._gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + if not is_gradient_checkpointing_set: + raise ValueError( + f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute" + " `gradient_checkpointing` to modules of the model that uses checkpointing." + ) + def gradient_checkpointing_disable(self): """ Deactivates gradient checkpointing for the current model. @@ -1890,7 +1914,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix activations". """ if self.supports_gradient_checkpointing: - self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None)) + self._set_gradient_checkpointing(enable=False) if getattr(self, "_hf_peft_config_loaded", False): self.disable_input_require_grads() diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 58dc2a89200..f48fcbace12 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1095,7 +1095,7 @@ class AlignTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -1192,11 +1192,6 @@ class AlignPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @add_start_docstrings( """The text model from ALIGN without any head or projection on top.""", @@ -1331,6 +1326,7 @@ class AlignTextModel(AlignPreTrainedModel): class AlignVisionModel(AlignPreTrainedModel): config_class = AlignVisionConfig main_input_name = "pixel_values" + supports_gradient_checkpointing = False def __init__(self, config: AlignVisionConfig): super().__init__(config) diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index e6229165aac..048c18edcc8 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -646,7 +646,7 @@ class AltRobertaEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -955,7 +955,7 @@ class AltCLIPEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1078,14 +1078,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, AltCLIPEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - if isinstance(module, AltRobertaEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING class AltCLIPVisionTransformer(nn.Module): diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index a1f85e2a09e..3fddccdea75 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -336,7 +336,7 @@ class ASTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -388,12 +388,6 @@ class ASTPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - # Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST - def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, ASTEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 40e30023108..12a0951c88f 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -946,11 +946,6 @@ class AutoformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (AutoformerDecoder, AutoformerEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - AUTOFORMER_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1208,7 +1203,7 @@ class AutoformerEncoder(AutoformerPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1420,7 +1415,7 @@ class AutoformerDecoder(AutoformerPreTrainedModel): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 8e5cf0d849d..9b3870e8250 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -317,11 +317,6 @@ class BarkPreTrainedModel(PreTrainedModel): return get_parameter_device(self) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - BARK_MODEL_START_DOCSTRING = """ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -642,7 +637,7 @@ class BarkCausalModel(BarkPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, None, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 73eca72e5d1..20ada97627d 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -521,11 +521,6 @@ class BartPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (BartDecoder, BartEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -855,7 +850,7 @@ class BartEncoder(BartPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1105,7 +1100,7 @@ class BartDecoder(BartPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 3ba3d4911b0..abc0d3158f5 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -510,7 +510,7 @@ class BeitEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -566,11 +566,6 @@ class BeitPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BeitEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - BEIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 91380e13a05..c6764c771e7 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -593,7 +593,7 @@ class BertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -757,11 +757,6 @@ class BertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BertEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @dataclass class BertForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 123cb2212e1..b7250f6f7b9 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -401,7 +401,7 @@ class BertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -602,11 +602,6 @@ class BertGenerationPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BertEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - BERT_GENERATION_START_DOCSTRING = r""" diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 0ba2119e684..4383d210cd8 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1617,7 +1617,7 @@ class BigBirdEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -1779,11 +1779,6 @@ class BigBirdPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BigBirdEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - BIG_BIRD_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 98ff51032ba..e7841d8f592 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1609,11 +1609,6 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -1944,7 +1939,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -2284,7 +2279,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 2bbdbed348a..157607b73ce 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -376,11 +376,6 @@ class BioGptPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BioGptModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - BIOGPT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use @@ -591,7 +586,7 @@ class BioGptModel(BioGptPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/bit/modeling_bit.py b/src/transformers/models/bit/modeling_bit.py index d02861d6343..49bc75b5f0a 100644 --- a/src/transformers/models/bit/modeling_bit.py +++ b/src/transformers/models/bit/modeling_bit.py @@ -660,7 +660,6 @@ class BitPreTrainedModel(PreTrainedModel): config_class = BitConfig base_model_prefix = "bit" main_input_name = "pixel_values" - supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, nn.Conv2d): @@ -669,11 +668,6 @@ class BitPreTrainedModel(PreTrainedModel): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BitModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - BIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 51a947af0a8..4a1a7d07320 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -483,11 +483,6 @@ class BlenderbotPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -778,7 +773,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1027,7 +1022,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 88a9b52de90..5755576b4d6 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -480,11 +480,6 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -776,7 +771,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1024,7 +1019,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index efd986299c2..b6173bcdad1 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -34,7 +34,7 @@ from ...utils import ( replace_return_docstrings, ) from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig -from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel +from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel logger = logging.get_logger(__name__) @@ -461,11 +461,6 @@ class BlipPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (BlipEncoder, BlipTextEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - BLIP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -623,7 +618,7 @@ class BlipEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index e0aa4e17f14..00c6a85ee61 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -422,7 +422,7 @@ class BlipTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 2f7f00b3dd5..10a37c79b86 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -297,15 +297,6 @@ class Blip2PreTrainedModel(PreTrainedModel): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - - # Enable / disable GC for the language model as well - if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): - self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) - BLIP_2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -478,7 +469,7 @@ class Blip2Encoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -943,7 +934,7 @@ class Blip2QFormerEncoder(nn.Module): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 583367c9ab5..f94e371256a 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -496,11 +496,6 @@ class BloomPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): - if isinstance(module, BloomModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @staticmethod def _convert_to_standard_cache( past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int @@ -762,7 +757,7 @@ class BloomModel(BloomPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, alibi, diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 0f272a21e21..89655db7f04 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -804,7 +804,7 @@ class BridgeTowerTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/bros/modeling_bros.py b/src/transformers/models/bros/modeling_bros.py index c10f8350567..d3a17b23c94 100755 --- a/src/transformers/models/bros/modeling_bros.py +++ b/src/transformers/models/bros/modeling_bros.py @@ -651,7 +651,7 @@ class BrosEncoder(nn.Module): "`use_cache=False`..." ) use_cache = False - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, bbox_pos_emb, diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 2e0a6c12fe6..50fac0efd00 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -524,7 +524,7 @@ class CamembertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -620,11 +620,6 @@ class CamembertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, CamembertEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - CAMEMBERT_INPUTS_DOCSTRING = r""" Args: diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index adc87591032..ead9619d926 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -795,7 +795,7 @@ class CanineEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -913,11 +913,6 @@ class CaninePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, CanineEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - CANINE_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index ef1c265723b..ec2086bf67c 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -742,11 +742,6 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - CHINESE_CLIP_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it @@ -910,7 +905,7 @@ class ChineseCLIPTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -1014,7 +1009,7 @@ class ChineseCLIPVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, output_attentions, diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 025b59ae4b9..bea7cf2b93c 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -939,7 +939,7 @@ class ClapAudioEncoder(nn.Module): input_dimensions = self.input_resolutions[i] if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: @@ -1588,7 +1588,7 @@ class ClapTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -1689,11 +1689,6 @@ class ClapPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ClapTextEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - class ClapAudioModel(ClapPreTrainedModel): config_class = ClapAudioConfig diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 56f24c157f8..4d2c96ecec4 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -467,11 +467,6 @@ class CLIPPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, CLIPEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - CLIP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -640,7 +635,7 @@ class CLIPEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index 7a0e5292698..a07ceedd726 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -479,11 +479,6 @@ class CLIPSegPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, CLIPSegEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - CLIPSEG_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it @@ -649,7 +644,7 @@ class CLIPSegEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 9a5509a9ed8..6fc054254a4 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -339,11 +339,6 @@ class CodeGenPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, CodeGenModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - CODEGEN_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use @@ -543,7 +538,7 @@ class CodeGenModel(CodeGenPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, None, diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 01dbf8ecd59..b964d72704f 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -1171,11 +1171,6 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ConditionalDetrDecoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - CONDITIONAL_DETR_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1519,7 +1514,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel): # apply transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, combined_attention_mask, diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index da577a58961..2a7901f2f35 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -264,11 +264,6 @@ class ConvBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ConvBertEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - class SeparableConv1D(nn.Module): """This class implements separable convolution, i.e. a depthwise and a pointwise layer""" @@ -633,7 +628,7 @@ class ConvBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/convnext/modeling_convnext.py b/src/transformers/models/convnext/modeling_convnext.py index e11112b5322..a0102b47ce8 100755 --- a/src/transformers/models/convnext/modeling_convnext.py +++ b/src/transformers/models/convnext/modeling_convnext.py @@ -282,7 +282,6 @@ class ConvNextPreTrainedModel(PreTrainedModel): config_class = ConvNextConfig base_model_prefix = "convnext" main_input_name = "pixel_values" - supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -296,11 +295,6 @@ class ConvNextPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ConvNextEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - CONVNEXT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/convnextv2/modeling_convnextv2.py b/src/transformers/models/convnextv2/modeling_convnextv2.py index f1ff89bb124..07580731ea1 100644 --- a/src/transformers/models/convnextv2/modeling_convnextv2.py +++ b/src/transformers/models/convnextv2/modeling_convnextv2.py @@ -303,7 +303,6 @@ class ConvNextV2PreTrainedModel(PreTrainedModel): config_class = ConvNextV2Config base_model_prefix = "convnextv2" main_input_name = "pixel_values" - supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -317,11 +316,6 @@ class ConvNextV2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ConvNextV2Encoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - CONVNEXTV2_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index 8a6c744ed69..405d892c70e 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -536,7 +536,6 @@ class CpmAntPreTrainedModel(PreTrainedModel): config_class = CpmAntConfig base_model_prefix = "cpmant" - supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -556,11 +555,6 @@ class CpmAntPreTrainedModel(PreTrainedModel): elif isinstance(module, CpmAntSegmentPositionEmbedding): module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, CpmAntEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - CPMANT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index a99b6f3a6dc..cf15d8508d5 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -293,7 +293,7 @@ class Data2VecAudioFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( conv_layer.__call__, hidden_states, ) @@ -586,7 +586,7 @@ class Data2VecAudioEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -748,11 +748,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DATA2VEC_AUDIO_START_DOCSTRING = r""" Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 507c2fc464d..567cc7b5c34 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -510,7 +510,7 @@ class Data2VecTextEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -608,11 +608,6 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, Data2VecTextEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DATA2VECTEXT_START_DOCSTRING = r""" Data2VecText was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 2742d5ffc37..49f8c411c33 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -522,7 +522,7 @@ class Data2VecVisionEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -579,11 +579,6 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, Data2VecVisionEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DATA2VEC_VISION_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 65ec497cecd..b5136bcb88c 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -457,7 +457,7 @@ class DebertaEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( layer_module.__call__, next_kv, attention_mask, @@ -833,11 +833,6 @@ class DebertaPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, DebertaEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 2245ac549ad..0d3ed94aeab 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -501,7 +501,7 @@ class DebertaV2Encoder(nn.Module): all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - output_states = self.gradient_checkpointing_func( + output_states = self._gradient_checkpointing_func( layer_module.__call__, next_kv, attention_mask, @@ -932,11 +932,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, DebertaV2Encoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 19c2731a50a..d07a25c8915 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -469,11 +469,6 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, DecisionTransformerGPT2Model): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): def __init__(self, config): @@ -632,7 +627,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, None, diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 220fcf0d066..ec1f60343ca 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1088,11 +1088,6 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, DeformableDetrDecoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DEFORMABLE_DETR_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1384,7 +1379,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel): all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, encoder_hidden_states, diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 6e97e932b53..b8bd9d6ce62 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -357,7 +357,7 @@ class DeiTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -409,11 +409,6 @@ class DeiTPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DeiTEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, DeiTEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DEIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index 9e7a73c5880..525dfec2ab9 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -504,11 +504,6 @@ class MCTCTPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() return attention_mask - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (MCTCTEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - MCTCT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use @@ -617,7 +612,7 @@ class MCTCTEncoder(MCTCTPreTrainedModel): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index fb1cc7f0fb8..8d99c79ef8e 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -456,11 +456,6 @@ class OpenLlamaPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, OpenLlamaModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - OPEN_LLAMA_INPUTS_DOCSTRING = r""" Args: @@ -666,7 +661,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py index c9f31c71444..40c08e4d1d4 100644 --- a/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py @@ -163,11 +163,6 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel): main_input_name = "trajectories" supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, TrajectoryTransformerModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) @@ -551,7 +546,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, layer_past, diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py index 52c9e124242..e0f88467e1e 100644 --- a/src/transformers/models/deprecated/van/modeling_van.py +++ b/src/transformers/models/deprecated/van/modeling_van.py @@ -387,11 +387,6 @@ class VanPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, VanModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - VAN_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index a6f979eaeea..94e2ea8dcd5 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -979,11 +979,6 @@ class DetaPreTrainedModel(PreTrainedModel): if hasattr(module, "level_embed"): nn.init.normal_(module.level_embed) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, DetaDecoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DETA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1276,7 +1271,7 @@ class DetaDecoder(DetaPreTrainedModel): all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, encoder_hidden_states, diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 1c09e3e3d7b..08e11e06e61 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -927,11 +927,6 @@ class DetrPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, DetrDecoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DETR_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1254,7 +1249,7 @@ class DetrDecoder(DetrPreTrainedModel): continue if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, combined_attention_mask, diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index eb4d3f2ff29..aae79e0452a 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -660,9 +660,6 @@ class DinatPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: DinatEncoder, gradient_checkpointing_func=None) -> None: - pass - DINAT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index 1440b6d615f..1215b23480b 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -447,7 +447,7 @@ class Dinov2Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -510,11 +510,6 @@ class Dinov2PreTrainedModel(PreTrainedModel): std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: Dinov2Encoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, Dinov2Encoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DINOV2_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 3768dd6e91c..c66519a7245 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -358,7 +358,7 @@ class Transformer(nn.Module): all_hidden_states = all_hidden_states + (hidden_state,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_state, attn_mask, @@ -424,11 +424,6 @@ class DistilBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, Transformer): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DISTILBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/donut/modeling_donut_swin.py b/src/transformers/models/donut/modeling_donut_swin.py index 76d525717f8..4e02c320a72 100644 --- a/src/transformers/models/donut/modeling_donut_swin.py +++ b/src/transformers/models/donut/modeling_donut_swin.py @@ -749,7 +749,7 @@ class DonutSwinEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: @@ -819,11 +819,6 @@ class DonutSwinPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, DonutSwinEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - SWIN_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index c258343f6cf..cc0d0a1fcb6 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -30,7 +30,7 @@ from ...utils import ( logging, replace_return_docstrings, ) -from ..bert.modeling_bert import BertEncoder, BertModel +from ..bert.modeling_bert import BertModel from .configuration_dpr import DPRConfig @@ -164,11 +164,6 @@ class DPRPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BertEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - class DPREncoder(DPRPreTrainedModel): base_model_prefix = "bert_model" diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 2621fa33801..63796d0168f 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -528,7 +528,7 @@ class DPTViTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -812,11 +812,6 @@ class DPTPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, DPTViTEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - DPT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/efficientnet/modeling_efficientnet.py b/src/transformers/models/efficientnet/modeling_efficientnet.py index d1b2c994034..2513f9b2fde 100644 --- a/src/transformers/models/efficientnet/modeling_efficientnet.py +++ b/src/transformers/models/efficientnet/modeling_efficientnet.py @@ -486,7 +486,6 @@ class EfficientNetPreTrainedModel(PreTrainedModel): config_class = EfficientNetConfig base_model_prefix = "efficientnet" main_input_name = "pixel_values" - supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -500,11 +499,6 @@ class EfficientNetPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, EfficientNetBlock): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @add_start_docstrings( "The bare EfficientNet model outputting raw features without any specific head on top.", diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index fde5632c09c..a30d0a69642 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -571,7 +571,7 @@ class ElectraEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -687,11 +687,6 @@ class ElectraPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ElectraEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @dataclass class ElectraForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/encodec/modeling_encodec.py b/src/transformers/models/encodec/modeling_encodec.py index 28c20da3d5e..441f4a27d83 100644 --- a/src/transformers/models/encodec/modeling_encodec.py +++ b/src/transformers/models/encodec/modeling_encodec.py @@ -446,7 +446,6 @@ class EncodecPreTrainedModel(PreTrainedModel): config_class = EncodecConfig base_model_prefix = "encodec" main_input_name = "input_values" - supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -473,11 +472,6 @@ class EncodecPreTrainedModel(PreTrainedModel): elif "bias" in name: nn.init.constant_(param, 0.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (EncodecEncoder, EncodecDecoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - ENCODEC_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index a13fd19a900..ff5a56749fa 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -265,11 +265,6 @@ class EncoderDecoderModel(PreTrainedModel): self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) - self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) - def get_encoder(self): return self.encoder diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 330cb503316..291ab6c54d1 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -506,7 +506,7 @@ class ErnieEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -675,11 +675,6 @@ class ErniePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ErnieEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @dataclass # Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->Ernie diff --git a/src/transformers/models/ernie_m/modeling_ernie_m.py b/src/transformers/models/ernie_m/modeling_ernie_m.py index b26ee0fcafd..c1be3cfba14 100755 --- a/src/transformers/models/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/ernie_m/modeling_ernie_m.py @@ -411,7 +411,6 @@ class ErnieMPreTrainedModel(PreTrainedModel): config_class = ErnieMConfig base_model_prefix = "ernie_m" - supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -429,11 +428,6 @@ class ErnieMPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ErnieMEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - ERNIE_M_START_DOCSTRING = r""" diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index 86bd20a4648..b7d0253fc4c 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -605,7 +605,7 @@ class EsmEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -705,11 +705,6 @@ class EsmPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, EsmEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - ESM_START_DOCSTRING = r""" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 49307cf52ec..6f4f0838e6c 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1101,12 +1101,6 @@ class FalconPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel - def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): - if isinstance(module, FalconModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @staticmethod def _convert_cache_to_standard_format( past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int @@ -1282,7 +1276,7 @@ class FalconModel(FalconPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, alibi, diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 1fbf49f9e12..de5ec177ae4 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -663,7 +663,7 @@ class FlavaEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -873,11 +873,6 @@ class FlavaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: FlavaEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, FlavaEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @add_start_docstrings( "The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.", diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index b84761536ba..2784880f3c7 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -292,7 +292,7 @@ class FNetEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func(layer_module.__call__, hidden_states) + layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states) else: layer_outputs = layer_module(hidden_states) @@ -424,11 +424,6 @@ class FNetPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, FNetEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @dataclass class FNetForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/focalnet/modeling_focalnet.py b/src/transformers/models/focalnet/modeling_focalnet.py index 87ec9816962..b0033c85598 100644 --- a/src/transformers/models/focalnet/modeling_focalnet.py +++ b/src/transformers/models/focalnet/modeling_focalnet.py @@ -586,7 +586,7 @@ class FocalNetEncoder(nn.Module): for i, stage_module in enumerate(self.stages): if self.gradient_checkpointing and self.training: - stage_outputs = self.gradient_checkpointing_func( + stage_outputs = self._gradient_checkpointing_func( stage_module.__call__, hidden_states, input_dimensions, @@ -652,11 +652,6 @@ class FocalNetPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, FocalNetEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - FOCALNET_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 37f9890ee3d..89127843bef 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -70,11 +70,6 @@ class FuyuPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, FuyuForCausalLM): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - FUYU_INPUTS_DOCSTRING = r""" Args: diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 293b9c789d5..120576bfab9 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -452,7 +452,7 @@ class GitEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -528,11 +528,6 @@ class GitPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (GitEncoder, GitVisionEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - GIT_START_DOCSTRING = r""" @@ -874,7 +869,7 @@ class GitVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 24826a76bc0..bbae7a6c555 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -480,11 +480,6 @@ class GPT2PreTrainedModel(PreTrainedModel): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, GPT2Model): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @dataclass class GPT2DoubleHeadsModelOutput(ModelOutput): @@ -878,7 +873,7 @@ class GPT2Model(GPT2PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, None, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 37c51b40c9a..f8e52b6510a 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -404,12 +404,6 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, GPTBigCodeModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - GPT_BIGCODE_START_DOCSTRING = r""" @@ -651,7 +645,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, None, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index ed1e62bf175..90ca265a822 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -384,11 +384,6 @@ class GPTNeoPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, GPTNeoModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - GPT_NEO_START_DOCSTRING = r""" @@ -605,7 +600,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, None, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 6025d827798..ac59011b281 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -78,11 +78,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, GPTNeoXModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - class GPTNeoXAttention(nn.Module): def __init__(self, config): @@ -642,7 +637,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index c1c5527a465..4f0841e3c34 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -48,7 +48,6 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): config_class = GPTNeoXJapaneseConfig base_model_prefix = "gpt_neox_japanese" - supports_gradient_checkpointing = True _no_split_modules = ["GPTNeoXJapaneseLayer"] _skip_keys_device_placement = "past_key_values" @@ -66,11 +65,6 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, GPTNeoXJapaneseModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - class GPTNeoXJapaneseAttention(nn.Module): def __init__(self, config, use_bias=False): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 2910f9535f6..45c7114943b 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -363,11 +363,6 @@ class GPTJPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, GPTJModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - GPTJ_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use @@ -670,7 +665,7 @@ class GPTJModel(GPTJPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, None, diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 84d956c9f57..1232d24730c 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -759,11 +759,6 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (GPTSanJapaneseAttention,)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id diff --git a/src/transformers/models/graphormer/modeling_graphormer.py b/src/transformers/models/graphormer/modeling_graphormer.py index 68ed6d265e7..ec56d8eda0d 100755 --- a/src/transformers/models/graphormer/modeling_graphormer.py +++ b/src/transformers/models/graphormer/modeling_graphormer.py @@ -712,7 +712,6 @@ class GraphormerPreTrainedModel(PreTrainedModel): config_class = GraphormerConfig base_model_prefix = "graphormer" - supports_gradient_checkpointing = True main_input_name_nodes = "input_nodes" main_input_name_edges = "input_edges" @@ -772,11 +771,6 @@ class GraphormerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, GraphormerModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - class GraphormerModel(GraphormerPreTrainedModel): """The Graphormer model is a graph-encoder model. diff --git a/src/transformers/models/groupvit/modeling_groupvit.py b/src/transformers/models/groupvit/modeling_groupvit.py index a9de6714384..680fe78f5c0 100644 --- a/src/transformers/models/groupvit/modeling_groupvit.py +++ b/src/transformers/models/groupvit/modeling_groupvit.py @@ -804,11 +804,6 @@ class GroupViTPreTrainedModel(PreTrainedModel): nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc2.weight, std=in_proj_std) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - GROUPVIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it @@ -1031,7 +1026,7 @@ class GroupViTTextEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 732e6be2f8d..ddb80f56723 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -346,7 +346,7 @@ class HubertFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( conv_layer.__call__, hidden_states, ) @@ -724,7 +724,7 @@ class HubertEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -808,7 +808,7 @@ class HubertEncoderStableLayerNorm(nn.Module): # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -876,11 +876,6 @@ class HubertPreTrainedModel(PreTrainedModel): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (HubertFeatureEncoder, HubertEncoder, HubertEncoderStableLayerNorm)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 28841903a1a..cb8945ea54d 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -40,7 +40,7 @@ from ...utils import ( ) from .configuration_idefics import IdeficsConfig from .perceiver import IdeficsPerceiverResampler -from .vision import IdeficsVisionEncoder, IdeficsVisionTransformer +from .vision import IdeficsVisionTransformer logger = logging.get_logger(__name__) @@ -978,11 +978,6 @@ class IdeficsPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (IdeficsModel, IdeficsVisionEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - LLAMA_INPUTS_DOCSTRING = r""" Args: @@ -1339,7 +1334,7 @@ class IdeficsModel(IdeficsPreTrainedModel): ) use_cache = False - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( vblock, decoder_layer, hidden_states, diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index 24dc3e9396a..04b2894c4af 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -401,7 +401,7 @@ class IdeficsVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index a365731ed53..33f7ee99c4f 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -525,11 +525,6 @@ class ImageGPTPreTrainedModel(PreTrainedModel): # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ImageGPTModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - IMAGEGPT_START_DOCSTRING = r""" @@ -817,7 +812,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, None, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 53518760cc0..5959c8538d3 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -924,11 +924,6 @@ class InformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (InformerDecoder, InformerEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - INFORMER_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1216,7 +1211,7 @@ class InformerEncoder(InformerPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1224,7 +1219,7 @@ class InformerEncoder(InformerPreTrainedModel): output_attentions, ) if conv_layer is not None: - output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0]) + output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0]) layer_outputs = (output,) + layer_outputs[1:] else: layer_outputs = encoder_layer( @@ -1433,7 +1428,7 @@ class InformerDecoder(InformerPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index d4cb7a1fa00..4e0173bd997 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -304,15 +304,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - - # Enable / disable GC for the language model as well - if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"): - self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func) - INSTRUCTBLIP_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -467,7 +458,7 @@ class InstructBlipEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -938,7 +929,7 @@ class InstructBlipQFormerEncoder(nn.Module): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index ce6d4302bcc..c2ecede73d3 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -487,7 +487,7 @@ class LayoutLMEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -633,11 +633,6 @@ class LayoutLMPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, LayoutLMEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - LAYOUTLM_START_DOCSTRING = r""" The LayoutLM model was proposed in [LayoutLM: Pre-training of Text and Layout for Document Image diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 8f6260fdda4..4a85923cb9b 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -439,7 +439,7 @@ class LayoutLMv2Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -508,11 +508,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, LayoutLMv2Encoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def my_convert_sync_batchnorm(module, process_group=None): # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d` diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index e387707e52d..fe1cbcc2c5c 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -657,7 +657,7 @@ class LayoutLMv3Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 61bbd4156b4..e757ef6a7b9 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1155,11 +1155,6 @@ class LEDPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (LEDDecoder, LEDEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -1877,7 +1872,7 @@ class LEDEncoder(LEDPreTrainedModel): layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -2138,7 +2133,7 @@ class LEDDecoder(LEDPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, combined_attention_mask, diff --git a/src/transformers/models/levit/modeling_levit.py b/src/transformers/models/levit/modeling_levit.py index 5acaaeba900..38a9ee1abc5 100644 --- a/src/transformers/models/levit/modeling_levit.py +++ b/src/transformers/models/levit/modeling_levit.py @@ -493,7 +493,6 @@ class LevitPreTrainedModel(PreTrainedModel): config_class = LevitConfig base_model_prefix = "levit" main_input_name = "pixel_values" - supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -507,11 +506,6 @@ class LevitPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, LevitModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - LEVIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/lilt/modeling_lilt.py b/src/transformers/models/lilt/modeling_lilt.py index 4fd7a85affd..e21f8ab2ce6 100644 --- a/src/transformers/models/lilt/modeling_lilt.py +++ b/src/transformers/models/lilt/modeling_lilt.py @@ -514,7 +514,7 @@ class LiltEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layout_inputs, @@ -601,11 +601,6 @@ class LiltPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, LiltEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - LILT_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4c31729337d..a330ff62e53 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -851,11 +851,6 @@ class LlamaPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, LlamaModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - LLAMA_INPUTS_DOCSTRING = r""" Args: @@ -1036,7 +1031,7 @@ class LlamaModel(LlamaPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index b4f20b45255..62b7ac4a1dc 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1304,7 +1304,7 @@ class LongformerEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -1434,11 +1434,6 @@ class LongformerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, LongformerEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - LONGFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 9abbfa2f200..91e584d80d3 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -23,7 +23,6 @@ from typing import Any, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -1339,11 +1338,6 @@ class LongT5PreTrainedModel(PreTrainedModel): mean=0.0, std=factor * ((d_model) ** -0.5) ) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id @@ -1391,11 +1385,11 @@ class LongT5Stack(LongT5PreTrainedModel): self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + self.gradient_checkpointing = False + # Initialize weights and apply final processing self.post_init() - self.gradient_checkpointing = False - # Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings def get_input_embeddings(self): return self.embed_tokens @@ -1509,7 +1503,7 @@ class LongT5Stack(LongT5PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = checkpoint( + layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, extended_attention_mask, diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 3b5f4d0bf71..6343867353f 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -788,7 +788,7 @@ class LukeEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, word_hidden_states, entity_hidden_states, @@ -914,11 +914,6 @@ class LukePreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, LukeEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - LUKE_START_DOCSTRING = r""" diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 4ebe11f3f3b..07338731cd3 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -552,11 +552,6 @@ class M2M100PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (M2M100Decoder, M2M100Encoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - M2M_100_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -821,7 +816,7 @@ class M2M100Encoder(M2M100PreTrainedModel): # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1061,7 +1056,7 @@ class M2M100Decoder(M2M100PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, combined_attention_mask, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index e2e09b564b0..68d7fe53bd7 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -500,11 +500,6 @@ class MarianPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (MarianDecoder, MarianEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -789,7 +784,7 @@ class MarianEncoder(MarianPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1032,7 +1027,7 @@ class MarianDecoder(MarianPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 80498efb3ca..24ca0c4972a 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -648,7 +648,7 @@ class MarkupLMEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 86eccc47875..6b3e901d71d 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -1864,7 +1864,7 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module): continue if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 7df8b60792a..0fda64fa49f 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -848,7 +848,7 @@ class DetrDecoder(nn.Module): continue if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, combined_attention_mask, @@ -1613,14 +1613,6 @@ class MaskFormerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, MaskFormerPixelLevelModule): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None - if isinstance(module, DetrDecoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @add_start_docstrings( "The bare MaskFormer Model outputting raw hidden-states without any specific head on top.", diff --git a/src/transformers/models/maskformer/modeling_maskformer_swin.py b/src/transformers/models/maskformer/modeling_maskformer_swin.py index 89c6a0c0e0b..b4714860e6b 100644 --- a/src/transformers/models/maskformer/modeling_maskformer_swin.py +++ b/src/transformers/models/maskformer/modeling_maskformer_swin.py @@ -688,7 +688,7 @@ class MaskFormerSwinEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( + layer_hidden_states, output_dimensions, layer_all_hidden_states = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -748,11 +748,6 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, MaskFormerSwinEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): def __init__(self, config, add_pooling_layer=True): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 7c4c9bdf959..e3ec189012c 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -516,11 +516,6 @@ class MBartPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (MBartDecoder, MBartEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -829,7 +824,7 @@ class MBartEncoder(MBartPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1081,7 +1076,7 @@ class MBartDecoder(MBartPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index c23666f10b7..9111f937bc2 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -551,7 +551,7 @@ class MegatronBertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -723,11 +723,6 @@ class MegatronBertPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, MegatronBertEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @dataclass # Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index 1257b4df39c..8914e59a207 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -333,11 +333,6 @@ class MgpstrPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: MgpstrEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, MgpstrEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - MGP_STR_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index f98eb4de884..94ebc690aa2 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -826,11 +826,6 @@ class MistralPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, MistralModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - MISTRAL_INPUTS_DOCSTRING = r""" Args: @@ -1029,7 +1024,7 @@ class MistralModel(MistralPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/mobilevit/modeling_mobilevit.py b/src/transformers/models/mobilevit/modeling_mobilevit.py index c664c02a883..1de0f6adbf0 100755 --- a/src/transformers/models/mobilevit/modeling_mobilevit.py +++ b/src/transformers/models/mobilevit/modeling_mobilevit.py @@ -626,7 +626,7 @@ class MobileViTEncoder(nn.Module): for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, ) @@ -665,11 +665,6 @@ class MobileViTPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, MobileViTEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - MOBILEVIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py index b88925f41b8..842e78946e9 100644 --- a/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py +++ b/src/transformers/models/mobilevitv2/modeling_mobilevitv2.py @@ -582,7 +582,7 @@ class MobileViTV2Encoder(nn.Module): for i, layer_module in enumerate(self.layer): if self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, ) @@ -622,11 +622,6 @@ class MobileViTV2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, MobileViTV2Encoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - MOBILEVITV2_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index ede306e71b8..5fa6698d34a 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -294,11 +294,6 @@ class MptPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None): - if isinstance(module, MptModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @staticmethod def _convert_to_mpt_cache( past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] @@ -524,7 +519,7 @@ class MptModel(MptPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - outputs = self.gradient_checkpointing_func( + outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, alibi, diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index f6cb65889a3..7e81f2a46c2 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -766,7 +766,7 @@ class MraEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -864,11 +864,6 @@ class MraPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, MraEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - MRA_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 2951ffc889d..ba977ad6ae6 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -845,11 +844,6 @@ class MT5PreTrainedModel(PreTrainedModel): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (MT5Attention, MT5Stack)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -1074,7 +1068,7 @@ class MT5Stack(MT5PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = checkpoint( + layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, extended_attention_mask, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index a740ed47074..a1f2c589309 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...generation.configuration_utils import GenerationConfig @@ -475,11 +474,6 @@ class MusicgenPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, MusicgenDecoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - MUSICGEN_START_DOCSTRING = r""" @@ -827,7 +821,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = checkpoint( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.forward, hidden_states, attention_mask, @@ -1557,11 +1551,6 @@ class MusicgenForConditionalGeneration(PreTrainedModel): self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix ) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - # call both encoder and decoder function on gradient checkpointing - self.text_encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) - self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) - def get_audio_encoder(self): return self.audio_encoder diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 122b4928787..62e61f84ea8 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -563,11 +563,6 @@ class MvpPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (MvpDecoder, MvpEncoder, MvpPrompt)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -950,7 +945,7 @@ class MvpEncoder(MvpPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1222,7 +1217,7 @@ class MvpDecoder(MvpPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/nat/modeling_nat.py b/src/transformers/models/nat/modeling_nat.py index 4f7206a5e8e..278ed3d4b6b 100644 --- a/src/transformers/models/nat/modeling_nat.py +++ b/src/transformers/models/nat/modeling_nat.py @@ -639,9 +639,6 @@ class NatPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: NatEncoder, gradient_checkpointing_func=None) -> None: - pass - NAT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use @@ -654,6 +651,7 @@ NAT_START_DOCSTRING = r""" configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ + NAT_INPUTS_DOCSTRING = r""" Args: pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index cd43688e3f7..b6d024b9d66 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -577,7 +577,7 @@ class NezhaEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -747,11 +747,6 @@ class NezhaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, NezhaEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @dataclass class NezhaForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index cbed1e1b153..418f493acfd 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -21,7 +21,6 @@ from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled @@ -874,11 +873,6 @@ class NllbMoePreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (NllbMoeDecoder, NllbMoeEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - NLLB_MOE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1154,7 +1148,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): layer_outputs = (None, None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1421,7 +1415,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - layer_outputs = checkpoint( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.forward, hidden_states, combined_attention_mask, diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 9b2052eb6ca..950f8d27fa8 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -370,7 +370,7 @@ class NystromformerEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -471,11 +471,6 @@ class NystromformerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, NystromformerEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - NYSTROMFORMER_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 165684542d8..33095e53e2f 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -2616,7 +2616,7 @@ class OneFormerTextTransformer(nn.Module): def forward(self, hidden_states: torch.Tensor): for layer in self.layers: if self.use_checkpoint: - hidden_states = self.gradient_checkpointing_func(layer, hidden_states) + hidden_states = self._gradient_checkpointing_func(layer, hidden_states) else: hidden_states = layer(hidden_states) return hidden_states diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 9925e7b4a46..a7efc1a7670 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -411,11 +411,6 @@ class OPTPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (OPTDecoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - OPT_INPUTS_DOCSTRING = r""" Args: @@ -692,7 +687,7 @@ class OPTDecoder(OPTPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_attention_mask, diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index a1491d15ea5..53252529805 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -584,11 +584,6 @@ class Owlv2PreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, Owlv2Encoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - OWLV2_START_DOCSTRING = r""" @@ -765,7 +760,7 @@ class Owlv2Encoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index 68037d13950..2880100d5c5 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -576,11 +576,6 @@ class OwlViTPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, OwlViTEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - OWLVIT_START_DOCSTRING = r""" @@ -754,7 +749,7 @@ class OwlViTEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 058ecd1775a..f5dddbe588e 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -500,11 +500,6 @@ class PegasusPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (PegasusDecoder, PegasusEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - PEGASUS_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -804,7 +799,7 @@ class PegasusEncoder(PegasusPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1082,7 +1077,7 @@ class PegasusDecoder(PegasusPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 6eaddf642a8..60a9714facd 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -780,11 +780,6 @@ class PegasusXPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (PegasusXDecoder, PegasusXEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - PEGASUS_X_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1072,7 +1067,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, global_hidden_states, @@ -1326,7 +1321,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8043fc8699a..c3a25c1030b 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -467,11 +467,6 @@ class PersimmonPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, PersimmonModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - PERSIMMON_INPUTS_DOCSTRING = r""" Args: @@ -669,7 +664,7 @@ class PersimmonModel(PersimmonPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index cfc2b137c57..42f3002ac63 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -20,7 +20,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -343,7 +342,7 @@ class Pix2StructVisionEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -557,11 +556,6 @@ class Pix2StructVisionModel(Pix2StructPreTrainedModel): # Initialize weights and apply final processing self.post_init() - def _set_gradient_checkpointing(self, module: Pix2StructVisionEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, (Pix2StructVisionEncoder, Pix2StructVisionAttention)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def get_input_embeddings(self): return self.embeddings.patch_projection @@ -1315,11 +1309,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] supports_gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (Pix2StructTextAttention, Pix2StructTextModel)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) @@ -1491,7 +1480,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - layer_outputs = checkpoint( + layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, extended_attention_mask, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 1e047fd3726..d828cd8e5bd 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -517,11 +517,6 @@ class PLBartPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (PLBartDecoder, PLBartEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - PLBART_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -808,7 +803,7 @@ class PLBartEncoder(PLBartPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1059,7 +1054,7 @@ class PLBartDecoder(PLBartPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/poolformer/modeling_poolformer.py b/src/transformers/models/poolformer/modeling_poolformer.py index 209533e3199..c5a8c7a0d27 100755 --- a/src/transformers/models/poolformer/modeling_poolformer.py +++ b/src/transformers/models/poolformer/modeling_poolformer.py @@ -270,7 +270,6 @@ class PoolFormerPreTrainedModel(PreTrainedModel): config_class = PoolFormerConfig base_model_prefix = "poolformer" main_input_name = "pixel_values" - supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -282,11 +281,6 @@ class PoolFormerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, PoolFormerEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - POOLFORMER_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 5cf7039e9f0..d9f9ee3aa11 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -22,7 +22,6 @@ from typing import Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from transformers.generation import GenerationConfig @@ -739,11 +738,6 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (Pop2PianoAttention, Pop2PianoStack)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -903,7 +897,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = checkpoint( + layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, extended_attention_mask, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index e4c28659cb4..9c84a85f1cf 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -557,11 +557,6 @@ class ProphetNetPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (ProphetNetDecoder, ProphetNetEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -1330,7 +1325,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, extended_attention_mask, @@ -1564,7 +1559,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, extended_attention_mask, diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index 356b7c14afa..58ed0ae68fe 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -489,11 +489,6 @@ class PvtPreTrainedModel(PreTrainedModel): std=self.config.initializer_range, ) - def _set_gradient_checkpointing(self, module: PvtEncoder, gradient_checkpointing_func=None): - if isinstance(module, PvtEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - PVT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 0a2546a9b64..33d6d6b2088 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -581,7 +581,7 @@ class QDQBertEncoder(nn.Module): "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -752,11 +752,6 @@ class QDQBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, QDQBertEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - QDQBERT_START_DOCSTRING = r""" diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index 86b37b21560..1b202ffd09b 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -586,7 +586,7 @@ class RealmEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/regnet/modeling_regnet.py b/src/transformers/models/regnet/modeling_regnet.py index 21050f07fda..2e6da1eaa38 100644 --- a/src/transformers/models/regnet/modeling_regnet.py +++ b/src/transformers/models/regnet/modeling_regnet.py @@ -283,7 +283,6 @@ class RegNetPreTrainedModel(PreTrainedModel): config_class = RegNetConfig base_model_prefix = "regnet" main_input_name = "pixel_values" - supports_gradient_checkpointing = True # Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights def _init_weights(self, module): @@ -293,11 +292,6 @@ class RegNetPreTrainedModel(PreTrainedModel): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, RegNetModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - REGNET_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index e5e662a9b55..b53464cdeca 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -543,7 +543,7 @@ class RemBertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -668,11 +668,6 @@ class RemBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, RemBertEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - REMBERT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/resnet/modeling_resnet.py b/src/transformers/models/resnet/modeling_resnet.py index e6b1d85b2a4..df460d58f04 100644 --- a/src/transformers/models/resnet/modeling_resnet.py +++ b/src/transformers/models/resnet/modeling_resnet.py @@ -274,7 +274,6 @@ class ResNetPreTrainedModel(PreTrainedModel): config_class = ResNetConfig base_model_prefix = "resnet" main_input_name = "pixel_values" - supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, nn.Conv2d): @@ -283,11 +282,6 @@ class ResNetPreTrainedModel(PreTrainedModel): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ResNetEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - RESNET_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 32a19c08831..8f34098f7bb 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -510,7 +510,7 @@ class RobertaEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -607,11 +607,6 @@ class RobertaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, RobertaEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 78ca2068454..cb22bbe14a0 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -512,7 +512,7 @@ class RobertaPreLayerNormEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -610,11 +610,6 @@ class RobertaPreLayerNormPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, RobertaPreLayerNormEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - ROBERTA_PRELAYERNORM_START_DOCSTRING = r""" diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 3a58efa9140..ff2900774fa 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -644,7 +644,7 @@ class RoCBertEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -791,11 +791,6 @@ class RoCBertPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, RoCBertEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - ROC_BERT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 3893e27b028..95dc0c99394 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -578,7 +578,7 @@ class RoFormerEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -710,11 +710,6 @@ class RoFormerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, RoFormerEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - ROFORMER_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 27523332137..35fd7976ccf 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -466,11 +466,6 @@ class RwkvPreTrainedModel(PreTrainedModel): module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, RwkvModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @dataclass class RwkvOutput(ModelOutput): @@ -677,7 +672,7 @@ class RwkvModel(RwkvPreTrainedModel): all_hidden_states = () if output_hidden_states else None for idx, block in enumerate(self.blocks): if self.gradient_checkpointing and self.training: - hidden_states, state, attentions = self.gradient_checkpointing_func( + hidden_states, state, attentions = self._gradient_checkpointing_func( block.__call__, hidden_states, state, use_cache, output_attentions ) else: diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 1bd6fcdc2a8..5b459f64695 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1042,7 +1042,7 @@ class SamVisionEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, ) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index c6538d4facc..a48c5191667 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -892,7 +892,7 @@ class SeamlessM4TConformerEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -1540,11 +1540,6 @@ class SeamlessM4TPreTrainedModel(PreTrainedModel): k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (SeamlessM4TDecoder, SeamlessM4TEncoder, SeamlessM4TConformerEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): kernel_size, stride = self.config.adaptor_kernel_size, self.config.adaptor_stride pad = kernel_size // 2 @@ -2118,7 +2113,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 36416c168c3..b98e093f8cc 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -360,7 +360,7 @@ class SEWFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( conv_layer.__call__, hidden_states, ) @@ -666,7 +666,7 @@ class SEWEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -743,11 +743,6 @@ class SEWPreTrainedModel(PreTrainedModel): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (SEWEncoder, SEWFeatureEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 39c9641b948..cbb74dcfa24 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -453,7 +453,7 @@ class SEWDFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( conv_layer.__call__, hidden_states, ) @@ -1127,7 +1127,7 @@ class SEWDTransformerEncoder(nn.Module): all_hidden_states = all_hidden_states + (output_states,) if self.gradient_checkpointing and self.training: - output_states = self.gradient_checkpointing_func( + output_states = self._gradient_checkpointing_func( layer_module.__call__, next_kv, attention_mask, @@ -1309,11 +1309,6 @@ class SEWDPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (SEWDEncoder, SEWDFeatureEncoder, SEWDTransformerEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - SEWD_START_DOCSTRING = r""" SEW-D was proposed in [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index ec255fab9bc..78a652e91d0 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -249,11 +249,6 @@ class SpeechEncoderDecoderModel(PreTrainedModel): f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) - self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) - def get_encoder(self): return self.encoder diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 73a02fe66df..24b97c028ba 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -559,11 +559,6 @@ class Speech2TextPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers @@ -818,7 +813,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1060,7 +1055,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index acee2b15a44..c870d7a2f72 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -437,11 +437,6 @@ class Speech2Text2PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, Speech2Text2Decoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - SPEECH_TO_TEXT_2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -670,7 +665,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index b8fea796647..e9382b1beb3 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -520,7 +520,7 @@ class SpeechT5FeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( conv_layer.__call__, hidden_states, ) @@ -1274,11 +1274,6 @@ class SpeechT5PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (SpeechT5Encoder, SpeechT5Decoder, SpeechT5FeatureEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - class SpeechT5Encoder(SpeechT5PreTrainedModel): """ @@ -1380,7 +1375,7 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1700,7 +1695,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 1bdf8f3f5f9..75187c36b93 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -459,7 +459,7 @@ class SplinterEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -539,11 +539,6 @@ class SplinterPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, SplinterEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - SPLINTER_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/swiftformer/modeling_swiftformer.py b/src/transformers/models/swiftformer/modeling_swiftformer.py index 4170ce153bb..0c59c6b5b2d 100644 --- a/src/transformers/models/swiftformer/modeling_swiftformer.py +++ b/src/transformers/models/swiftformer/modeling_swiftformer.py @@ -442,11 +442,6 @@ class SwiftFormerPreTrainedModel(PreTrainedModel): nn.init.constant_(module.bias, 0) nn.init.constant_(module.weight, 1.0) - def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, SwiftFormerEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - SWIFTFORMER_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index c2f15dbbf27..4fe4be5ac79 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -825,7 +825,7 @@ class SwinEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: @@ -894,11 +894,6 @@ class SwinPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, SwinEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - SWIN_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index 47ce01d1691..1884a4a2c44 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -746,7 +746,7 @@ class Swin2SREncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: @@ -795,11 +795,6 @@ class Swin2SRPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, Swin2SREncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - SWIN2SR_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/swinv2/modeling_swinv2.py b/src/transformers/models/swinv2/modeling_swinv2.py index 6daad938a62..ebe9426689c 100644 --- a/src/transformers/models/swinv2/modeling_swinv2.py +++ b/src/transformers/models/swinv2/modeling_swinv2.py @@ -906,7 +906,7 @@ class Swinv2Encoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions ) else: @@ -976,11 +976,6 @@ class Swinv2PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, Swinv2Encoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - SWINV2_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 32d030728de..e00a0147e42 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -865,11 +864,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (SwitchTransformersAttention, SwitchTransformersStack)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -1040,7 +1034,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = checkpoint( + layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, extended_attention_mask, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index c796a9cf24c..3748e5af778 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -24,7 +24,6 @@ from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -873,11 +872,6 @@ class T5PreTrainedModel(PreTrainedModel): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (T5Attention, T5Stack)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -1101,7 +1095,7 @@ class T5Stack(T5PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = checkpoint( + layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, extended_attention_mask, diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index e1da557b001..e44b7cdd360 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -837,11 +837,6 @@ class TableTransformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, TableTransformerDecoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - TABLE_TRANSFORMER_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -1150,7 +1145,7 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel): continue if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, combined_attention_mask, diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index de05d77ec94..1e7a4372bb0 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -646,7 +646,7 @@ class TapasEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -773,11 +773,6 @@ class TapasPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, TapasEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - TAPAS_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 1fa6a963f58..30fce5fa6d9 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -663,11 +663,6 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (TimeSeriesTransformerDecoder, TimeSeriesTransformerEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - TIME_SERIES_TRANSFORMER_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -947,7 +942,7 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -1158,7 +1153,7 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/timesformer/modeling_timesformer.py b/src/transformers/models/timesformer/modeling_timesformer.py index 044705c35e5..1f201b6a5e4 100644 --- a/src/transformers/models/timesformer/modeling_timesformer.py +++ b/src/transformers/models/timesformer/modeling_timesformer.py @@ -439,7 +439,7 @@ class TimesformerEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, output_attentions, @@ -488,11 +488,6 @@ class TimesformerPreTrainedModel(PreTrainedModel): nn.init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range) module.patch_embeddings.apply(self._init_weights) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, TimesformerEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - TIMESFORMER_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index ada8638a03b..d9f1f7c915a 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -454,11 +454,6 @@ class TrOCRPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, TrOCRDecoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - TROCR_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -702,7 +697,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/tvlt/modeling_tvlt.py b/src/transformers/models/tvlt/modeling_tvlt.py index a37265f37c7..ec8b29634a9 100644 --- a/src/transformers/models/tvlt/modeling_tvlt.py +++ b/src/transformers/models/tvlt/modeling_tvlt.py @@ -560,7 +560,7 @@ class TvltEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -610,11 +610,6 @@ class TvltPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (TvltEncoder, TvltDecoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - TVLT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it @@ -872,7 +867,7 @@ class TvltDecoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, None, diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index a5b58444fe4..bfcbfb13eb4 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -21,7 +21,6 @@ from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN from ...modeling_outputs import ( @@ -556,11 +555,6 @@ class UMT5PreTrainedModel(PreTrainedModel): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (UMT5Attention, UMT5Stack)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -710,7 +704,7 @@ class UMT5Stack(UMT5PreTrainedModel): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = checkpoint( + layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, extended_attention_mask, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index db14d5bca51..057a9579c12 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -384,7 +384,7 @@ class UniSpeechFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( conv_layer.__call__, hidden_states, ) @@ -760,7 +760,7 @@ class UniSpeechEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -844,7 +844,7 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -1020,11 +1020,6 @@ class UniSpeechPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (UniSpeechEncoder, UniSpeechEncoderStableLayerNorm, UniSpeechFeatureEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - UNISPEECH_START_DOCSTRING = r""" UniSpeech was proposed in [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 8a9a63804b5..c2889299574 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -398,7 +398,7 @@ class UniSpeechSatFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( conv_layer.__call__, hidden_states, ) @@ -774,7 +774,7 @@ class UniSpeechSatEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -858,7 +858,7 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -1034,11 +1034,6 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (UniSpeechSatEncoder, UniSpeechSatEncoderStableLayerNorm, UniSpeechSatFeatureEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - UNISPEECH_SAT_START_DOCSTRING = r""" UniSpeechSat was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech diff --git a/src/transformers/models/upernet/modeling_upernet.py b/src/transformers/models/upernet/modeling_upernet.py index 04b8c94e135..2ad8e8c372f 100644 --- a/src/transformers/models/upernet/modeling_upernet.py +++ b/src/transformers/models/upernet/modeling_upernet.py @@ -24,7 +24,6 @@ from ... import AutoBackbone from ...modeling_outputs import SemanticSegmenterOutput from ...modeling_utils import PreTrainedModel from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings -from ...utils.backbone_utils import BackboneMixin from .configuration_upernet import UperNetConfig @@ -299,7 +298,6 @@ class UperNetPreTrainedModel(PreTrainedModel): config_class = UperNetConfig main_input_name = "pixel_values" - supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, UperNetPreTrainedModel): @@ -315,11 +313,6 @@ class UperNetPreTrainedModel(PreTrainedModel): if self.auxiliary_head is not None: self.auxiliary_head.init_weights() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BackboneMixin): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - UPERNET_START_DOCSTRING = r""" Parameters: diff --git a/src/transformers/models/videomae/modeling_videomae.py b/src/transformers/models/videomae/modeling_videomae.py index 277280954fd..f78198451d0 100644 --- a/src/transformers/models/videomae/modeling_videomae.py +++ b/src/transformers/models/videomae/modeling_videomae.py @@ -434,7 +434,7 @@ class VideoMAEEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -483,11 +483,6 @@ class VideoMAEPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (VideoMAEEncoder, VideoMAEDecoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - VIDEOMAE_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it @@ -721,7 +716,7 @@ class VideoMAEDecoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, None, diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 482bd08359b..29ee9566e7c 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -531,7 +531,7 @@ class ViltEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -585,11 +585,6 @@ class ViltPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ViltEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - VILT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 84275cc33a7..60646809a62 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -225,11 +225,6 @@ class VisionEncoderDecoderModel(PreTrainedModel): f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" ) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - # call both encoder and decoder function on gradient checkpointing - self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) - self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func) - def get_encoder(self): return self.encoder diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 425a125a0b8..30fe60ef7a1 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -418,7 +418,7 @@ class VisualBertEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -541,11 +541,6 @@ class VisualBertPreTrainedModel(PreTrainedModel): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, VisualBertEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @dataclass class VisualBertForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 67dbddf8766..734ccf6a9e8 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -397,7 +397,7 @@ class ViTEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -461,11 +461,6 @@ class ViTPreTrainedModel(PreTrainedModel): std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, ViTEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - VIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index 959522843f7..24b133e27af 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -415,7 +415,7 @@ class ViTHybridEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -480,11 +480,6 @@ class ViTHybridPreTrainedModel(PreTrainedModel): std=self.config.initializer_range, ).to(module.cls_token.dtype) - def _set_gradient_checkpointing(self, module: ViTHybridEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, ViTHybridEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - VIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index e156fdc3292..910353217fa 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -536,7 +536,7 @@ class ViTMAEEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -585,11 +585,6 @@ class ViTMAEPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (ViTMAEEncoder, ViTMAEDecoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - VIT_MAE_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it @@ -788,7 +783,7 @@ class ViTMAEDecoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, None, diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index b727c331cfb..6b10eb9f245 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -387,7 +387,7 @@ class ViTMSNEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -438,11 +438,6 @@ class ViTMSNPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: ViTMSNEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, ViTMSNEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - VIT_MSN_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/vitdet/modeling_vitdet.py b/src/transformers/models/vitdet/modeling_vitdet.py index 9bb3991fabf..4015875f0c7 100644 --- a/src/transformers/models/vitdet/modeling_vitdet.py +++ b/src/transformers/models/vitdet/modeling_vitdet.py @@ -565,7 +565,7 @@ class VitDetEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -660,11 +660,6 @@ class VitDetPreTrainedModel(PreTrainedModel): module.norm3.weight.data.zero_() module.norm3.bias.data.zero_() - def _set_gradient_checkpointing(self, module: VitDetEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, VitDetEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - VITDET_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index f5025a37e71..01e6ed5aa0a 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -28,7 +28,6 @@ from ...utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...utils.backbone_utils import BackboneMixin from .configuration_vitmatte import VitMatteConfig @@ -86,16 +85,6 @@ class VitMattePreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, BackboneMixin): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - - for backbone_module in module.modules(): - if hasattr(backbone_module, "gradient_checkpointing"): - backbone_module.gradient_checkpointing_func = gradient_checkpointing_func - backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None - class VitMatteBasicConv3x3(nn.Module): """ diff --git a/src/transformers/models/vits/modeling_vits.py b/src/transformers/models/vits/modeling_vits.py index b621bde35e6..a58437fa5aa 100644 --- a/src/transformers/models/vits/modeling_vits.py +++ b/src/transformers/models/vits/modeling_vits.py @@ -1167,7 +1167,7 @@ class VitsEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, padding_mask, @@ -1290,11 +1290,6 @@ class VitsPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, VitsEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - VITS_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 50cb82fb4e1..a9c3f5fd651 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -338,7 +338,7 @@ class VivitEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -408,11 +408,6 @@ class VivitPreTrainedModel(PreTrainedModel): elif isinstance(module, nn.Parameter): module.data.normal_(mean=0.0, std=self.config.initializer_range) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, VivitEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - VIVIT_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index f5bd292da0e..9a2235cb2fd 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -451,7 +451,7 @@ class Wav2Vec2FeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( conv_layer.__call__, hidden_states, ) @@ -796,7 +796,7 @@ class Wav2Vec2Encoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -879,7 +879,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -1154,11 +1154,6 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _get_adapters(self): if self.config.adapter_attn_dim is None: raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.") diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 5fba773ee0c..f7b519185b7 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -518,7 +518,7 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( conv_layer.__call__, hidden_states, ) @@ -904,7 +904,7 @@ class Wav2Vec2ConformerEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -1165,11 +1165,6 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - WAV2VEC2_CONFORMER_START_DOCSTRING = r""" Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 55b19e4c414..0fda7d75da1 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -354,7 +354,7 @@ class WavLMFeatureEncoder(nn.Module): for conv_layer in self.conv_layers: if self._requires_grad and self.gradient_checkpointing and self.training: - hidden_states = self.gradient_checkpointing_func( + hidden_states = self._gradient_checkpointing_func( conv_layer.__call__, hidden_states, ) @@ -706,7 +706,7 @@ class WavLMEncoder(nn.Module): if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -791,7 +791,7 @@ class WavLMEncoderStableLayerNorm(nn.Module): # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, @@ -1033,11 +1033,6 @@ class WavLMPreTrainedModel(PreTrainedModel): attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - WAVLM_START_DOCSTRING = r""" WavLM was proposed in [WavLM: Unified Speech Representation Learning with Labeled and Unlabeled diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d6d0302727c..cd75e397fd8 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -685,11 +685,6 @@ class WhisperPreTrainedModel(PreTrainedModel): embed_positions = module.embed_positions.weight embed_positions.copy_(sinusoids(*embed_positions.shape)) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (WhisperDecoder, WhisperEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers @@ -943,7 +938,7 @@ class WhisperEncoder(WhisperPreTrainedModel): layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, None, @@ -1169,7 +1164,7 @@ class WhisperDecoder(WhisperPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 6c9cc02db9c..ce5a88d7d54 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -534,11 +534,6 @@ class XCLIPPreTrainedModel(PreTrainedModel): if module.bias is not None: module.bias.data.zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (XCLIPEncoder, XCLIPVisionEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - X_CLIP_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it @@ -704,7 +699,7 @@ class XCLIPEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -945,7 +940,7 @@ class XCLIPVisionEncoder(nn.Module): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 1880a783219..b6c21676521 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -503,11 +503,6 @@ class XGLMPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, XGLMModel): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - @add_start_docstrings( "The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.", @@ -675,7 +670,7 @@ class XGLMModel(XGLMPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index 9a9f02b74a6..faa5080b2d9 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -570,11 +570,6 @@ class XLMProphetNetPreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, (XLMProphetNetDecoder, XLMProphetNetEncoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -1350,7 +1345,7 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel): encoder_hidden_states = encoder_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, extended_attention_mask, @@ -1587,7 +1582,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, extended_attention_mask, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index da99b2806fb..95ea2e7dca7 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -511,7 +511,7 @@ class XLMRobertaEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -609,11 +609,6 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, XLMRobertaEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - XLM_ROBERTA_START_DOCSTRING = r""" diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 49f7c075172..582f3733d6e 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -499,7 +499,7 @@ class XLMRobertaXLEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 5f7b42f266f..cb048fb85e2 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -573,7 +573,7 @@ class XmodEncoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, lang_ids, @@ -674,12 +674,6 @@ class XmodPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - # Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel._set_gradient_checkpointing with Roberta->Xmod - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, XmodEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - def set_default_language(self, language: str): """ Set the default language code for the model. This is used when the language is not specified in the input. diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index f6cbaecd014..08e7f0777c6 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -492,7 +492,7 @@ class YolosEncoder(nn.Module): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, layer_head_mask, @@ -545,11 +545,6 @@ class YolosPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module: YolosEncoder, gradient_checkpointing_func=None) -> None: - if isinstance(module, YolosEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - YOLOS_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index 8db66d22106..6666c0a5aa8 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -561,7 +561,7 @@ class YosoEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -662,11 +662,6 @@ class YosoPreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, YosoEncoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - YOSO_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 0b5af845c9a..7a6867bba9c 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -544,7 +544,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, @@ -675,11 +675,6 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): module.bias.data.zero_() module.weight.data.fill_(1.0) - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. @@ -2021,11 +2016,6 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None): - if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)): - module.gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = gradient_checkpointing_func is not None - {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic @@ -2310,7 +2300,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, @@ -2543,7 +2533,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - layer_outputs = self.gradient_checkpointing_func( + layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask,