From cba279f46ca7d6af738b0d67812866859ffcfda3 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 26 May 2025 09:50:32 +0200 Subject: [PATCH] [VLMs] add helpers for get/set embedding (#38144) * add helpers in VLMs * fix tied weight key test --- src/transformers/models/aria/modeling_aria.py | 6 ++++++ .../models/aya_vision/modeling_aya_vision.py | 6 ++++++ src/transformers/models/emu3/modeling_emu3.py | 16 +++++++++++++--- src/transformers/models/emu3/modular_emu3.py | 16 +++++++++++++--- .../models/gemma3/modeling_gemma3.py | 6 ++++++ .../models/got_ocr2/modeling_got_ocr2.py | 6 ++++++ .../models/internvl/modeling_internvl.py | 6 ++++++ src/transformers/models/llava/modeling_llava.py | 6 ++++++ .../models/llava_next/modeling_llava_next.py | 6 ++++++ .../modeling_llava_next_video.py | 6 ++++++ .../llava_onevision/modeling_llava_onevision.py | 6 ++++++ .../models/mistral3/modeling_mistral3.py | 6 ++++++ .../models/mllama/modeling_mllama.py | 6 ++++++ .../models/paligemma/modeling_paligemma.py | 6 ++++++ .../models/video_llava/modeling_video_llava.py | 6 ++++++ .../models/vipllava/modeling_vipllava.py | 6 ++++++ 16 files changed, 110 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index abb751ab7df..8f552cfc815 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1219,6 +1219,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index a851d4d0a0f..e074d4b1193 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -389,6 +389,12 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index c13eb25d9a6..31f01db1b5a 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1438,9 +1438,6 @@ class Emu3Model(Emu3PreTrainedModel): def __init__(self, config): super().__init__(config) self.text_model = Emu3TextModel._from_config(config.text_config) - if self.text_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys] - self.vqmodel = Emu3VQVAE(config.vq_config) self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) @@ -1561,6 +1558,7 @@ class Emu3Model(Emu3PreTrainedModel): class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): base_model_prefix = "" + _tied_weights_keys = ["lm_head.weight"] _checkpoint_conversion_mapping = { "^text_model.model": "model.text_model", "^vqmodel": "model.vqmodel", @@ -1581,6 +1579,18 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def text_model(self): diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 9a141d61dba..8c86f81d523 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -925,9 +925,6 @@ class Emu3Model(Emu3PreTrainedModel): def __init__(self, config): super().__init__(config) self.text_model = Emu3TextModel._from_config(config.text_config) - if self.text_model._tied_weights_keys is not None: - self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys] - self.vqmodel = Emu3VQVAE(config.vq_config) self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map) @@ -1048,6 +1045,7 @@ class Emu3Model(Emu3PreTrainedModel): class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): base_model_prefix = "" + _tied_weights_keys = ["lm_head.weight"] _checkpoint_conversion_mapping = { "^text_model.model": "model.text_model", "^vqmodel": "model.vqmodel", @@ -1068,6 +1066,18 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin): def set_input_embeddings(self, value): self.model.set_input_embeddings(value) + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def text_model(self): diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 4a9fbfcd319..92fbeff3303 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -1008,6 +1008,12 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 6da4405fad5..0d6b44214ba 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -755,6 +755,12 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 6f06b32c168..c73f84d9222 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -868,6 +868,12 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 448879ec06f..1fcb00e6e58 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -359,6 +359,12 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 496049e3123..fb56a9b1748 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -567,6 +567,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 3cb81ada8ac..c38ce78bc99 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -698,6 +698,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 1a60c092ed9..9d40b2cef4d 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -725,6 +725,12 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 625e1c3185e..4da2570090b 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -401,6 +401,12 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index fcccd2b9ea6..97dd32b7215 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1795,6 +1795,12 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 15180b91b96..4e508ef3331 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -416,6 +416,12 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 937d44a7817..ed7a19ca664 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -443,6 +443,12 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self): diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 375278b5d85..dbe93bb6160 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -320,6 +320,12 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + # Make modules available throught conditional class for BC @property def language_model(self):