Add mean_resizing for every VLMs' resizing_token_embeddings() (#35717)

* refine all resize_token_embedding()

* ruff format

* hotfix
This commit is contained in:
Gar 2025-02-03 22:03:49 +08:00 committed by GitHub
parent 7eecdf2a86
commit 9d2056f12b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 71 additions and 32 deletions

View File

@ -455,8 +455,9 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of=None,
mean_resizing=True
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
# Update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings

View File

@ -73,8 +73,9 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of=None,
mean_resizing=True
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
# Update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings

View File

@ -1189,11 +1189,11 @@ class BarkFineModel(BarkPreTrainedModel):
# one lm_head for each codebook
self.lm_heads = new_output_embeddings
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
old_embeddings_list = self.get_input_embeddings()
new_embeddings_list = nn.ModuleList(
[
self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing)
for old_embeddings in old_embeddings_list
]
)
@ -1211,7 +1211,10 @@ class BarkFineModel(BarkPreTrainedModel):
return self.get_input_embeddings()
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> nn.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
@ -1230,11 +1233,19 @@ class BarkFineModel(BarkPreTrainedModel):
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
mean_resizing (`bool`):
Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds

View File

@ -1577,8 +1577,10 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

View File

@ -2457,8 +2457,10 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

View File

@ -1232,8 +1232,10 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMi
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

View File

@ -1184,8 +1184,10 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, Ge
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

View File

@ -1295,8 +1295,10 @@ class GPTSanJapaneseForConditionalGeneration(GPTSanJapanesePreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

View File

@ -2319,8 +2319,10 @@ class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.led.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

View File

@ -1072,9 +1072,11 @@ class LxmertForPreTraining(LxmertPreTrainedModel):
}
self.visual_losses = visual_losses
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
# Adding the following steps to resize bias to match the shape of resized embeddings
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self.cls.predictions.bias = self._resize_bias(self.cls.predictions.bias, new_num_tokens)
return new_embeddings

View File

@ -1252,8 +1252,10 @@ class MarianMTModel(MarianPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
if self.config.share_encoder_decoder_embeddings:
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings

View File

@ -1546,8 +1546,10 @@ class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

View File

@ -1370,8 +1370,10 @@ class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings

View File

@ -1658,9 +1658,11 @@ class OmDetTurboForObjectDetection(OmDetTurboPreTrainedModel):
def set_input_embeddings(self, value):
self.language_backbone.model.set_input_embeddings(value)
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing: bool = True
) -> nn.Embedding:
model_embeds = self.language_backbone.model.resize_token_embeddings(
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of, mean_resizing=mean_resizing
)
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings

View File

@ -1265,8 +1265,10 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

View File

@ -1274,8 +1274,10 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.model.get_decoder()
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings