[core / gradient_checkpointing] add support for old GC method (#27610)

* add support for old GC method

* add also disable

* up

* oops
This commit is contained in:
Younes Belkada 2023-11-21 11:03:30 +01:00 committed by GitHub
parent 8eb9e29d8d
commit 0e6794ff1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1876,7 +1876,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
if not _is_using_old_format:
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
else:
self.apply(partial(self._set_gradient_checkpointing, value=True))
logger.warn(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
@ -1915,7 +1926,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
activations".
"""
if self.supports_gradient_checkpointing:
self._set_gradient_checkpointing(enable=False)
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
if not _is_using_old_format:
self._set_gradient_checkpointing(enable=False)
else:
logger.warn(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)
self.apply(partial(self._set_gradient_checkpointing, value=False))
if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads()