[core/ gradient_checkpointing] Refactor GC - part 2 (#27073)

* fix

* more fixes

* fix other models

* fix long t5

* use `gradient_checkpointing_func` instead

* fix copies

* set `gradient_checkpointing_func` as a private attribute and retrieve previous behaviour

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* replace it with `is_gradient_checkpointing_set`

* remove default

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2023-10-27 16:15:22 +02:00 committed by GitHub
parent 5be1fb6d1f
commit ffff9e70ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
186 changed files with 242 additions and 1145 deletions

View File

@ -1873,7 +1873,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
)
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func))
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
@ -1882,6 +1882,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
def _set_gradient_checkpointing(
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
):
is_gradient_checkpointing_set = False
# Apply it on the top-level module in case the top-level modules supports it
# for example, LongT5Stack inherits from `PreTrainedModel`.
if hasattr(self, "gradient_checkpointing"):
self._gradient_checkpointing_func = gradient_checkpointing_func
self.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
for module in self.modules():
if hasattr(module, "gradient_checkpointing"):
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
" `gradient_checkpointing` to modules of the model that uses checkpointing."
)
def gradient_checkpointing_disable(self):
"""
Deactivates gradient checkpointing for the current model.
@ -1890,7 +1914,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
activations".
"""
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None))
self._set_gradient_checkpointing(enable=False)
if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads()

View File

