diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index f07ac7f3348..eb23aa08bff 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -455,8 +455,9 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin): self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, + mean_resizing=True ) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) # Update vocab size self.config.text_config.vocab_size = model_embeds.num_embeddings diff --git a/examples/modular-transformers/modular_new_task_model.py b/examples/modular-transformers/modular_new_task_model.py index 877fba00a50..a67cf2752fb 100644 --- a/examples/modular-transformers/modular_new_task_model.py +++ b/examples/modular-transformers/modular_new_task_model.py @@ -73,8 +73,9 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration): self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, + mean_resizing=True ) -> nn.Embedding: - model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) # Update vocab size self.config.text_config.vocab_size = model_embeds.num_embeddings diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 36a278263b5..56f8ce4d100 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -1189,11 +1189,11 @@ class BarkFineModel(BarkPreTrainedModel): # one lm_head for each codebook self.lm_heads = new_output_embeddings - def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True): old_embeddings_list = self.get_input_embeddings() new_embeddings_list = nn.ModuleList( [ - self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) + self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing) for old_embeddings in old_embeddings_list ] ) @@ -1211,7 +1211,10 @@ class BarkFineModel(BarkPreTrainedModel): return self.get_input_embeddings() def resize_token_embeddings( - self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + self, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + mean_resizing: bool = True, ) -> nn.Embedding: """ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. @@ -1230,11 +1233,19 @@ class BarkFineModel(BarkPreTrainedModel): `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + mean_resizing (`bool`): + Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and + covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`. + + Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models, + where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the + old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings. + Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html Return: `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. """ - model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) if new_num_tokens is None and pad_to_multiple_of is None: return model_embeds diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 4e1f0b389d4..e64ab3b2d04 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1577,8 +1577,10 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index fd52e4b8bb7..61634901289 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2457,8 +2457,10 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index ace9470d01e..16bea0a09f4 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1232,8 +1232,10 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMi def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 8564fbf3115..dec50328b76 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1184,8 +1184,10 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, Ge def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings diff --git a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py index c7a195dbea0..7274f5c02c3 100644 --- a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py @@ -1295,8 +1295,10 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index b5e8b43fbaa..e72ed197645 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2319,8 +2319,10 @@ class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin): def get_decoder(self): return self.led.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index b97d78d1505..36dae0ee1d7 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -1072,9 +1072,11 @@ class LxmertForPreTraining(LxmertPreTrainedModel): } self.visual_losses = visual_losses - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: # Adding the following steps to resize bias to match the shape of resized embeddings - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self.cls.predictions.bias = self._resize_bias(self.cls.predictions.bias, new_num_tokens) return new_embeddings diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index c9865256636..b64970e8063 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1252,8 +1252,10 @@ class MarianMTModel(MarianPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) if self.config.share_encoder_decoder_embeddings: self._resize_final_logits_bias(new_num_tokens) return new_embeddings diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index b2a0b38107e..8b42755ce35 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1546,8 +1546,10 @@ class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 7348493f218..ea1d12af0c8 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1370,8 +1370,10 @@ class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self._resize_final_logits_bias(new_num_tokens) return new_embeddings diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index 15d7c1f0591..31fcf1e44f1 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -1658,9 +1658,11 @@ class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel): def set_input_embeddings(self, value): self.language_backbone.model.set_input_embeddings(value) - def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing: bool = True + ) -> nn.Embedding: model_embeds = self.language_backbone.model.resize_token_embeddings( - new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of + new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of, mean_resizing=mean_resizing ) self.config.text_config.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 3b1cd70404c..fb560452a94 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1265,8 +1265,10 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 9d387207a90..e2f11d97b8f 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1274,8 +1274,10 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): def get_decoder(self): return self.model.get_decoder() - def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: - new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) self._resize_final_logits_bias(new_embeddings.weight.shape[0]) return new_embeddings