[VLMs] add helpers for get/set embedding (#38144)

* add helpers in VLMs

* fix tied weight key test
This commit is contained in:
Raushan Turganbay 2025-05-26 09:50:32 +02:00 committed by GitHub
parent 6e3063422c
commit cba279f46c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 110 additions and 6 deletions

View File

@ -1219,6 +1219,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -389,6 +389,12 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -1438,9 +1438,6 @@ class Emu3Model(Emu3PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.text_model = Emu3TextModel._from_config(config.text_config)
if self.text_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys]
self.vqmodel = Emu3VQVAE(config.vq_config)
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
@ -1561,6 +1558,7 @@ class Emu3Model(Emu3PreTrainedModel):
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
base_model_prefix = ""
_tied_weights_keys = ["lm_head.weight"]
_checkpoint_conversion_mapping = {
"^text_model.model": "model.text_model",
"^vqmodel": "model.vqmodel",
@ -1581,6 +1579,18 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def text_model(self):

View File

@ -925,9 +925,6 @@ class Emu3Model(Emu3PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.text_model = Emu3TextModel._from_config(config.text_config)
if self.text_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys]
self.vqmodel = Emu3VQVAE(config.vq_config)
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
@ -1048,6 +1045,7 @@ class Emu3Model(Emu3PreTrainedModel):
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
base_model_prefix = ""
_tied_weights_keys = ["lm_head.weight"]
_checkpoint_conversion_mapping = {
"^text_model.model": "model.text_model",
"^vqmodel": "model.vqmodel",
@ -1068,6 +1066,18 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def text_model(self):

View File

@ -1008,6 +1008,12 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -755,6 +755,12 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -868,6 +868,12 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -359,6 +359,12 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -567,6 +567,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -698,6 +698,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -725,6 +725,12 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -401,6 +401,12 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -1795,6 +1795,12 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -416,6 +416,12 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -443,6 +443,12 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):

View File

@ -320,6 +320,12 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Make modules available throught conditional class for BC
@property
def language_model(self):