@ -1095,7 +1095,7 @@ class AlignTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -1192,11 +1192,6 @@ class AlignPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings(
"""The text model from ALIGN without any head or projection on top.""",
@ -1331,6 +1326,7 @@ class AlignTextModel(AlignPreTrainedModel):
class AlignVisionModel(AlignPreTrainedModel):
config_class = AlignVisionConfig
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
def __init__(self, config: AlignVisionConfig):
super().__init__(config)

View File

@ -646,7 +646,7 @@ class AltRobertaEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -955,7 +955,7 @@ class AltCLIPEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -1078,14 +1078,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, AltCLIPEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
if isinstance(module, AltRobertaEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING
class AltCLIPVisionTransformer(nn.Module):

View File

@ -336,7 +336,7 @@ class ASTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
@ -388,12 +388,6 @@ class ASTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ASTEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -946,11 +946,6 @@ class AutoformerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (AutoformerDecoder, AutoformerEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
AUTOFORMER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1208,7 +1203,7 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -1420,7 +1415,7 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -317,11 +317,6 @@ class BarkPreTrainedModel(PreTrainedModel):
return get_parameter_device(self)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BARK_MODEL_START_DOCSTRING = """
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -642,7 +637,7 @@ class BarkCausalModel(BarkPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,

View File

@ -521,11 +521,6 @@ class BartPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BartDecoder, BartEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
@ -855,7 +850,7 @@ class BartEncoder(BartPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -1105,7 +1100,7 @@ class BartDecoder(BartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -510,7 +510,7 @@ class BeitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
@ -566,11 +566,6 @@ class BeitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BeitEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BEIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -593,7 +593,7 @@ class BertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -757,11 +757,6 @@ class BertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
class BertForPreTrainingOutput(ModelOutput):

View File

@ -401,7 +401,7 @@ class BertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -602,11 +602,6 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BERT_GENERATION_START_DOCSTRING = r"""

View File

@ -1617,7 +1617,7 @@ class BigBirdEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -1779,11 +1779,6 @@ class BigBirdPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BigBirdEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIG_BIRD_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -1609,11 +1609,6 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
@ -1944,7 +1939,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -2284,7 +2279,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -376,11 +376,6 @@ class BioGptPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BioGptModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIOGPT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
@ -591,7 +586,7 @@ class BioGptModel(BioGptPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -660,7 +660,6 @@ class BitPreTrainedModel(PreTrainedModel):
config_class = BitConfig
base_model_prefix = "bit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
@ -669,11 +668,6 @@ class BitPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BitModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -483,11 +483,6 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
@ -778,7 +773,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -1027,7 +1022,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -480,11 +480,6 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
@ -776,7 +771,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -1024,7 +1019,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -34,7 +34,7 @@ from ...utils import (
replace_return_docstrings,
)
from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel
from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
logger = logging.get_logger(__name__)
@ -461,11 +461,6 @@ class BlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (BlipEncoder, BlipTextEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
BLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -623,7 +618,7 @@ class BlipEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -422,7 +422,7 @@ class BlipTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,

View File

@ -297,15 +297,6 @@ class Blip2PreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Enable / disable GC for the language model as well
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
BLIP_2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -478,7 +469,7 @@ class Blip2Encoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -943,7 +934,7 @@ class Blip2QFormerEncoder(nn.Module):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,

View File

@ -496,11 +496,6 @@ class BloomPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None):
if isinstance(module, BloomModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
@ -762,7 +757,7 @@ class BloomModel(BloomPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,

View File

@ -804,7 +804,7 @@ class BridgeTowerTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,

View File

@ -651,7 +651,7 @@ class BrosEncoder(nn.Module):
"`use_cache=False`..."
)
use_cache = False
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
bbox_pos_emb,

View File

@ -524,7 +524,7 @@ class CamembertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -620,11 +620,6 @@ class CamembertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CamembertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CAMEMBERT_INPUTS_DOCSTRING = r"""
Args:

View File

@ -795,7 +795,7 @@ class CanineEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -913,11 +913,6 @@ class CaninePreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CanineEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CANINE_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -742,11 +742,6 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CHINESE_CLIP_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
@ -910,7 +905,7 @@ class ChineseCLIPTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -1014,7 +1009,7 @@ class ChineseCLIPVisionEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
output_attentions,

View File

@ -939,7 +939,7 @@ class ClapAudioEncoder(nn.Module):
input_dimensions = self.input_resolutions[i]
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
)
else:
@ -1588,7 +1588,7 @@ class ClapTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -1689,11 +1689,6 @@ class ClapPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ClapTextEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class ClapAudioModel(ClapPreTrainedModel):
config_class = ClapAudioConfig

View File

@ -467,11 +467,6 @@ class CLIPPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CLIPEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -640,7 +635,7 @@ class CLIPEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -479,11 +479,6 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CLIPSegEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CLIPSEG_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
@ -649,7 +644,7 @@ class CLIPSegEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -339,11 +339,6 @@ class CodeGenPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CodeGenModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CODEGEN_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
@ -543,7 +538,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,

View File

@ -1171,11 +1171,6 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConditionalDetrDecoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CONDITIONAL_DETR_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1519,7 +1514,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
# apply transformation
query_sine_embed = query_sine_embed_before_transformation * pos_transformation
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,

View File

@ -264,11 +264,6 @@ class ConvBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConvBertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class SeparableConv1D(nn.Module):
"""This class implements separable convolution, i.e. a depthwise and a pointwise layer"""
@ -633,7 +628,7 @@ class ConvBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,

View File

@ -282,7 +282,6 @@ class ConvNextPreTrainedModel(PreTrainedModel):
config_class = ConvNextConfig
base_model_prefix = "convnext"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -296,11 +295,6 @@ class ConvNextPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConvNextEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CONVNEXT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -303,7 +303,6 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
config_class = ConvNextV2Config
base_model_prefix = "convnextv2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -317,11 +316,6 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConvNextV2Encoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CONVNEXTV2_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -536,7 +536,6 @@ class CpmAntPreTrainedModel(PreTrainedModel):
config_class = CpmAntConfig
base_model_prefix = "cpmant"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -556,11 +555,6 @@ class CpmAntPreTrainedModel(PreTrainedModel):
elif isinstance(module, CpmAntSegmentPositionEmbedding):
module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CpmAntEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CPMANT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -293,7 +293,7 @@ class Data2VecAudioFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func(
hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
@ -586,7 +586,7 @@ class Data2VecAudioEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
@ -748,11 +748,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DATA2VEC_AUDIO_START_DOCSTRING = r"""
Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and

View File

@ -510,7 +510,7 @@ class Data2VecTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -608,11 +608,6 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Data2VecTextEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DATA2VECTEXT_START_DOCSTRING = r"""
Data2VecText was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and

View File

@ -522,7 +522,7 @@ class Data2VecVisionEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
@ -579,11 +579,6 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Data2VecVisionEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DATA2VEC_VISION_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -457,7 +457,7 @@ class DebertaEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func(
hidden_states = self._gradient_checkpointing_func(
layer_module.__call__,
next_kv,
attention_mask,
@ -833,11 +833,6 @@ class DebertaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DebertaEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEBERTA_START_DOCSTRING = r"""
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled

View File

@ -501,7 +501,7 @@ class DebertaV2Encoder(nn.Module):
all_hidden_states = all_hidden_states + (output_states,)
if self.gradient_checkpointing and self.training:
output_states = self.gradient_checkpointing_func(
output_states = self._gradient_checkpointing_func(
layer_module.__call__,
next_kv,
attention_mask,
@ -932,11 +932,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DebertaV2Encoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEBERTA_START_DOCSTRING = r"""
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled

View File

@ -469,11 +469,6 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DecisionTransformerGPT2Model):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
def __init__(self, config):
@ -632,7 +627,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,

View File

@ -1088,11 +1088,6 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
if hasattr(module, "level_embed"):
nn.init.normal_(module.level_embed)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DeformableDetrDecoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEFORMABLE_DETR_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1384,7 +1379,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
encoder_hidden_states,

View File

@ -357,7 +357,7 @@ class DeiTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
@ -409,11 +409,6 @@ class DeiTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: DeiTEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, DeiTEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -504,11 +504,6 @@ class MCTCTPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long()
return attention_mask
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MCTCTEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
MCTCT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
@ -617,7 +612,7 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -456,11 +456,6 @@ class OpenLlamaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, OpenLlamaModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
OPEN_LLAMA_INPUTS_DOCSTRING = r"""
Args:
@ -666,7 +661,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -163,11 +163,6 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel):
main_input_name = "trajectories"
supports_gradient_checkpointing = True
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, TrajectoryTransformerModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@ -551,7 +546,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
layer_past,

View File

@ -387,11 +387,6 @@ class VanPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, VanModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VAN_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -979,11 +979,6 @@ class DetaPreTrainedModel(PreTrainedModel):
if hasattr(module, "level_embed"):
nn.init.normal_(module.level_embed)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DetaDecoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DETA_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1276,7 +1271,7 @@ class DetaDecoder(DetaPreTrainedModel):
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
encoder_hidden_states,

View File

@ -927,11 +927,6 @@ class DetrPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DetrDecoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DETR_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1254,7 +1249,7 @@ class DetrDecoder(DetrPreTrainedModel):
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,

View File

@ -660,9 +660,6 @@ class DinatPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: DinatEncoder, gradient_checkpointing_func=None) -> None:
pass
DINAT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -447,7 +447,7 @@ class Dinov2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
@ -510,11 +510,6 @@ class Dinov2PreTrainedModel(PreTrainedModel):
std=self.config.initializer_range,
).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: Dinov2Encoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, Dinov2Encoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DINOV2_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -358,7 +358,7 @@ class Transformer(nn.Module):
all_hidden_states = all_hidden_states + (hidden_state,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_state,
attn_mask,
@ -424,11 +424,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Transformer):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DISTILBERT_START_DOCSTRING = r"""

View File

@ -749,7 +749,7 @@ class DonutSwinEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
)
else:
@ -819,11 +819,6 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DonutSwinEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
SWIN_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -30,7 +30,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ..bert.modeling_bert import BertEncoder, BertModel
from ..bert.modeling_bert import BertModel
from .configuration_dpr import DPRConfig
@ -164,11 +164,6 @@ class DPRPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class DPREncoder(DPRPreTrainedModel):
base_model_prefix = "bert_model"

View File

@ -528,7 +528,7 @@ class DPTViTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
@ -812,11 +812,6 @@ class DPTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DPTViTEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DPT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -486,7 +486,6 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
config_class = EfficientNetConfig
base_model_prefix = "efficientnet"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -500,11 +499,6 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, EfficientNetBlock):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings(
"The bare EfficientNet model outputting raw features without any specific head on top.",

View File

@ -571,7 +571,7 @@ class ElectraEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -687,11 +687,6 @@ class ElectraPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ElectraEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
class ElectraForPreTrainingOutput(ModelOutput):

View File

@ -446,7 +446,6 @@ class EncodecPreTrainedModel(PreTrainedModel):
config_class = EncodecConfig
base_model_prefix = "encodec"
main_input_name = "input_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -473,11 +472,6 @@ class EncodecPreTrainedModel(PreTrainedModel):
elif "bias" in name:
nn.init.constant_(param, 0.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (EncodecEncoder, EncodecDecoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
ENCODEC_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the

View File

@ -265,11 +265,6 @@ class EncoderDecoderModel(PreTrainedModel):
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
# call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
def get_encoder(self):
return self.encoder

View File

@ -506,7 +506,7 @@ class ErnieEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -675,11 +675,6 @@ class ErniePreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ErnieEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->Ernie

View File

@ -411,7 +411,6 @@ class ErnieMPreTrainedModel(PreTrainedModel):
config_class = ErnieMConfig
base_model_prefix = "ernie_m"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -429,11 +428,6 @@ class ErnieMPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ErnieMEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
ERNIE_M_START_DOCSTRING = r"""

View File

@ -605,7 +605,7 @@ class EsmEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -705,11 +705,6 @@ class EsmPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, EsmEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
ESM_START_DOCSTRING = r"""

View File

@ -1101,12 +1101,6 @@ class FalconPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None):
if isinstance(module, FalconModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@staticmethod
def _convert_cache_to_standard_format(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
@ -1282,7 +1276,7 @@ class FalconModel(FalconPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,

View File

@ -663,7 +663,7 @@ class FlavaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -873,11 +873,6 @@ class FlavaPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: FlavaEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, FlavaEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings(
"The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.",

View File

@ -292,7 +292,7 @@ class FNetEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(layer_module.__call__, hidden_states)
layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states)
else:
layer_outputs = layer_module(hidden_states)
@ -424,11 +424,6 @@ class FNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, FNetEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
class FNetForPreTrainingOutput(ModelOutput):

View File

@ -586,7 +586,7 @@ class FocalNetEncoder(nn.Module):
for i, stage_module in enumerate(self.stages):
if self.gradient_checkpointing and self.training:
stage_outputs = self.gradient_checkpointing_func(
stage_outputs = self._gradient_checkpointing_func(
stage_module.__call__,
hidden_states,
input_dimensions,
@ -652,11 +652,6 @@ class FocalNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, FocalNetEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
FOCALNET_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use

View File

@ -70,11 +70,6 @@ class FuyuPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, FuyuForCausalLM):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
FUYU_INPUTS_DOCSTRING = r"""
Args:

View File

@ -452,7 +452,7 @@ class GitEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -528,11 +528,6 @@ class GitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (GitEncoder, GitVisionEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GIT_START_DOCSTRING = r"""
@ -874,7 +869,7 @@ class GitVisionEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -480,11 +480,6 @@ class GPT2PreTrainedModel(PreTrainedModel):
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPT2Model):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
class GPT2DoubleHeadsModelOutput(ModelOutput):
@ -878,7 +873,7 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,

View File

@ -404,12 +404,6 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTBigCodeModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GPT_BIGCODE_START_DOCSTRING = r"""
@ -651,7 +645,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,

View File

@ -384,11 +384,6 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTNeoModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GPT_NEO_START_DOCSTRING = r"""
@ -605,7 +600,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,

View File

@ -78,11 +78,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTNeoXModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class GPTNeoXAttention(nn.Module):
def __init__(self, config):
@ -642,7 +637,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,

View File

@ -48,7 +48,6 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
config_class = GPTNeoXJapaneseConfig
base_model_prefix = "gpt_neox_japanese"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXJapaneseLayer"]
_skip_keys_device_placement = "past_key_values"
@ -66,11 +65,6 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTNeoXJapaneseModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class GPTNeoXJapaneseAttention(nn.Module):
def __init__(self, config, use_bias=False):

View File

@ -363,11 +363,6 @@ class GPTJPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GPTJModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GPTJ_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
@ -670,7 +665,7 @@ class GPTJModel(GPTJPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,

View File

@ -759,11 +759,6 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (GPTSanJapaneseAttention,)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id

View File

@ -712,7 +712,6 @@ class GraphormerPreTrainedModel(PreTrainedModel):
config_class = GraphormerConfig
base_model_prefix = "graphormer"
supports_gradient_checkpointing = True
main_input_name_nodes = "input_nodes"
main_input_name_edges = "input_edges"
@ -772,11 +771,6 @@ class GraphormerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, GraphormerModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class GraphormerModel(GraphormerPreTrainedModel):
"""The Graphormer model is a graph-encoder model.

View File

@ -804,11 +804,6 @@ class GroupViTPreTrainedModel(PreTrainedModel):
nn.init.normal_(module.fc1.weight, std=fc_std)
nn.init.normal_(module.fc2.weight, std=in_proj_std)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
GROUPVIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
@ -1031,7 +1026,7 @@ class GroupViTTextEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -346,7 +346,7 @@ class HubertFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func(
hidden_states = self._gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
@ -724,7 +724,7 @@ class HubertEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
@ -808,7 +808,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
@ -876,11 +876,6 @@ class HubertPreTrainedModel(PreTrainedModel):
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (HubertFeatureEncoder, HubertEncoder, HubertEncoderStableLayerNorm)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers

View File

@ -40,7 +40,7 @@ from ...utils import (
)
from .configuration_idefics import IdeficsConfig
from .perceiver import IdeficsPerceiverResampler
from .vision import IdeficsVisionEncoder, IdeficsVisionTransformer
from .vision import IdeficsVisionTransformer
logger = logging.get_logger(__name__)
@ -978,11 +978,6 @@ class IdeficsPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (IdeficsModel, IdeficsVisionEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LLAMA_INPUTS_DOCSTRING = r"""
Args:
@ -1339,7 +1334,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
)
use_cache = False
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
vblock,
decoder_layer,
hidden_states,

View File

@ -401,7 +401,7 @@ class IdeficsVisionEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -525,11 +525,6 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ImageGPTModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
IMAGEGPT_START_DOCSTRING = r"""
@ -817,7 +812,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,

View File

@ -924,11 +924,6 @@ class InformerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (InformerDecoder, InformerEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
INFORMER_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -1216,7 +1211,7 @@ class InformerEncoder(InformerPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -1224,7 +1219,7 @@ class InformerEncoder(InformerPreTrainedModel):
output_attentions,
)
if conv_layer is not None:
output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0])
output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
else:
layer_outputs = encoder_layer(
@ -1433,7 +1428,7 @@ class InformerDecoder(InformerPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -304,15 +304,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Enable / disable GC for the language model as well
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
INSTRUCTBLIP_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -467,7 +458,7 @@ class InstructBlipEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -938,7 +929,7 @@ class InstructBlipQFormerEncoder(nn.Module):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,

View File

@ -487,7 +487,7 @@ class LayoutLMEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -633,11 +633,6 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LayoutLMEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LAYOUTLM_START_DOCSTRING = r"""
The LayoutLM model was proposed in [LayoutLM: Pre-training of Text and Layout for Document Image

View File

@ -439,7 +439,7 @@ class LayoutLMv2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -508,11 +508,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LayoutLMv2Encoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def my_convert_sync_batchnorm(module, process_group=None):
# same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d`

View File

@ -657,7 +657,7 @@ class LayoutLMv3Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,

View File

@ -1155,11 +1155,6 @@ class LEDPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (LEDDecoder, LEDEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
@ -1877,7 +1872,7 @@ class LEDEncoder(LEDPreTrainedModel):
layer_outputs = (None, None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -2138,7 +2133,7 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,

View File

@ -493,7 +493,6 @@ class LevitPreTrainedModel(PreTrainedModel):
config_class = LevitConfig
base_model_prefix = "levit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
@ -507,11 +506,6 @@ class LevitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LevitModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LEVIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it

View File

@ -514,7 +514,7 @@ class LiltEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layout_inputs,
@ -601,11 +601,6 @@ class LiltPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LiltEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LILT_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the

View File

@ -851,11 +851,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LlamaModel):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LLAMA_INPUTS_DOCSTRING = r"""
Args:
@ -1036,7 +1031,7 @@ class LlamaModel(LlamaPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -1304,7 +1304,7 @@ class LongformerEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -1434,11 +1434,6 @@ class LongformerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LongformerEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LONGFORMER_START_DOCSTRING = r"""

View File

@ -23,7 +23,6 @@ from typing import Any, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from ...activations import ACT2FN
from ...modeling_outputs import (
@ -1339,11 +1338,6 @@ class LongT5PreTrainedModel(PreTrainedModel):
mean=0.0, std=factor * ((d_model) ** -0.5)
)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
@ -1391,11 +1385,11 @@ class LongT5Stack(LongT5PreTrainedModel):
self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
self.gradient_checkpointing = False
# Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings
def get_input_embeddings(self):
return self.embed_tokens
@ -1509,7 +1503,7 @@ class LongT5Stack(LongT5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = checkpoint(
layer_outputs = self._gradient_checkpointing_func(
layer_module.forward,
hidden_states,
extended_attention_mask,

View File

@ -788,7 +788,7 @@ class LukeEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
word_hidden_states,
entity_hidden_states,
@ -914,11 +914,6 @@ class LukePreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LukeEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LUKE_START_DOCSTRING = r"""

View File

@ -552,11 +552,6 @@ class M2M100PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (M2M100Decoder, M2M100Encoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
M2M_100_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
@ -821,7 +816,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -1061,7 +1056,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,

View File

@ -500,11 +500,6 @@ class MarianPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MarianDecoder, MarianEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
@ -789,7 +784,7 @@ class MarianEncoder(MarianPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -1032,7 +1027,7 @@ class MarianDecoder(MarianPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -648,7 +648,7 @@ class MarkupLMEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,

View File

@ -1864,7 +1864,7 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module):
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -848,7 +848,7 @@ class DetrDecoder(nn.Module):
continue
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,
@ -1613,14 +1613,6 @@ class MaskFormerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MaskFormerPixelLevelModule):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None
if isinstance(module, DetrDecoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings(
"The bare MaskFormer Model outputting raw hidden-states without any specific head on top.",

View File

@ -688,7 +688,7 @@ class MaskFormerSwinEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func(
layer_hidden_states, output_dimensions, layer_all_hidden_states = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
@ -748,11 +748,6 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MaskFormerSwinEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):

View File

@ -516,11 +516,6 @@ class MBartPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MBartDecoder, MBartEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
@ -829,7 +824,7 @@ class MBartEncoder(MBartPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
@ -1081,7 +1076,7 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,

View File

@ -551,7 +551,7 @@ class MegatronBertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
@ -723,11 +723,6 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MegatronBertEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert

Some files were not shown because too many files have changed in this diff Show More