[resize_embedding] Introduce pad_to_multiple_of and guidance (#25088)

* fix

* revert cahnges and update resizing of embedding layer

* use wraning

* fixup

* more styling nits

* fix all tests that overload the embedding tests

* 👀👀 remove breakpoint

* remove useless overload + overload correctly where needed

* resize lm head with new vocab size

* reverse not necessary changes

* style

* fix CIs!

* fix last CI tests, adapt bark and Marian

* fixup
This commit is contained in:
Arthur 2023-08-17 17:00:32 +02:00 committed by GitHub
parent d2871b2975
commit d6bf08f7f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 107 additions and 66 deletions

View File

@ -1382,7 +1382,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
) -> nn.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
@ -1393,11 +1395,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens)
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if new_num_tokens is None:
return model_embeds
@ -1410,21 +1419,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model_embeds
def _resize_token_embeddings(self, new_num_tokens):
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
self.set_input_embeddings(new_embeddings)
# if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head = self.get_output_embeddings()
new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
new_lm_head = self._get_resized_lm_head(old_lm_head, new_embeddings.weight.shape[0])
self.set_output_embeddings(new_lm_head)
return self.get_input_embeddings()
def _get_resized_embeddings(
self, old_embeddings: nn.Embedding, new_num_tokens: Optional[int] = None
self,
old_embeddings: nn.Embedding,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
) -> nn.Embedding:
"""
Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
@ -1439,11 +1451,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
`torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
Return:
`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
`new_num_tokens` is `None`
"""
if pad_to_multiple_of is not None:
if not isinstance(pad_to_multiple_of, int):
raise ValueError(
f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
)
if new_num_tokens is None:
new_num_tokens = old_embeddings.weight.shape[0]
new_num_tokens = ((new_num_tokens // pad_to_multiple_of) + 1) * pad_to_multiple_of
else:
logger.warning(
"You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embeding"
f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
" For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
" https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
)
if new_num_tokens is None:
return old_embeddings

View File

@ -1077,18 +1077,25 @@ class BarkFineModel(BarkPreTrainedModel):
# one lm_head for each codebook
self.lm_heads = new_output_embeddings
def _resize_token_embeddings(self, new_num_tokens):
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
old_embeddings_list = self.get_input_embeddings()
new_embeddings_list = nn.ModuleList(
[self._get_resized_embeddings(old_embeddings, new_num_tokens) for old_embeddings in old_embeddings_list]
[
self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
for old_embeddings in old_embeddings_list
]
)
self.set_input_embeddings(new_embeddings_list)
new_num_tokens = [embed.weight.shape[0] for embed in new_embeddings_list]
# if word embeddings are not tied, make sure that lm head is resized as well
if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
old_lm_head_list = self.get_output_embeddings()
new_lm_head_list = nn.ModuleList(
[self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list]
[
self._get_resized_lm_head(old_lm_head, new_num_token)
for old_lm_head, new_num_token in zip(old_lm_head_list, new_num_tokens)
]
)
self.set_output_embeddings(new_lm_head_list)

View File

@ -1324,9 +1324,9 @@ class BartForConditionalGeneration(BartPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:

View File

@ -2508,9 +2508,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:

View File

@ -1277,9 +1277,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:

View File

@ -1244,9 +1244,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:

View File

@ -1313,9 +1313,9 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
return self._shift_right(labels)
# Copied from transformers.models.mbart.modeling_mbart.MBartForConditionalGeneration.resize_token_embeddings with MBart->GPTSanJapanese
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings
# Copied from transformers.models.mbart.modeling_mbart.MBartForConditionalGeneration._resize_final_logits_bias with MBart->GPTSanJapanese

View File

@ -2352,9 +2352,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
def get_decoder(self):
return self.led.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:

View File

@ -1267,10 +1267,6 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
return new_embeddings
def get_output_embeddings(self):
return self.lm_head

View File

@ -1316,17 +1316,18 @@ class MarianMTModel(MarianPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if self.config.share_encoder_decoder_embeddings:
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings
def _resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
def _resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of=None) -> nn.Embedding:
old_embeddings = self.get_input_embeddings()
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
self.set_input_embeddings(new_embeddings)
new_num_tokens = new_embeddings.weight.shape[0]
# update config.decoder_vocab_size if embeddings are tied
if self.config.share_encoder_decoder_embeddings:
self.config.decoder_vocab_size = new_num_tokens

View File

@ -1294,9 +1294,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:

View File

@ -1453,8 +1453,8 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings

View File

@ -1652,10 +1652,6 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
return new_embeddings
def get_output_embeddings(self):
return self.lm_head

View File

@ -1327,9 +1327,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:

View File

@ -1552,10 +1552,6 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
return new_embeddings
def get_output_embeddings(self):
return self.lm_head

View File

@ -1267,9 +1267,9 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
self._resize_final_logits_bias(new_num_tokens)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings
def _resize_final_logits_bias(self, new_num_tokens: int) -> None:

View File

@ -1282,10 +1282,6 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
return new_embeddings
def get_output_embeddings(self):
return self.lm_head

View File

@ -2359,10 +2359,6 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel):
"""
self.get_encoder().prenet.freeze_feature_encoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
return new_embeddings
def get_output_embeddings(self):
return self.text_decoder_postnet.get_output_embeddings()

View File

@ -1410,10 +1410,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens)
return new_embeddings
def get_output_embeddings(self):
return self.proj_out

View File

@ -1413,6 +1413,26 @@ class ModelTesterMixin:
self.assertTrue(models_equal)
config = copy.deepcopy(original_config)
model = model_class(config)
model.to(torch_device)
model_vocab_size = config.vocab_size
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
self.assertTrue(model.config.vocab_size + 10, model_vocab_size)
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
with self.assertRaisesRegex(
ValueError,
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
):
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
def test_resize_embeddings_untied(self):
(
original_config,