mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[core
/ gradient_checkpointing
] add support for old GC method (#27610)
* add support for old GC method * add also disable * up * oops
This commit is contained in:
parent
8eb9e29d8d
commit
0e6794ff1c
@ -1876,7 +1876,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
|
||||
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
|
||||
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||
|
||||
if not _is_using_old_format:
|
||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||
else:
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
logger.warn(
|
||||
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
|
||||
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
||||
)
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||
@ -1915,7 +1926,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
activations".
|
||||
"""
|
||||
if self.supports_gradient_checkpointing:
|
||||
self._set_gradient_checkpointing(enable=False)
|
||||
# For old GC format (transformers < 4.35.0) for models that live on the Hub
|
||||
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
|
||||
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
|
||||
if not _is_using_old_format:
|
||||
self._set_gradient_checkpointing(enable=False)
|
||||
else:
|
||||
logger.warn(
|
||||
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
|
||||
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
|
||||
)
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
self.disable_input_require_grads()
|
||||
|
Loading…
Reference in New Issue
Block a user