Enable requires_grad on input embedding to train on top of frozen layers (#21598)

* v1

* make fixup

* add more methods
This commit is contained in:
Younes Belkada 2023-02-14 09:43:06 +01:00 committed by GitHub
parent 8c5026628a
commit 41fa672df1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1148,6 +1148,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return False
return True
def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
the model weights fixed.
"""
def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)
self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
def disable_input_require_grads(self):
"""
Removes the `_require_grads_hook`.
"""
self._require_grads_hook.remove()
def get_input_embeddings(self) -> nn.Module:
"""
Returns the model's input embeddings.