mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[resize_embedding
] Introduce pad_to_multiple_of
and guidance (#25088)
* fix * revert cahnges and update resizing of embedding layer * use wraning * fixup * more styling nits * fix all tests that overload the embedding tests * 👀👀 remove breakpoint * remove useless overload + overload correctly where needed * resize lm head with new vocab size * reverse not necessary changes * style * fix CIs! * fix last CI tests, adapt bark and Marian * fixup
This commit is contained in:
parent
d2871b2975
commit
d6bf08f7f6
@ -1382,7 +1382,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
||||
output_embeddings.out_features = input_embeddings.num_embeddings
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
|
||||
def resize_token_embeddings(
|
||||
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
||||
) -> nn.Embedding:
|
||||
"""
|
||||
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
|
||||
|
||||
@ -1393,11 +1395,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
|
||||
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
|
||||
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
If set will pad the embedding matrix to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
|
||||
details about this, or help on choosing the correct value for resizing, refer to this guide:
|
||||
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
||||
|
||||
Return:
|
||||
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
||||
"""
|
||||
model_embeds = self._resize_token_embeddings(new_num_tokens)
|
||||
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
if new_num_tokens is None:
|
||||
return model_embeds
|
||||
|
||||
@ -1410,21 +1419,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
return model_embeds
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
|
||||
old_embeddings = self.get_input_embeddings()
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
|
||||
# if word embeddings are not tied, make sure that lm head is resized as well
|
||||
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
|
||||
old_lm_head = self.get_output_embeddings()
|
||||
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
|
||||
new_lm_head = self._get_resized_lm_head(old_lm_head, new_embeddings.weight.shape[0])
|
||||
self.set_output_embeddings(new_lm_head)
|
||||
|
||||
return self.get_input_embeddings()
|
||||
|
||||
def _get_resized_embeddings(
|
||||
self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None
|
||||
self,
|
||||
old_embeddings: nn.Embedding,
|
||||
new_num_tokens: Optional[int] = None,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
) -> nn.Embedding:
|
||||
"""
|
||||
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
|
||||
@ -1439,11 +1451,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
|
||||
vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
|
||||
`torch.nn.Embedding` module of the model without doing anything.
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
If set will pad the embedding matrix to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
|
||||
details about this, or help on choosing the correct value for resizing, refer to this guide:
|
||||
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
||||
|
||||
|
||||
Return:
|
||||
`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
|
||||
`new_num_tokens` is `None`
|
||||
"""
|
||||
|
||||
if pad_to_multiple_of is not None:
|
||||
if not isinstance(pad_to_multiple_of, int):
|
||||
raise ValueError(
|
||||
f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
|
||||
)
|
||||
if new_num_tokens is None:
|
||||
new_num_tokens = old_embeddings.weight.shape[0]
|
||||
new_num_tokens = ((new_num_tokens // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
else:
|
||||
logger.warning(
|
||||
"You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embeding"
|
||||
f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
|
||||
" For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
|
||||
" https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
|
||||
)
|
||||
|
||||
if new_num_tokens is None:
|
||||
return old_embeddings
|
||||
|
||||
|
@ -1077,18 +1077,25 @@ class BarkFineModel(BarkPreTrainedModel):
|
||||
# one lm_head for each codebook
|
||||
self.lm_heads = new_output_embeddings
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
|
||||
old_embeddings_list = self.get_input_embeddings()
|
||||
new_embeddings_list = nn.ModuleList(
|
||||
[self._get_resized_embeddings(old_embeddings, new_num_tokens) for old_embeddings in old_embeddings_list]
|
||||
[
|
||||
self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
|
||||
for old_embeddings in old_embeddings_list
|
||||
]
|
||||
)
|
||||
self.set_input_embeddings(new_embeddings_list)
|
||||
new_num_tokens = [embed.weight.shape[0] for embed in new_embeddings_list]
|
||||
|
||||
# if word embeddings are not tied, make sure that lm head is resized as well
|
||||
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
|
||||
old_lm_head_list = self.get_output_embeddings()
|
||||
new_lm_head_list = nn.ModuleList(
|
||||
[self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list]
|
||||
[
|
||||
self._get_resized_lm_head(old_lm_head, new_num_token)
|
||||
for old_lm_head, new_num_token in zip(old_lm_head_list, new_num_tokens)
|
||||
]
|
||||
)
|
||||
self.set_output_embeddings(new_lm_head_list)
|
||||
|
||||
|
@ -1324,9 +1324,9 @@ class BartForConditionalGeneration(BartPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
||||
|
@ -2508,9 +2508,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
||||
|
@ -1277,9 +1277,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
||||
|
@ -1244,9 +1244,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
||||
|
@ -1313,9 +1313,9 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
|
||||
return self._shift_right(labels)
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartForConditionalGeneration.resize_token_embeddings with MBart->GPTSanJapanese
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
||||
return new_embeddings
|
||||
|
||||
# Copied from transformers.models.mbart.modeling_mbart.MBartForConditionalGeneration._resize_final_logits_bias with MBart->GPTSanJapanese
|
||||
|
@ -2352,9 +2352,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.led.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
||||
|
@ -1267,10 +1267,6 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
@ -1316,17 +1316,18 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
if self.config.share_encoder_decoder_embeddings:
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
def _resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of=None) -> nn.Embedding:
|
||||
old_embeddings = self.get_input_embeddings()
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
|
||||
self.set_input_embeddings(new_embeddings)
|
||||
|
||||
new_num_tokens = new_embeddings.weight.shape[0]
|
||||
# update config.decoder_vocab_size if embeddings are tied
|
||||
if self.config.share_encoder_decoder_embeddings:
|
||||
self.config.decoder_vocab_size = new_num_tokens
|
||||
|
@ -1294,9 +1294,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
||||
|
@ -1453,8 +1453,8 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
|
@ -1652,10 +1652,6 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
@ -1327,9 +1327,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
||||
|
@ -1552,10 +1552,6 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
@ -1267,9 +1267,9 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
self._resize_final_logits_bias(new_num_tokens)
|
||||
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
||||
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
||||
return new_embeddings
|
||||
|
||||
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
|
||||
|
@ -1282,10 +1282,6 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
|
@ -2359,10 +2359,6 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
|
||||
"""
|
||||
self.get_encoder().prenet.freeze_feature_encoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.text_decoder_postnet.get_output_embeddings()
|
||||
|
||||
|
@ -1410,10 +1410,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
def get_decoder(self):
|
||||
return self.model.get_decoder()
|
||||
|
||||
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
||||
new_embeddings = super().resize_token_embeddings(new_num_tokens)
|
||||
return new_embeddings
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.proj_out
|
||||
|
||||
|
@ -1413,6 +1413,26 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
config = copy.deepcopy(original_config)
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
model_vocab_size = config.vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
|
||||
self.assertTrue(model.config.vocab_size + 10, model_vocab_size)
|
||||
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
|
||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
|
||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
|
||||
):
|
||||
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
|
||||
|
||||
def test_resize_embeddings_untied(self):
|
||||
(
|
||||
original_config,
|
||||
|
Loading…
Reference in New Issue
Block a user