mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
Enable requires_grad
on input embedding to train on top of frozen layers (#21598)
* v1 * make fixup * add more methods
This commit is contained in:
parent
8c5026628a
commit
41fa672df1
@ -1148,6 +1148,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
return False
|
||||
return True
|
||||
|
||||
def enable_input_require_grads(self):
|
||||
"""
|
||||
Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
|
||||
the model weights fixed.
|
||||
"""
|
||||
|
||||
def make_inputs_require_grads(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
|
||||
|
||||
def disable_input_require_grads(self):
|
||||
"""
|
||||
Removes the `_require_grads_hook`.
|
||||
"""
|
||||
self._require_grads_hook.remove()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
"""
|
||||
Returns the model's input embeddings.
|
||||
|
Loading…
Reference in New Issue
Block a user