Fix resize_token_embeddings (#26861) (#26865)

* Fix `resize_token_embeddings` about `requires_grad`

The method `resize_token_embeddings` should keep `requires_grad`
unchanged for all parameters in embeddings.

Previously, `resize_token_embeddings` always set `requires_grad`
to `True`. After fixed, `resize_token_embeddings` copy the
`requires_grad` attribute in the old embeddings.
This commit is contained in:
Ziyu Chen 2023-11-22 01:51:48 +08:00 committed by GitHub
parent d2a980ec74
commit c5be38cd27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1586,6 +1586,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hasattr(old_embeddings, "_hf_hook"):
hook = old_embeddings._hf_hook
add_hook_to_module(new_embeddings, hook)
old_embeddings_requires_grad = old_embeddings.weight.requires_grad
new_embeddings.requires_grad_(old_embeddings_requires_grad)
self.set_input_embeddings(new_embeddings)
# Update new_num_tokens with the actual size of new_embeddings
@ -1605,6 +1607,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hasattr(old_lm_head, "_hf_hook"):
hook = old_lm_head._hf_hook
add_hook_to_module(new_lm_head, hook)
old_lm_head_requires_grad = old_lm_head.weight.requires_grad
new_lm_head.requires_grad_(old_lm_head_requires_grad)
self.set_output_embeddings(new_lm_head)
return self.get_input_embeddings()