mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
[VLMs] add helpers for get/set embedding (#38144)
* add helpers in VLMs * fix tied weight key test
This commit is contained in:
parent
6e3063422c
commit
cba279f46c
@ -1219,6 +1219,12 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -389,6 +389,12 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -1438,9 +1438,6 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.text_model = Emu3TextModel._from_config(config.text_config)
|
||||
if self.text_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys]
|
||||
|
||||
self.vqmodel = Emu3VQVAE(config.vq_config)
|
||||
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
|
||||
|
||||
@ -1561,6 +1558,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
|
||||
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = ""
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^text_model.model": "model.text_model",
|
||||
"^vqmodel": "model.vqmodel",
|
||||
@ -1581,6 +1579,18 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def text_model(self):
|
||||
|
@ -925,9 +925,6 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.text_model = Emu3TextModel._from_config(config.text_config)
|
||||
if self.text_model._tied_weights_keys is not None:
|
||||
self._tied_weights_keys = [f"text_model.{k}" for k in self.text_model._tied_weights_keys]
|
||||
|
||||
self.vqmodel = Emu3VQVAE(config.vq_config)
|
||||
self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
|
||||
|
||||
@ -1048,6 +1045,7 @@ class Emu3Model(Emu3PreTrainedModel):
|
||||
|
||||
class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
base_model_prefix = ""
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_checkpoint_conversion_mapping = {
|
||||
"^text_model.model": "model.text_model",
|
||||
"^vqmodel": "model.vqmodel",
|
||||
@ -1068,6 +1066,18 @@ class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def text_model(self):
|
||||
|
@ -1008,6 +1008,12 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -755,6 +755,12 @@ class GotOcr2ForConditionalGeneration(GotOcr2PreTrainedModel, GenerationMixin):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -868,6 +868,12 @@ class InternVLForConditionalGeneration(InternVLPreTrainedModel, GenerationMixin)
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -359,6 +359,12 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -567,6 +567,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -698,6 +698,12 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -725,6 +725,12 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -401,6 +401,12 @@ class Mistral3ForConditionalGeneration(Mistral3PreTrainedModel, GenerationMixin)
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -1795,6 +1795,12 @@ class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -416,6 +416,12 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -443,6 +443,12 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
@ -320,6 +320,12 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin)
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
# Make modules available throught conditional class for BC
|
||||
@property
|
||||
def language_model(self):
|
||||
|
Loading…
Reference in New Issue
Block a user