[core/ gradient_checkpointing] Refactor GC - part 2 (#27073)

* fix

* more fixes

* fix other models

* fix long t5

* use `gradient_checkpointing_func` instead

* fix copies

* set `gradient_checkpointing_func` as a private attribute and retrieve previous behaviour

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* replace it with `is_gradient_checkpointing_set`

* remove default

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2023-10-27 16:15:22 +02:00 committed by GitHub
parent 5be1fb6d1f
commit ffff9e70ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
186 changed files with 242 additions and 1145 deletions

View File

@ -1873,7 +1873,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
) )
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func)) self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
if getattr(self, "_hf_peft_config_loaded", False): if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
@ -1882,6 +1882,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# the gradients to make sure the gradient flows. # the gradients to make sure the gradient flows.
self.enable_input_require_grads() self.enable_input_require_grads()
def _set_gradient_checkpointing(
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
):
is_gradient_checkpointing_set = False
# Apply it on the top-level module in case the top-level modules supports it
# for example, LongT5Stack inherits from `PreTrainedModel`.
if hasattr(self, "gradient_checkpointing"):
self._gradient_checkpointing_func = gradient_checkpointing_func
self.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
for module in self.modules():
if hasattr(module, "gradient_checkpointing"):
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
" `gradient_checkpointing` to modules of the model that uses checkpointing."
)
def gradient_checkpointing_disable(self): def gradient_checkpointing_disable(self):
""" """
Deactivates gradient checkpointing for the current model. Deactivates gradient checkpointing for the current model.
@ -1890,7 +1914,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
activations". activations".
""" """
if self.supports_gradient_checkpointing: if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None)) self._set_gradient_checkpointing(enable=False)
if getattr(self, "_hf_peft_config_loaded", False): if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads() self.disable_input_require_grads()

View File

