mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[core
/ gradient_checkpointing
] Refactor GC - part 2 (#27073)
* fix * more fixes * fix other models * fix long t5 * use `gradient_checkpointing_func` instead * fix copies * set `gradient_checkpointing_func` as a private attribute and retrieve previous behaviour * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * replace it with `is_gradient_checkpointing_set` * remove default * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fixup --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
5be1fb6d1f
commit
ffff9e70ab
@ -1873,7 +1873,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
|
||||
)
|
||||
|
||||
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=gradient_checkpointing_func))
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||
@ -1882,6 +1882,30 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# the gradients to make sure the gradient flows.
|
||||
self.enable_input_require_grads()
|
||||
|
||||
def _set_gradient_checkpointing(
|
||||
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
|
||||
):
|
||||
is_gradient_checkpointing_set = False
|
||||
|
||||
# Apply it on the top-level module in case the top-level modules supports it
|
||||
# for example, LongT5Stack inherits from `PreTrainedModel`.
|
||||
if hasattr(self, "gradient_checkpointing"):
|
||||
self._gradient_checkpointing_func = gradient_checkpointing_func
|
||||
self.gradient_checkpointing = enable
|
||||
is_gradient_checkpointing_set = True
|
||||
|
||||
for module in self.modules():
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module._gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = enable
|
||||
is_gradient_checkpointing_set = True
|
||||
|
||||
if not is_gradient_checkpointing_set:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
|
||||
" `gradient_checkpointing` to modules of the model that uses checkpointing."
|
||||
)
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""
|
||||
Deactivates gradient checkpointing for the current model.
|
||||
@ -1890,7 +1914,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
activations".
|
||||
"""
|
||||
if self.supports_gradient_checkpointing:
|
||||
self.apply(partial(self._set_gradient_checkpointing, gradient_checkpointing_func=None))
|
||||
self._set_gradient_checkpointing(enable=False)
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
self.disable_input_require_grads()
|
||||
|
@ -1095,7 +1095,7 @@ class AlignTextEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1192,11 +1192,6 @@ class AlignPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (AlignTextModel, AlignVisionModel, AlignTextEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The text model from ALIGN without any head or projection on top.""",
|
||||
@ -1331,6 +1326,7 @@ class AlignTextModel(AlignPreTrainedModel):
|
||||
class AlignVisionModel(AlignPreTrainedModel):
|
||||
config_class = AlignVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = False
|
||||
|
||||
def __init__(self, config: AlignVisionConfig):
|
||||
super().__init__(config)
|
||||
|
@ -646,7 +646,7 @@ class AltRobertaEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -955,7 +955,7 @@ class AltCLIPEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1078,14 +1078,6 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, AltCLIPEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
if isinstance(module, AltRobertaEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING
|
||||
class AltCLIPVisionTransformer(nn.Module):
|
||||
|
@ -336,7 +336,7 @@ class ASTEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
@ -388,12 +388,6 @@ class ASTPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.ViTPreTrainedModel._set_gradient_checkpointing with ViT->AST
|
||||
def _set_gradient_checkpointing(self, module: ASTEncoder, gradient_checkpointing_func=None) -> None:
|
||||
if isinstance(module, ASTEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -946,11 +946,6 @@ class AutoformerPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (AutoformerDecoder, AutoformerEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
AUTOFORMER_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -1208,7 +1203,7 @@ class AutoformerEncoder(AutoformerPreTrainedModel):
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1420,7 +1415,7 @@ class AutoformerDecoder(AutoformerPreTrainedModel):
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -317,11 +317,6 @@ class BarkPreTrainedModel(PreTrainedModel):
|
||||
|
||||
return get_parameter_device(self)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
BARK_MODEL_START_DOCSTRING = """
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -642,7 +637,7 @@ class BarkCausalModel(BarkPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
|
@ -521,11 +521,6 @@ class BartPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (BartDecoder, BartEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
@ -855,7 +850,7 @@ class BartEncoder(BartPreTrainedModel):
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1105,7 +1100,7 @@ class BartDecoder(BartPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -510,7 +510,7 @@ class BeitEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
@ -566,11 +566,6 @@ class BeitPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, BeitEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
BEIT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -593,7 +593,7 @@ class BertEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -757,11 +757,6 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, BertEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BertForPreTrainingOutput(ModelOutput):
|
||||
|
@ -401,7 +401,7 @@ class BertEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -602,11 +602,6 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, BertEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
BERT_GENERATION_START_DOCSTRING = r"""
|
||||
|
||||
|
@ -1617,7 +1617,7 @@ class BigBirdEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1779,11 +1779,6 @@ class BigBirdPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, BigBirdEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
BIG_BIRD_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
|
@ -1609,11 +1609,6 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
@ -1944,7 +1939,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -2284,7 +2279,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -376,11 +376,6 @@ class BioGptPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, BioGptModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
BIOGPT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
@ -591,7 +586,7 @@ class BioGptModel(BioGptPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -660,7 +660,6 @@ class BitPreTrainedModel(PreTrainedModel):
|
||||
config_class = BitConfig
|
||||
base_model_prefix = "bit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
@ -669,11 +668,6 @@ class BitPreTrainedModel(PreTrainedModel):
|
||||
nn.init.constant_(module.weight, 1)
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, BitModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
BIT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -483,11 +483,6 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
@ -778,7 +773,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1027,7 +1022,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -480,11 +480,6 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
@ -776,7 +771,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1024,7 +1019,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -34,7 +34,7 @@ from ...utils import (
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
|
||||
from .modeling_blip_text import BlipTextEncoder, BlipTextLMHeadModel, BlipTextModel
|
||||
from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -461,11 +461,6 @@ class BlipPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (BlipEncoder, BlipTextEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
BLIP_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -623,7 +618,7 @@ class BlipEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -422,7 +422,7 @@ class BlipTextEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -297,15 +297,6 @@ class Blip2PreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (Blip2Encoder, Blip2QFormerEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
# Enable / disable GC for the language model as well
|
||||
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
|
||||
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
|
||||
|
||||
|
||||
BLIP_2_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -478,7 +469,7 @@ class Blip2Encoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -943,7 +934,7 @@ class Blip2QFormerEncoder(nn.Module):
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -496,11 +496,6 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, BloomModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_standard_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||
@ -762,7 +757,7 @@ class BloomModel(BloomPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
|
@ -804,7 +804,7 @@ class BridgeTowerTextEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -651,7 +651,7 @@ class BrosEncoder(nn.Module):
|
||||
"`use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
bbox_pos_emb,
|
||||
|
@ -524,7 +524,7 @@ class CamembertEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -620,11 +620,6 @@ class CamembertPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, CamembertEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
CAMEMBERT_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
|
@ -795,7 +795,7 @@ class CanineEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -913,11 +913,6 @@ class CaninePreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, CanineEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
CANINE_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
|
@ -742,11 +742,6 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
CHINESE_CLIP_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
@ -910,7 +905,7 @@ class ChineseCLIPTextEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1014,7 +1009,7 @@ class ChineseCLIPVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
output_attentions,
|
||||
|
@ -939,7 +939,7 @@ class ClapAudioEncoder(nn.Module):
|
||||
input_dimensions = self.input_resolutions[i]
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
|
||||
)
|
||||
else:
|
||||
@ -1588,7 +1588,7 @@ class ClapTextEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1689,11 +1689,6 @@ class ClapPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ClapTextEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
class ClapAudioModel(ClapPreTrainedModel):
|
||||
config_class = ClapAudioConfig
|
||||
|
@ -467,11 +467,6 @@ class CLIPPreTrainedModel(PreTrainedModel):
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, CLIPEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
CLIP_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -640,7 +635,7 @@ class CLIPEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -479,11 +479,6 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, CLIPSegEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
CLIPSEG_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
@ -649,7 +644,7 @@ class CLIPSegEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -339,11 +339,6 @@ class CodeGenPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, CodeGenModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
CODEGEN_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
@ -543,7 +538,7 @@ class CodeGenModel(CodeGenPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
|
@ -1171,11 +1171,6 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ConditionalDetrDecoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
CONDITIONAL_DETR_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -1519,7 +1514,7 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
|
||||
# apply transformation
|
||||
query_sine_embed = query_sine_embed_before_transformation * pos_transformation
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
combined_attention_mask,
|
||||
|
@ -264,11 +264,6 @@ class ConvBertPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ConvBertEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
class SeparableConv1D(nn.Module):
|
||||
"""This class implements separable convolution, i.e. a depthwise and a pointwise layer"""
|
||||
@ -633,7 +628,7 @@ class ConvBertEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -282,7 +282,6 @@ class ConvNextPreTrainedModel(PreTrainedModel):
|
||||
config_class = ConvNextConfig
|
||||
base_model_prefix = "convnext"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -296,11 +295,6 @@ class ConvNextPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ConvNextEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
CONVNEXT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -303,7 +303,6 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
|
||||
config_class = ConvNextV2Config
|
||||
base_model_prefix = "convnextv2"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -317,11 +316,6 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ConvNextV2Encoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
CONVNEXTV2_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -536,7 +536,6 @@ class CpmAntPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = CpmAntConfig
|
||||
base_model_prefix = "cpmant"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -556,11 +555,6 @@ class CpmAntPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, CpmAntSegmentPositionEmbedding):
|
||||
module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, CpmAntEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
CPMANT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
|
@ -293,7 +293,7 @@ class Data2VecAudioFeatureEncoder(nn.Module):
|
||||
|
||||
for conv_layer in self.conv_layers:
|
||||
if self._requires_grad and self.gradient_checkpointing and self.training:
|
||||
hidden_states = self.gradient_checkpointing_func(
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
conv_layer.__call__,
|
||||
hidden_states,
|
||||
)
|
||||
@ -586,7 +586,7 @@ class Data2VecAudioEncoder(nn.Module):
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -748,11 +748,6 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||
return attention_mask
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DATA2VEC_AUDIO_START_DOCSTRING = r"""
|
||||
Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
|
||||
|
@ -510,7 +510,7 @@ class Data2VecTextEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -608,11 +608,6 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, Data2VecTextEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DATA2VECTEXT_START_DOCSTRING = r"""
|
||||
Data2VecText was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
|
||||
|
@ -522,7 +522,7 @@ class Data2VecVisionEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
@ -579,11 +579,6 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, Data2VecVisionEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DATA2VEC_VISION_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -457,7 +457,7 @@ class DebertaEncoder(nn.Module):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
hidden_states = self.gradient_checkpointing_func(
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
next_kv,
|
||||
attention_mask,
|
||||
@ -833,11 +833,6 @@ class DebertaPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, DebertaEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DEBERTA_START_DOCSTRING = r"""
|
||||
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
|
||||
|
@ -501,7 +501,7 @@ class DebertaV2Encoder(nn.Module):
|
||||
all_hidden_states = all_hidden_states + (output_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
output_states = self.gradient_checkpointing_func(
|
||||
output_states = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
next_kv,
|
||||
attention_mask,
|
||||
@ -932,11 +932,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, DebertaV2Encoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DEBERTA_START_DOCSTRING = r"""
|
||||
The DeBERTa model was proposed in [DeBERTa: Decoding-enhanced BERT with Disentangled
|
||||
|
@ -469,11 +469,6 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, DecisionTransformerGPT2Model):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
@ -632,7 +627,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
|
@ -1088,11 +1088,6 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
|
||||
if hasattr(module, "level_embed"):
|
||||
nn.init.normal_(module.level_embed)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, DeformableDetrDecoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DEFORMABLE_DETR_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -1384,7 +1379,7 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
|
@ -357,7 +357,7 @@ class DeiTEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
@ -409,11 +409,6 @@ class DeiTPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module: DeiTEncoder, gradient_checkpointing_func=None) -> None:
|
||||
if isinstance(module, DeiTEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DEIT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -504,11 +504,6 @@ class MCTCTPreTrainedModel(PreTrainedModel):
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long()
|
||||
return attention_mask
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (MCTCTEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
MCTCT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
@ -617,7 +612,7 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -456,11 +456,6 @@ class OpenLlamaPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, OpenLlamaModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
OPEN_LLAMA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
@ -666,7 +661,7 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -163,11 +163,6 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel):
|
||||
main_input_name = "trajectories"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, TrajectoryTransformerModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||
@ -551,7 +546,7 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
layer_past,
|
||||
|
@ -387,11 +387,6 @@ class VanPreTrainedModel(PreTrainedModel):
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, VanModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
VAN_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -979,11 +979,6 @@ class DetaPreTrainedModel(PreTrainedModel):
|
||||
if hasattr(module, "level_embed"):
|
||||
nn.init.normal_(module.level_embed)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, DetaDecoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DETA_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -1276,7 +1271,7 @@ class DetaDecoder(DetaPreTrainedModel):
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
|
@ -927,11 +927,6 @@ class DetrPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, DetrDecoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DETR_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -1254,7 +1249,7 @@ class DetrDecoder(DetrPreTrainedModel):
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
combined_attention_mask,
|
||||
|
@ -660,9 +660,6 @@ class DinatPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module: DinatEncoder, gradient_checkpointing_func=None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
DINAT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
|
@ -447,7 +447,7 @@ class Dinov2Encoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
@ -510,11 +510,6 @@ class Dinov2PreTrainedModel(PreTrainedModel):
|
||||
std=self.config.initializer_range,
|
||||
).to(module.cls_token.dtype)
|
||||
|
||||
def _set_gradient_checkpointing(self, module: Dinov2Encoder, gradient_checkpointing_func=None) -> None:
|
||||
if isinstance(module, Dinov2Encoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DINOV2_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -358,7 +358,7 @@ class Transformer(nn.Module):
|
||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_state,
|
||||
attn_mask,
|
||||
@ -424,11 +424,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, Transformer):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DISTILBERT_START_DOCSTRING = r"""
|
||||
|
||||
|
@ -749,7 +749,7 @@ class DonutSwinEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
|
||||
)
|
||||
else:
|
||||
@ -819,11 +819,6 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, DonutSwinEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
SWIN_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
|
@ -30,7 +30,7 @@ from ...utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ..bert.modeling_bert import BertEncoder, BertModel
|
||||
from ..bert.modeling_bert import BertModel
|
||||
from .configuration_dpr import DPRConfig
|
||||
|
||||
|
||||
@ -164,11 +164,6 @@ class DPRPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, BertEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
class DPREncoder(DPRPreTrainedModel):
|
||||
base_model_prefix = "bert_model"
|
||||
|
@ -528,7 +528,7 @@ class DPTViTEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
@ -812,11 +812,6 @@ class DPTPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, DPTViTEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
DPT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -486,7 +486,6 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
|
||||
config_class = EfficientNetConfig
|
||||
base_model_prefix = "efficientnet"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -500,11 +499,6 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, EfficientNetBlock):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare EfficientNet model outputting raw features without any specific head on top.",
|
||||
|
@ -571,7 +571,7 @@ class ElectraEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -687,11 +687,6 @@ class ElectraPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ElectraEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ElectraForPreTrainingOutput(ModelOutput):
|
||||
|
@ -446,7 +446,6 @@ class EncodecPreTrainedModel(PreTrainedModel):
|
||||
config_class = EncodecConfig
|
||||
base_model_prefix = "encodec"
|
||||
main_input_name = "input_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -473,11 +472,6 @@ class EncodecPreTrainedModel(PreTrainedModel):
|
||||
elif "bias" in name:
|
||||
nn.init.constant_(param, 0.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (EncodecEncoder, EncodecDecoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
ENCODEC_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
|
@ -265,11 +265,6 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
# call both encoder and decoder function on gradient checkpointing
|
||||
self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
@ -506,7 +506,7 @@ class ErnieEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -675,11 +675,6 @@ class ErniePreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ErnieEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->Ernie
|
||||
|
@ -411,7 +411,6 @@ class ErnieMPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = ErnieMConfig
|
||||
base_model_prefix = "ernie_m"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -429,11 +428,6 @@ class ErnieMPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ErnieMEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
ERNIE_M_START_DOCSTRING = r"""
|
||||
|
||||
|
@ -605,7 +605,7 @@ class EsmEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -705,11 +705,6 @@ class EsmPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, EsmEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
ESM_START_DOCSTRING = r"""
|
||||
|
||||
|
@ -1101,12 +1101,6 @@ class FalconPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->FalconModel
|
||||
def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, FalconModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
@staticmethod
|
||||
def _convert_cache_to_standard_format(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||
@ -1282,7 +1276,7 @@ class FalconModel(FalconPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
alibi,
|
||||
|
@ -663,7 +663,7 @@ class FlavaEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -873,11 +873,6 @@ class FlavaPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module: FlavaEncoder, gradient_checkpointing_func=None) -> None:
|
||||
if isinstance(module, FlavaEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
|
@ -292,7 +292,7 @@ class FNetEncoder(nn.Module):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(layer_module.__call__, hidden_states)
|
||||
layer_outputs = self._gradient_checkpointing_func(layer_module.__call__, hidden_states)
|
||||
else:
|
||||
layer_outputs = layer_module(hidden_states)
|
||||
|
||||
@ -424,11 +424,6 @@ class FNetPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, FNetEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FNetForPreTrainingOutput(ModelOutput):
|
||||
|
@ -586,7 +586,7 @@ class FocalNetEncoder(nn.Module):
|
||||
|
||||
for i, stage_module in enumerate(self.stages):
|
||||
if self.gradient_checkpointing and self.training:
|
||||
stage_outputs = self.gradient_checkpointing_func(
|
||||
stage_outputs = self._gradient_checkpointing_func(
|
||||
stage_module.__call__,
|
||||
hidden_states,
|
||||
input_dimensions,
|
||||
@ -652,11 +652,6 @@ class FocalNetPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, FocalNetEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
FOCALNET_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
|
@ -70,11 +70,6 @@ class FuyuPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, FuyuForCausalLM):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
FUYU_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
|
@ -452,7 +452,7 @@ class GitEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -528,11 +528,6 @@ class GitPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (GitEncoder, GitVisionEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
GIT_START_DOCSTRING = r"""
|
||||
|
||||
@ -874,7 +869,7 @@ class GitVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -480,11 +480,6 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, GPT2Model):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPT2DoubleHeadsModelOutput(ModelOutput):
|
||||
@ -878,7 +873,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
|
@ -404,12 +404,6 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, GPTBigCodeModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
GPT_BIGCODE_START_DOCSTRING = r"""
|
||||
|
||||
@ -651,7 +645,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
|
@ -384,11 +384,6 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, GPTNeoModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
GPT_NEO_START_DOCSTRING = r"""
|
||||
|
||||
@ -605,7 +600,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
|
@ -78,11 +78,6 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, GPTNeoXModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
@ -642,7 +637,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -48,7 +48,6 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = GPTNeoXJapaneseConfig
|
||||
base_model_prefix = "gpt_neox_japanese"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["GPTNeoXJapaneseLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
|
||||
@ -66,11 +65,6 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, GPTNeoXJapaneseModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
class GPTNeoXJapaneseAttention(nn.Module):
|
||||
def __init__(self, config, use_bias=False):
|
||||
|
@ -363,11 +363,6 @@ class GPTJPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, GPTJModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
GPTJ_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
@ -670,7 +665,7 @@ class GPTJModel(GPTJPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
|
@ -759,11 +759,6 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
|
||||
module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
||||
module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (GPTSanJapaneseAttention,)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
|
||||
def _shift_right(self, input_ids):
|
||||
decoder_start_token_id = self.config.decoder_start_token_id
|
||||
|
@ -712,7 +712,6 @@ class GraphormerPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = GraphormerConfig
|
||||
base_model_prefix = "graphormer"
|
||||
supports_gradient_checkpointing = True
|
||||
main_input_name_nodes = "input_nodes"
|
||||
main_input_name_edges = "input_edges"
|
||||
|
||||
@ -772,11 +771,6 @@ class GraphormerPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, GraphormerModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
class GraphormerModel(GraphormerPreTrainedModel):
|
||||
"""The Graphormer model is a graph-encoder model.
|
||||
|
@ -804,11 +804,6 @@ class GroupViTPreTrainedModel(PreTrainedModel):
|
||||
nn.init.normal_(module.fc1.weight, std=fc_std)
|
||||
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (GroupViTTextEncoder, GroupViTVisionEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
GROUPVIT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
@ -1031,7 +1026,7 @@ class GroupViTTextEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -346,7 +346,7 @@ class HubertFeatureEncoder(nn.Module):
|
||||
|
||||
for conv_layer in self.conv_layers:
|
||||
if self._requires_grad and self.gradient_checkpointing and self.training:
|
||||
hidden_states = self.gradient_checkpointing_func(
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
conv_layer.__call__,
|
||||
hidden_states,
|
||||
)
|
||||
@ -724,7 +724,7 @@ class HubertEncoder(nn.Module):
|
||||
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -808,7 +808,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -876,11 +876,6 @@ class HubertPreTrainedModel(PreTrainedModel):
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (HubertFeatureEncoder, HubertEncoder, HubertEncoderStableLayerNorm)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
|
@ -40,7 +40,7 @@ from ...utils import (
|
||||
)
|
||||
from .configuration_idefics import IdeficsConfig
|
||||
from .perceiver import IdeficsPerceiverResampler
|
||||
from .vision import IdeficsVisionEncoder, IdeficsVisionTransformer
|
||||
from .vision import IdeficsVisionTransformer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -978,11 +978,6 @@ class IdeficsPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (IdeficsModel, IdeficsVisionEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
LLAMA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
@ -1339,7 +1334,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
vblock,
|
||||
decoder_layer,
|
||||
hidden_states,
|
||||
|
@ -401,7 +401,7 @@ class IdeficsVisionEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -525,11 +525,6 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, ImageGPTModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
IMAGEGPT_START_DOCSTRING = r"""
|
||||
|
||||
@ -817,7 +812,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
outputs = self.gradient_checkpointing_func(
|
||||
outputs = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
hidden_states,
|
||||
None,
|
||||
|
@ -924,11 +924,6 @@ class InformerPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (InformerDecoder, InformerEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
INFORMER_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -1216,7 +1211,7 @@ class InformerEncoder(InformerPreTrainedModel):
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1224,7 +1219,7 @@ class InformerEncoder(InformerPreTrainedModel):
|
||||
output_attentions,
|
||||
)
|
||||
if conv_layer is not None:
|
||||
output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0])
|
||||
output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0])
|
||||
layer_outputs = (output,) + layer_outputs[1:]
|
||||
else:
|
||||
layer_outputs = encoder_layer(
|
||||
@ -1433,7 +1428,7 @@ class InformerDecoder(InformerPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -304,15 +304,6 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
|
||||
elif isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
# Enable / disable GC for the language model as well
|
||||
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
|
||||
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
|
||||
|
||||
|
||||
INSTRUCTBLIP_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -467,7 +458,7 @@ class InstructBlipEncoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -938,7 +929,7 @@ class InstructBlipQFormerEncoder(nn.Module):
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -487,7 +487,7 @@ class LayoutLMEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -633,11 +633,6 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, LayoutLMEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
LAYOUTLM_START_DOCSTRING = r"""
|
||||
The LayoutLM model was proposed in [LayoutLM: Pre-training of Text and Layout for Document Image
|
||||
|
@ -439,7 +439,7 @@ class LayoutLMv2Encoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -508,11 +508,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, LayoutLMv2Encoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
def my_convert_sync_batchnorm(module, process_group=None):
|
||||
# same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d`
|
||||
|
@ -657,7 +657,7 @@ class LayoutLMv3Encoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -1155,11 +1155,6 @@ class LEDPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (LEDDecoder, LEDEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
@ -1877,7 +1872,7 @@ class LEDEncoder(LEDPreTrainedModel):
|
||||
layer_outputs = (None, None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -2138,7 +2133,7 @@ class LEDDecoder(LEDPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
combined_attention_mask,
|
||||
|
@ -493,7 +493,6 @@ class LevitPreTrainedModel(PreTrainedModel):
|
||||
config_class = LevitConfig
|
||||
base_model_prefix = "levit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
@ -507,11 +506,6 @@ class LevitPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, LevitModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
LEVIT_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
|
||||
|
@ -514,7 +514,7 @@ class LiltEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layout_inputs,
|
||||
@ -601,11 +601,6 @@ class LiltPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, LiltEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
LILT_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
|
@ -851,11 +851,6 @@ class LlamaPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, LlamaModel):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
LLAMA_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
@ -1036,7 +1031,7 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -1304,7 +1304,7 @@ class LongformerEncoder(nn.Module):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1434,11 +1434,6 @@ class LongformerPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, LongformerEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
LONGFORMER_START_DOCSTRING = r"""
|
||||
|
||||
|
@ -23,7 +23,6 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import (
|
||||
@ -1339,11 +1338,6 @@ class LongT5PreTrainedModel(PreTrainedModel):
|
||||
mean=0.0, std=factor * ((d_model) ** -0.5)
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
|
||||
def _shift_right(self, input_ids):
|
||||
decoder_start_token_id = self.config.decoder_start_token_id
|
||||
@ -1391,11 +1385,11 @@ class LongT5Stack(LongT5PreTrainedModel):
|
||||
self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
@ -1509,7 +1503,7 @@ class LongT5Stack(LongT5PreTrainedModel):
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = checkpoint(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.forward,
|
||||
hidden_states,
|
||||
extended_attention_mask,
|
||||
|
@ -788,7 +788,7 @@ class LukeEncoder(nn.Module):
|
||||
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
word_hidden_states,
|
||||
entity_hidden_states,
|
||||
@ -914,11 +914,6 @@ class LukePreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, LukeEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
LUKE_START_DOCSTRING = r"""
|
||||
|
||||
|
@ -552,11 +552,6 @@ class M2M100PreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (M2M100Decoder, M2M100Encoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
M2M_100_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
@ -821,7 +816,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
|
||||
# under deepspeed zero3 all gpus must run in sync
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1061,7 +1056,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
combined_attention_mask,
|
||||
|
@ -500,11 +500,6 @@ class MarianPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (MarianDecoder, MarianEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
@ -789,7 +784,7 @@ class MarianEncoder(MarianPreTrainedModel):
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1032,7 +1027,7 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -648,7 +648,7 @@ class MarkupLMEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -1864,7 +1864,7 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module):
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -848,7 +848,7 @@ class DetrDecoder(nn.Module):
|
||||
continue
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
combined_attention_mask,
|
||||
@ -1613,14 +1613,6 @@ class MaskFormerPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, MaskFormerPixelLevelModule):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
if isinstance(module, DetrDecoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare MaskFormer Model outputting raw hidden-states without any specific head on top.",
|
||||
|
@ -688,7 +688,7 @@ class MaskFormerSwinEncoder(nn.Module):
|
||||
layer_head_mask = head_mask[i] if head_mask is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func(
|
||||
layer_hidden_states, output_dimensions, layer_all_hidden_states = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
layer_head_mask,
|
||||
@ -748,11 +748,6 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, MaskFormerSwinEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
|
@ -516,11 +516,6 @@ class MBartPreTrainedModel(PreTrainedModel):
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, (MBartDecoder, MBartEncoder)):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
pad_token = self.config.pad_token_id
|
||||
@ -829,7 +824,7 @@ class MBartEncoder(MBartPreTrainedModel):
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
encoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -1081,7 +1076,7 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
|
@ -551,7 +551,7 @@ class MegatronBertEncoder(nn.Module):
|
||||
past_key_value = past_key_values[i] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self.gradient_checkpointing_func(
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
layer_module.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@ -723,11 +723,6 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
|
||||
if isinstance(module, MegatronBertEncoder):
|
||||
module.gradient_checkpointing_func = gradient_checkpointing_func
|
||||
module.gradient_checkpointing = gradient_checkpointing_func is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user