@ -1095,7 +1095,7 @@ class AlignTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1192,11 +1192,6 @@ class AlignPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings( @add_start_docstrings(
"""The text model from ALIGN without any head or projection on top.""", """The text model from ALIGN without any head or projection on top.""",
@ -1331,6 +1326,7 @@ class AlignTextModel(AlignPreTrainedModel):
class AlignVisionModel(AlignPreTrainedModel): class AlignVisionModel(AlignPreTrainedModel):
config_class = AlignVisionConfig config_class = AlignVisionConfig
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = False
def __init__(self, config: AlignVisionConfig): def __init__(self, config: AlignVisionConfig):
super().__init__(config) super().__init__(config)

View File

@ -646,7 +646,7 @@ class AltRobertaEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -955,7 +955,7 @@ class AltCLIPEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1078,14 +1078,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, AltCLIPEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
if isinstance(module, AltRobertaEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING
class AltCLIPVisionTransformer(nn.Module): class AltCLIPVisionTransformer(nn.Module):

View File

@ -336,7 +336,7 @@ class ASTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
@ -388,12 +388,6 @@ class ASTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ASTEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r""" AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -946,11 +946,6 @@ class AutoformerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (AutoformerDecoder, AutoformerEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
AUTOFORMER_START_DOCSTRING = r""" AUTOFORMER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1208,7 +1203,7 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1420,7 +1415,7 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -317,11 +317,6 @@ class BarkPreTrainedModel(PreTrainedModel):
return get_parameter_device(self) return get_parameter_device(self)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BARK_MODEL_START_DOCSTRING = """ BARK_MODEL_START_DOCSTRING = """
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -642,7 +637,7 @@ class BarkCausalModel(BarkPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
None, None,

View File

@ -521,11 +521,6 @@ class BartPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BartDecoder, BartEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
@ -855,7 +850,7 @@ class BartEncoder(BartPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1105,7 +1100,7 @@ class BartDecoder(BartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -510,7 +510,7 @@ class BeitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
@ -566,11 +566,6 @@ class BeitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BeitEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BEIT_START_DOCSTRING = r""" BEIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -593,7 +593,7 @@ class BertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -757,11 +757,6 @@ class BertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
class BertForPreTrainingOutput(ModelOutput): class BertForPreTrainingOutput(ModelOutput):

View File

@ -401,7 +401,7 @@ class BertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -602,11 +602,6 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BERT_GENERATION_START_DOCSTRING = r""" BERT_GENERATION_START_DOCSTRING = r"""

View File

@ -1617,7 +1617,7 @@ class BigBirdEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1779,11 +1779,6 @@ class BigBirdPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BigBirdEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIG_BIRD_START_DOCSTRING = r""" BIG_BIRD_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -1609,11 +1609,6 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
@ -1944,7 +1939,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -2284,7 +2279,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -376,11 +376,6 @@ class BioGptPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BioGptModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIOGPT_START_DOCSTRING = r""" BIOGPT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
@ -591,7 +586,7 @@ class BioGptModel(BioGptPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -660,7 +660,6 @@ class BitPreTrainedModel(PreTrainedModel):
config_class = BitConfig config_class = BitConfig
base_model_prefix = "bit" base_model_prefix = "bit"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
@ -669,11 +668,6 @@ class BitPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.weight, 1) nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0) nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BitModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIT_START_DOCSTRING = r""" BIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -483,11 +483,6 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
@ -778,7 +773,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1027,7 +1022,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -480,11 +480,6 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
@ -776,7 +771,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1024,7 +1019,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -34,7 +34,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -461,11 +461,6 @@ class BlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None: elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlipEncoder, BlipTextEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BLIP_START_DOCSTRING = r""" BLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -623,7 +618,7 @@ class BlipEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -422,7 +422,7 @@ class BlipTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -297,15 +297,6 @@ class Blip2PreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None: elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Enable / disable GC for the language model as well
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
BLIP_2_START_DOCSTRING = r""" BLIP_2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -478,7 +469,7 @@ class Blip2Encoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -943,7 +934,7 @@ class Blip2QFormerEncoder(nn.Module):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -496,11 +496,6 @@ class BloomPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None):
if isinstance(module, BloomModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@staticmethod @staticmethod
def _convert_to_standard_cache( def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
@ -762,7 +757,7 @@ class BloomModel(BloomPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
alibi, alibi,

View File

@ -804,7 +804,7 @@ class BridgeTowerTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -651,7 +651,7 @@ class BrosEncoder(nn.Module):
"`use_cache=False`..." "`use_cache=False`..."
) )
use_cache = False use_cache = False
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
bbox_pos_emb, bbox_pos_emb,

View File

@ -524,7 +524,7 @@ class CamembertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -620,11 +620,6 @@ class CamembertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CamembertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CAMEMBERT_INPUTS_DOCSTRING = r""" CAMEMBERT_INPUTS_DOCSTRING = r"""
Args: Args:

View File

@ -795,7 +795,7 @@ class CanineEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -913,11 +913,6 @@ class CaninePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CanineEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CANINE_START_DOCSTRING = r""" CANINE_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -742,11 +742,6 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CHINESE_CLIP_START_DOCSTRING = r""" CHINESE_CLIP_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
@ -910,7 +905,7 @@ class ChineseCLIPTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1014,7 +1009,7 @@ class ChineseCLIPVisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
output_attentions, output_attentions,

View File

@ -939,7 +939,7 @@ class ClapAudioEncoder(nn.Module):
input_dimensions = self.input_resolutions[i] input_dimensions = self.input_resolutions[i]
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
) )
else: else:
@ -1588,7 +1588,7 @@ class ClapTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1689,11 +1689,6 @@ class ClapPreTrainedModel(PreTrainedModel):
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ClapTextEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class ClapAudioModel(ClapPreTrainedModel): class ClapAudioModel(ClapPreTrainedModel):
config_class = ClapAudioConfig config_class = ClapAudioConfig

View File

@ -467,11 +467,6 @@ class CLIPPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CLIPEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CLIP_START_DOCSTRING = r""" CLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -640,7 +635,7 @@ class CLIPEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -479,11 +479,6 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CLIPSegEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CLIPSEG_START_DOCSTRING = r""" CLIPSEG_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
@ -649,7 +644,7 @@ class CLIPSegEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -339,11 +339,6 @@ class CodeGenPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CodeGenModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CODEGEN_START_DOCSTRING = r""" CODEGEN_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
@ -543,7 +538,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
None, None,

View File

@ -1171,11 +1171,6 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConditionalDetrDecoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CONDITIONAL_DETR_START_DOCSTRING = r""" CONDITIONAL_DETR_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1519,7 +1514,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
# apply transformation # apply transformation
query_sine_embed = query_sine_embed_before_transformation * pos_transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,

View File

@ -264,11 +264,6 @@ class ConvBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConvBertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class SeparableConv1D(nn.Module): class SeparableConv1D(nn.Module):
"""This class implements separable convolution, i.e. a depthwise and a pointwise layer""" """This class implements separable convolution, i.e. a depthwise and a pointwise layer"""
@ -633,7 +628,7 @@ class ConvBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -282,7 +282,6 @@ class ConvNextPreTrainedModel(PreTrainedModel):
config_class = ConvNextConfig config_class = ConvNextConfig
base_model_prefix = "convnext" base_model_prefix = "convnext"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
@ -296,11 +295,6 @@ class ConvNextPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConvNextEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CONVNEXT_START_DOCSTRING = r""" CONVNEXT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -303,7 +303,6 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
config_class = ConvNextV2Config config_class = ConvNextV2Config
base_model_prefix = "convnextv2" base_model_prefix = "convnextv2"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
@ -317,11 +316,6 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConvNextV2Encoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CONVNEXTV2_START_DOCSTRING = r""" CONVNEXTV2_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -536,7 +536,6 @@ class CpmAntPreTrainedModel(PreTrainedModel):
config_class = CpmAntConfig config_class = CpmAntConfig
base_model_prefix = "cpmant" base_model_prefix = "cpmant"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
@ -556,11 +555,6 @@ class CpmAntPreTrainedModel(PreTrainedModel):
elif isinstance(module, CpmAntSegmentPositionEmbedding): elif isinstance(module, CpmAntSegmentPositionEmbedding):
module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CpmAntEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CPMANT_START_DOCSTRING = r""" CPMANT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -293,7 +293,7 @@ class Data2VecAudioFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func( hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__, conv_layer.__call__,
hidden_states, hidden_states,
) )
@ -586,7 +586,7 @@ class Data2VecAudioEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer.__call__, layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -748,11 +748,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask return attention_mask
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DATA2VEC_AUDIO_START_DOCSTRING = r""" DATA2VEC_AUDIO_START_DOCSTRING = r"""
Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and

View File

@ -510,7 +510,7 @@ class Data2VecTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -608,11 +608,6 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Data2VecTextEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DATA2VECTEXT_START_DOCSTRING = r""" DATA2VECTEXT_START_DOCSTRING = r"""
Data2VecText was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and Data2VecText was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and

View File

@ -522,7 +522,7 @@ class Data2VecVisionEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
@ -579,11 +579,6 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Data2VecVisionEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DATA2VEC_VISION_START_DOCSTRING = r""" DATA2VEC_VISION_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -457,7 +457,7 @@ class DebertaEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func( hidden_states = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
next_kv, next_kv,
attention_mask, attention_mask,
@ -833,11 +833,6 @@ class DebertaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DebertaEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEBERTA_START_DOCSTRING = r""" DEBERTA_START_DOCSTRING = r"""
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled

View File

@ -501,7 +501,7 @@ class DebertaV2Encoder(nn.Module):
all_hidden_states = all_hidden_states + (output_states,) all_hidden_states = all_hidden_states + (output_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
output_states = self.gradient_checkpointing_func( output_states = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
next_kv, next_kv,
attention_mask, attention_mask,
@ -932,11 +932,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DebertaV2Encoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEBERTA_START_DOCSTRING = r""" DEBERTA_START_DOCSTRING = r"""
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled

View File

@ -469,11 +469,6 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DecisionTransformerGPT2Model):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
def __init__(self, config): def __init__(self, config):
@ -632,7 +627,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
None, None,

View File

@ -1088,11 +1088,6 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
if hasattr(module, "level_embed"): if hasattr(module, "level_embed"):
nn.init.normal_(module.level_embed) nn.init.normal_(module.level_embed)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DeformableDetrDecoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEFORMABLE_DETR_START_DOCSTRING = r""" DEFORMABLE_DETR_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1384,7 +1379,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,

View File

@ -357,7 +357,7 @@ class DeiTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
@ -409,11 +409,6 @@ class DeiTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: DeiTEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, DeiTEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEIT_START_DOCSTRING = r""" DEIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -504,11 +504,6 @@ class MCTCTPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long()
return attention_mask return attention_mask
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MCTCTEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
MCTCT_START_DOCSTRING = r""" MCTCT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
@ -617,7 +612,7 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -456,11 +456,6 @@ class OpenLlamaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, OpenLlamaModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
OPEN_LLAMA_INPUTS_DOCSTRING = r""" OPEN_LLAMA_INPUTS_DOCSTRING = r"""
Args: Args:
@ -666,7 +661,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -163,11 +163,6 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel):
main_input_name = "trajectories" main_input_name = "trajectories"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, TrajectoryTransformerModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@ -551,7 +546,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
layer_past, layer_past,

View File

@ -387,11 +387,6 @@ class VanPreTrainedModel(PreTrainedModel):
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, VanModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VAN_START_DOCSTRING = r""" VAN_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -979,11 +979,6 @@ class DetaPreTrainedModel(PreTrainedModel):
if hasattr(module, "level_embed"): if hasattr(module, "level_embed"):
nn.init.normal_(module.level_embed) nn.init.normal_(module.level_embed)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DetaDecoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DETA_START_DOCSTRING = r""" DETA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1276,7 +1271,7 @@ class DetaDecoder(DetaPreTrainedModel):
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,

View File

@ -927,11 +927,6 @@ class DetrPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DetrDecoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DETR_START_DOCSTRING = r""" DETR_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1254,7 +1249,7 @@ class DetrDecoder(DetrPreTrainedModel):
continue continue
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,

View File

@ -660,9 +660,6 @@ class DinatPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: DinatEncoder, gradient_checkpointing_func=None) -> None:
pass
DINAT_START_DOCSTRING = r""" DINAT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -447,7 +447,7 @@ class Dinov2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
@ -510,11 +510,6 @@ class Dinov2PreTrainedModel(PreTrainedModel):
std=self.config.initializer_range, std=self.config.initializer_range,
).to(module.cls_token.dtype) ).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: Dinov2Encoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, Dinov2Encoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DINOV2_START_DOCSTRING = r""" DINOV2_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -358,7 +358,7 @@ class Transformer(nn.Module):
all_hidden_states = all_hidden_states + (hidden_state,) all_hidden_states = all_hidden_states + (hidden_state,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_state, hidden_state,
attn_mask, attn_mask,
@ -424,11 +424,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Transformer):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DISTILBERT_START_DOCSTRING = r""" DISTILBERT_START_DOCSTRING = r"""

View File

@ -749,7 +749,7 @@ class DonutSwinEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
) )
else: else:
@ -819,11 +819,6 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DonutSwinEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
SWIN_START_DOCSTRING = r""" SWIN_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -30,7 +30,7 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from ..bert.modeling_bert import BertEncoder, BertModel from ..bert.modeling_bert import BertModel
from .configuration_dpr import DPRConfig from .configuration_dpr import DPRConfig
@ -164,11 +164,6 @@ class DPRPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class DPREncoder(DPRPreTrainedModel): class DPREncoder(DPRPreTrainedModel):
base_model_prefix = "bert_model" base_model_prefix = "bert_model"

View File

@ -528,7 +528,7 @@ class DPTViTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
@ -812,11 +812,6 @@ class DPTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DPTViTEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DPT_START_DOCSTRING = r""" DPT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -486,7 +486,6 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
config_class = EfficientNetConfig config_class = EfficientNetConfig
base_model_prefix = "efficientnet" base_model_prefix = "efficientnet"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
@ -500,11 +499,6 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, EfficientNetBlock):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings( @add_start_docstrings(
"The bare EfficientNet model outputting raw features without any specific head on top.", "The bare EfficientNet model outputting raw features without any specific head on top.",

View File

@ -571,7 +571,7 @@ class ElectraEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -687,11 +687,6 @@ class ElectraPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ElectraEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
class ElectraForPreTrainingOutput(ModelOutput): class ElectraForPreTrainingOutput(ModelOutput):

View File

@ -446,7 +446,6 @@ class EncodecPreTrainedModel(PreTrainedModel):
config_class = EncodecConfig config_class = EncodecConfig
base_model_prefix = "encodec" base_model_prefix = "encodec"
main_input_name = "input_values" main_input_name = "input_values"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
@ -473,11 +472,6 @@ class EncodecPreTrainedModel(PreTrainedModel):
elif "bias" in name: elif "bias" in name:
nn.init.constant_(param, 0.0) nn.init.constant_(param, 0.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (EncodecEncoder, EncodecDecoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
ENCODEC_START_DOCSTRING = r""" ENCODEC_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the

View File

@ -265,11 +265,6 @@ class EncoderDecoderModel(PreTrainedModel):
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
) )
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
# call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder

View File

@ -506,7 +506,7 @@ class ErnieEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -675,11 +675,6 @@ class ErniePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ErnieEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->Ernie # Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->Ernie

View File

@ -411,7 +411,6 @@ class ErnieMPreTrainedModel(PreTrainedModel):
config_class = ErnieMConfig config_class = ErnieMConfig
base_model_prefix = "ernie_m" base_model_prefix = "ernie_m"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
@ -429,11 +428,6 @@ class ErnieMPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ErnieMEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
ERNIE_M_START_DOCSTRING = r""" ERNIE_M_START_DOCSTRING = r"""

View File

@ -605,7 +605,7 @@ class EsmEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -705,11 +705,6 @@ class EsmPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, EsmEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
ESM_START_DOCSTRING = r""" ESM_START_DOCSTRING = r"""

View File

@ -1101,12 +1101,6 @@ class FalconPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
# Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None):
if isinstance(module, FalconModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@staticmethod @staticmethod
def _convert_cache_to_standard_format( def _convert_cache_to_standard_format(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
@ -1282,7 +1276,7 @@ class FalconModel(FalconPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
alibi, alibi,

View File

@ -663,7 +663,7 @@ class FlavaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -873,11 +873,6 @@ class FlavaPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: FlavaEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, FlavaEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings( @add_start_docstrings(
"The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.", "The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.",

View File

@ -292,7 +292,7 @@ class FNetEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(layer_module.__call__, hidden_states) layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states)
else: else:
layer_outputs = layer_module(hidden_states) layer_outputs = layer_module(hidden_states)
@ -424,11 +424,6 @@ class FNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, FNetEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
class FNetForPreTrainingOutput(ModelOutput): class FNetForPreTrainingOutput(ModelOutput):

View File

@ -586,7 +586,7 @@ class FocalNetEncoder(nn.Module):
for i, stage_module in enumerate(self.stages): for i, stage_module in enumerate(self.stages):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
stage_outputs = self.gradient_checkpointing_func( stage_outputs = self._gradient_checkpointing_func(
stage_module.__call__, stage_module.__call__,
hidden_states, hidden_states,
input_dimensions, input_dimensions,
@ -652,11 +652,6 @@ class FocalNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, FocalNetEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
FOCALNET_START_DOCSTRING = r""" FOCALNET_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -70,11 +70,6 @@ class FuyuPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, FuyuForCausalLM):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
FUYU_INPUTS_DOCSTRING = r""" FUYU_INPUTS_DOCSTRING = r"""
Args: Args:

View File

@ -452,7 +452,7 @@ class GitEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -528,11 +528,6 @@ class GitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (GitEncoder, GitVisionEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GIT_START_DOCSTRING = r""" GIT_START_DOCSTRING = r"""
@ -874,7 +869,7 @@ class GitVisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -480,11 +480,6 @@ class GPT2PreTrainedModel(PreTrainedModel):
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPT2Model):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
class GPT2DoubleHeadsModelOutput(ModelOutput): class GPT2DoubleHeadsModelOutput(ModelOutput):
@ -878,7 +873,7 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
None, None,

View File

@ -404,12 +404,6 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTBigCodeModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GPT_BIGCODE_START_DOCSTRING = r""" GPT_BIGCODE_START_DOCSTRING = r"""
@ -651,7 +645,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
None, None,

View File

@ -384,11 +384,6 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTNeoModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GPT_NEO_START_DOCSTRING = r""" GPT_NEO_START_DOCSTRING = r"""
@ -605,7 +600,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
None, None,

View File

@ -78,11 +78,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTNeoXModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
@ -642,7 +637,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
layer.__call__, layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -48,7 +48,6 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
config_class = GPTNeoXJapaneseConfig config_class = GPTNeoXJapaneseConfig
base_model_prefix = "gpt_neox_japanese" base_model_prefix = "gpt_neox_japanese"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXJapaneseLayer"] _no_split_modules = ["GPTNeoXJapaneseLayer"]
_skip_keys_device_placement = "past_key_values" _skip_keys_device_placement = "past_key_values"
@ -66,11 +65,6 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTNeoXJapaneseModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class GPTNeoXJapaneseAttention(nn.Module): class GPTNeoXJapaneseAttention(nn.Module):
def __init__(self, config, use_bias=False): def __init__(self, config, use_bias=False):

View File

@ -363,11 +363,6 @@ class GPTJPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTJModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GPTJ_START_DOCSTRING = r""" GPTJ_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
@ -670,7 +665,7 @@ class GPTJModel(GPTJPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
None, None,

View File

@ -759,11 +759,6 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (GPTSanJapaneseAttention,)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
def _shift_right(self, input_ids): def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id

View File

@ -712,7 +712,6 @@ class GraphormerPreTrainedModel(PreTrainedModel):
config_class = GraphormerConfig config_class = GraphormerConfig
base_model_prefix = "graphormer" base_model_prefix = "graphormer"
supports_gradient_checkpointing = True
main_input_name_nodes = "input_nodes" main_input_name_nodes = "input_nodes"
main_input_name_edges = "input_edges" main_input_name_edges = "input_edges"
@ -772,11 +771,6 @@ class GraphormerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GraphormerModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class GraphormerModel(GraphormerPreTrainedModel): class GraphormerModel(GraphormerPreTrainedModel):
"""The Graphormer model is a graph-encoder model. """The Graphormer model is a graph-encoder model.

View File

@ -804,11 +804,6 @@ class GroupViTPreTrainedModel(PreTrainedModel):
nn.init.normal_(module.fc1.weight, std=fc_std) nn.init.normal_(module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std) nn.init.normal_(module.fc2.weight, std=in_proj_std)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GROUPVIT_START_DOCSTRING = r""" GROUPVIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
@ -1031,7 +1026,7 @@ class GroupViTTextEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -346,7 +346,7 @@ class HubertFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func( hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__, conv_layer.__call__,
hidden_states, hidden_states,
) )
@ -724,7 +724,7 @@ class HubertEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer.__call__, layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -808,7 +808,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer.__call__, layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -876,11 +876,6 @@ class HubertPreTrainedModel(PreTrainedModel):
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (HubertFeatureEncoder, HubertEncoder, HubertEncoderStableLayerNorm)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
""" """
Computes the output length of the convolutional layers Computes the output length of the convolutional layers

View File

@ -40,7 +40,7 @@ from ...utils import (
) )
from .configuration_idefics import IdeficsConfig from .configuration_idefics import IdeficsConfig
from .perceiver import IdeficsPerceiverResampler from .perceiver import IdeficsPerceiverResampler
from .vision import IdeficsVisionEncoder, IdeficsVisionTransformer from .vision import IdeficsVisionTransformer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -978,11 +978,6 @@ class IdeficsPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (IdeficsModel, IdeficsVisionEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LLAMA_INPUTS_DOCSTRING = r""" LLAMA_INPUTS_DOCSTRING = r"""
Args: Args:
@ -1339,7 +1334,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
) )
use_cache = False use_cache = False
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
vblock, vblock,
decoder_layer, decoder_layer,
hidden_states, hidden_states,

View File

@ -401,7 +401,7 @@ class IdeficsVisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -525,11 +525,6 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ImageGPTModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
IMAGEGPT_START_DOCSTRING = r""" IMAGEGPT_START_DOCSTRING = r"""
@ -817,7 +812,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func( outputs = self._gradient_checkpointing_func(
block.__call__, block.__call__,
hidden_states, hidden_states,
None, None,

View File

@ -924,11 +924,6 @@ class InformerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (InformerDecoder, InformerEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
INFORMER_START_DOCSTRING = r""" INFORMER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1216,7 +1211,7 @@ class InformerEncoder(InformerPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1224,7 +1219,7 @@ class InformerEncoder(InformerPreTrainedModel):
output_attentions, output_attentions,
) )
if conv_layer is not None: if conv_layer is not None:
output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0]) output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:] layer_outputs = (output,) + layer_outputs[1:]
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
@ -1433,7 +1428,7 @@ class InformerDecoder(InformerPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -304,15 +304,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None: elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Enable / disable GC for the language model as well
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
INSTRUCTBLIP_START_DOCSTRING = r""" INSTRUCTBLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -467,7 +458,7 @@ class InstructBlipEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -938,7 +929,7 @@ class InstructBlipQFormerEncoder(nn.Module):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -487,7 +487,7 @@ class LayoutLMEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -633,11 +633,6 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LayoutLMEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LAYOUTLM_START_DOCSTRING = r""" LAYOUTLM_START_DOCSTRING = r"""
The LayoutLM model was proposed in [LayoutLM: Pre-training of Text and Layout for Document Image The LayoutLM model was proposed in [LayoutLM: Pre-training of Text and Layout for Document Image

View File

@ -439,7 +439,7 @@ class LayoutLMv2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -508,11 +508,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LayoutLMv2Encoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def my_convert_sync_batchnorm(module, process_group=None): def my_convert_sync_batchnorm(module, process_group=None):
# same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d` # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d`

View File

@ -657,7 +657,7 @@ class LayoutLMv3Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -1155,11 +1155,6 @@ class LEDPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (LEDDecoder, LEDEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
@ -1877,7 +1872,7 @@ class LEDEncoder(LEDPreTrainedModel):
layer_outputs = (None, None, None) layer_outputs = (None, None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -2138,7 +2133,7 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,

View File

@ -493,7 +493,6 @@ class LevitPreTrainedModel(PreTrainedModel):
config_class = LevitConfig config_class = LevitConfig
base_model_prefix = "levit" base_model_prefix = "levit"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
@ -507,11 +506,6 @@ class LevitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LevitModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LEVIT_START_DOCSTRING = r""" LEVIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -514,7 +514,7 @@ class LiltEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
layout_inputs, layout_inputs,
@ -601,11 +601,6 @@ class LiltPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LiltEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LILT_START_DOCSTRING = r""" LILT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the

View File

@ -851,11 +851,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LlamaModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LLAMA_INPUTS_DOCSTRING = r""" LLAMA_INPUTS_DOCSTRING = r"""
Args: Args:
@ -1036,7 +1031,7 @@ class LlamaModel(LlamaPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -1304,7 +1304,7 @@ class LongformerEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1434,11 +1434,6 @@ class LongformerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LongformerEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LONGFORMER_START_DOCSTRING = r""" LONGFORMER_START_DOCSTRING = r"""

View File

@ -23,7 +23,6 @@ from typing import Any, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import ( from ...modeling_outputs import (
@ -1339,11 +1338,6 @@ class LongT5PreTrainedModel(PreTrainedModel):
mean=0.0, std=factor * ((d_model) ** -0.5) mean=0.0, std=factor * ((d_model) ** -0.5)
) )
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
def _shift_right(self, input_ids): def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
@ -1391,11 +1385,11 @@ class LongT5Stack(LongT5PreTrainedModel):
self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
self.gradient_checkpointing = False
# Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings # Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
@ -1509,7 +1503,7 @@ class LongT5Stack(LongT5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = checkpoint( layer_outputs = self._gradient_checkpointing_func(
layer_module.forward, layer_module.forward,
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,

View File

@ -788,7 +788,7 @@ class LukeEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
word_hidden_states, word_hidden_states,
entity_hidden_states, entity_hidden_states,
@ -914,11 +914,6 @@ class LukePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LukeEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LUKE_START_DOCSTRING = r""" LUKE_START_DOCSTRING = r"""

View File

@ -552,11 +552,6 @@ class M2M100PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (M2M100Decoder, M2M100Encoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
M2M_100_START_DOCSTRING = r""" M2M_100_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -821,7 +816,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1061,7 +1056,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,

View File

@ -500,11 +500,6 @@ class MarianPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MarianDecoder, MarianEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
@ -789,7 +784,7 @@ class MarianEncoder(MarianPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1032,7 +1027,7 @@ class MarianDecoder(MarianPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -648,7 +648,7 @@ class MarkupLMEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -1864,7 +1864,7 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module):
continue continue
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -848,7 +848,7 @@ class DetrDecoder(nn.Module):
continue continue
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,
@ -1613,14 +1613,6 @@ class MaskFormerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MaskFormerPixelLevelModule):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None
if isinstance(module, DetrDecoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings( @add_start_docstrings(
"The bare MaskFormer Model outputting raw hidden-states without any specific head on top.", "The bare MaskFormer Model outputting raw hidden-states without any specific head on top.",

View File

@ -688,7 +688,7 @@ class MaskFormerSwinEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func( layer_hidden_states, output_dimensions, layer_all_hidden_states = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
@ -748,11 +748,6 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MaskFormerSwinEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):

View File

@ -516,11 +516,6 @@ class MBartPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MBartDecoder, MBartEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
@ -829,7 +824,7 @@ class MBartEncoder(MBartPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__, encoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -1081,7 +1076,7 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__, decoder_layer.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,

View File

@ -551,7 +551,7 @@ class MegatronBertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, layer_module.__call__,
hidden_states, hidden_states,
attention_mask, attention_mask,
@ -723,11 +723,6 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MegatronBertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert # Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert

Some files were not shown because too many files have changed in this diff Show More